## 1.Installations

Install necessary libraries, including `optuna` for hyperparameter optimization.

In [None]:
!pip install pandas numpy torch scikit-learn transformers sentencepiece evaluate rouge_score accelerate optuna

## 2. Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup 
import re
import os
import json
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import nltk 
import optuna 

from transformers import (
    MT5Tokenizer,
    MT5ForConditionalGeneration
)
from transformers.modeling_utils import PreTrainedModel 
from transformers.tokenization_utils_base import PreTrainedTokenizerBase 
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 

import evaluate 
import traceback
import gc 

try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    nltk.download('punkt', quiet=True)

## 3. Configuration 

In [None]:
MODEL_NAME_OR_PATH = "google/mt5-small" 
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128
TEST_BATCH_SIZE = 8
ADAM_EPSILON = 1e-8 
MAX_GRAD_NORM = 1.0 

OUTPUT_DIR_BASE = "./mt5_optuna_tuning_output"
COMBINED_TRAIN_DATA_PATH = "combined_train_data.csv"
COMBINED_TEST_DATA_PATH = "combined_test_data.csv"
DATA_DIR = "data" 
MIN_SUMMARY_WORDS = 3 
LOG_INTERVAL = 100 

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Optuna Config
N_OPTUNA_TRIALS = 20 
EPOCHS_PER_OPTUNA_TRIAL = 2 

## 4. Data Loading and Preprocessing 

In [None]:
def clean_text_series_for_metrics(text): 
    text = str(text).lower() 
    text = re.sub(r'\s+', ' ', text) 
    text = re.sub(r'[^\w\s]', '', text) 
    text = text.strip()
    return text

def create_combined_datasets_if_not_exist():
    # ... (Content from previous version of this function) ...
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)
        print(f"Created directory: {DATA_DIR}. Please upload your CSV files here.")
    if os.path.exists(COMBINED_TRAIN_DATA_PATH) and os.path.exists(COMBINED_TEST_DATA_PATH):
        print(f"Found existing combined datasets: {COMBINED_TRAIN_DATA_PATH} and {COMBINED_TEST_DATA_PATH}")
        return
    print(f"Combined datasets not found. Attempting to create them from CSVs in '{DATA_DIR}/' directory...")
    train_files_langs = {
        "english_train.csv": "en", "hindi_train.csv": "hi",
        "gujrati_train.csv": "gu", "bengali_train.csv": "bn",
    }
    test_files_langs = {
        "english_test.csv": "en", "hindi_test.csv": "hi",
        "gujrati_test.csv": "gu", "bengali_test.csv": "bn",
    }
    def process_files(file_lang_map, output_path, dataset_type):
        if os.path.exists(output_path):
            print(f"Found existing combined {dataset_type} dataset: {output_path}")
            return
        all_data = []
        for file_name, lang in file_lang_map.items():
            full_file_path = os.path.join(DATA_DIR, file_name)
            if not os.path.exists(full_file_path):
                print(f"Missing file: {full_file_path}")
                continue
            try:
                df = pd.read_csv(full_file_path, encoding='utf-8', on_bad_lines='skip')
                if 'Article' not in df.columns or 'Summary' not in df.columns:
                    print(f" 'Article' or 'Summary' column missing in {file_name}. Skipping.")
                    continue
                df['Article'] = df['Article'].astype(str).str.strip()
                df['Summary'] = df['Summary'].astype(str).str.strip()
                df.dropna(subset=['Article', 'Summary'], inplace=True) 
                df = df[df['Article'].str.len() > 0] 
                df = df[df['Summary'].str.len() > 0] 
                df = df[df['Summary'].apply(lambda x: len(x.split()) >= MIN_SUMMARY_WORDS)]
                if df.empty:
                    continue
                df['lang'] = lang
                all_data.append(df[['Article', 'Summary', 'lang']])
                print(f"Loaded and processed for {dataset_type}: {file_name}, kept {len(df)} rows.")
            except Exception as e:
                print(f"Error processing {file_name}: {e}")
        if not all_data:
            print(f"No data loaded for {dataset_type} after filtering.")
            return
        combined_df = pd.concat(all_data, ignore_index=True)
        if combined_df.empty:
            print(f"Combined {dataset_type} dataset is empty. Cannot save.")
            return
        combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)
        combined_df.to_csv(output_path, index=False)
        print(f"Combined {dataset_type} dataset created and saved to: {output_path} with {len(combined_df)} rows.")
    process_files(train_files_langs, COMBINED_TRAIN_DATA_PATH, "train")
    process_files(test_files_langs, COMBINED_TEST_DATA_PATH, "test")

def load_data(data_path):
    # ... (Content from previous version of this function) ...
    if not os.path.exists(data_path):
        print(f"Data file not found: {data_path}")
        if data_path == COMBINED_TRAIN_DATA_PATH or data_path == COMBINED_TEST_DATA_PATH:
             create_combined_datasets_if_not_exist() 
             if not os.path.exists(data_path): 
                 return None
        else:
            return None
    try:
        df = pd.read_csv(data_path)
        print(f"Successfully loaded data from {data_path}, shape: {df.shape}")
        return df
    except pd.errors.EmptyDataError:
        print(f"Warning: Data file {data_path} is empty.")
        return None

## 5. Summarization Dataset Class

In [None]:
def prefix_by_lang(lang):
    return f"summarize in {lang}: "

class SummarizationDataset(Dataset):
    # ... (Content from previous version of this class) ...
    def __init__(self, dataframe, tokenizer, max_input_len, max_target_len):
        self.data = dataframe.copy()
        self.data['Article'] = self.data['Article'].astype(str).str.strip()
        self.data['Summary'] = self.data['Summary'].astype(str).str.strip()
        initial_len = len(self.data)
        self.data = self.data[self.data['Article'].str.len() > 0]
        self.data = self.data[self.data['Summary'].str.len() > 0]
        self.data = self.data[self.data['Summary'].apply(lambda x: len(x.split()) >= MIN_SUMMARY_WORDS)]
        self.data = self.data.reset_index(drop=True)
        if len(self.data) < initial_len:
            print(f"SummarizationDataset: Initialized with {len(self.data)} rows after filtering {initial_len - len(self.data)} empty/short entries.")
        if len(self.data) == 0:
            print("CRITICAL: SummarizationDataset is empty after filtering. No data available.")
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        if index >= len(self.data):
            raise IndexError("Index out of bounds in SummarizationDataset")
        row = self.data.iloc[index]
        article_lang = row['lang']
        article_text_raw = str(row['Article'])
        summary_text_raw = str(row['Summary'])
        input_text = prefix_by_lang(article_lang) + article_text_raw
        target_text = summary_text_raw
        input_enc = self.tokenizer(input_text, max_length=self.max_input_len, padding='do_not_pad', truncation=True, return_tensors="pt")
        target_enc = self.tokenizer(target_text, max_length=self.max_target_len, padding='do_not_pad', truncation=True, return_tensors="pt")
        input_ids = input_enc["input_ids"].squeeze(0)
        attention_mask = input_enc["attention_mask"].squeeze(0)
        labels = target_enc["input_ids"].squeeze(0).clone()
        if labels.ndim == 0: labels = labels.unsqueeze(0)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

## 6. Custom Data Collator

In [None]:
class CustomSummarizationCollator:
    # ... (Content from previous version of this class) ...
    def __init__(self, tokenizer: PreTrainedTokenizerBase, model: Optional[PreTrainedModel] = None, label_pad_token_id: int = -100, pad_to_multiple_of: Optional[int] = None):
        self.tokenizer = tokenizer
        self.model = model
        self.label_pad_token_id = label_pad_token_id
        self.pad_to_multiple_of = pad_to_multiple_of
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        input_ids_list = [feature["input_ids"].tolist() for feature in features]
        attention_mask_list = [feature["attention_mask"].tolist() for feature in features]
        padded_inputs = self.tokenizer.pad(
            {"input_ids": input_ids_list},
            padding="longest", max_length=MAX_INPUT_LENGTH, 
            pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt"
        )
        padded_attention_masks = self.tokenizer.pad(
            {"input_ids": attention_mask_list},
            padding="longest", max_length=MAX_INPUT_LENGTH,
            pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt"
        )["input_ids"]
        batch = {
            "input_ids": padded_inputs["input_ids"],
            "attention_mask": padded_attention_masks
        }
        if "labels" in features[0] and features[0]["labels"] is not None:
            labels_list = [feature["labels"] for feature in features]
            max_label_len = max(len(l) for l in labels_list)
            if self.pad_to_multiple_of is not None:
                max_label_len = ((max_label_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of)
            padded_labels_list = []
            for label_tensor in labels_list:
                padding_len = max_label_len - len(label_tensor)
                padded_tensor = torch.cat([label_tensor, torch.full((padding_len,), self.tokenizer.pad_token_id, dtype=label_tensor.dtype)])
                padded_labels_list.append(padded_tensor)
            labels_tensor = torch.stack(padded_labels_list)
            if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
                batch["decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels(labels=labels_tensor.clone())
            else: 
                shifted_labels = labels_tensor.new_zeros(labels_tensor.shape)
                shifted_labels[..., 1:] = labels_tensor[..., :-1].clone()
                shifted_labels[..., 0] = self.tokenizer.pad_token_id 
                batch["decoder_input_ids"] = shifted_labels
            labels_tensor[labels_tensor == self.tokenizer.pad_token_id] = self.label_pad_token_id
            batch["labels"] = labels_tensor
        return batch

## 7. Metrics Calculation Functions

In [None]:
rouge_metric = evaluate.load("rouge")

def calculate_jaccard(str1, str2):
    s1_cleaned = clean_text_series_for_metrics(str1)
    s2_cleaned = clean_text_series_for_metrics(str2)
    tokens1 = set(nltk.word_tokenize(s1_cleaned))
    tokens2 = set(nltk.word_tokenize(s2_cleaned))
    if not tokens1 and not tokens2: return 1.0
    if not tokens1 or not tokens2: return 0.0
    return len(tokens1.intersection(tokens2)) / len(tokens1.union(tokens2))

def calculate_cosine_tfidf(list_of_references, list_of_predictions):
    if not list_of_references or not list_of_predictions or len(list_of_references) != len(list_of_predictions):
        return 0.0
    cleaned_refs = [clean_text_series_for_metrics(s) for s in list_of_references if str(s).strip()]
    cleaned_preds = [clean_text_series_for_metrics(s) for s in list_of_predictions if str(s).strip()]
    if not cleaned_refs or not cleaned_preds: return 0.0
    vectorizer = TfidfVectorizer()
    corpus = cleaned_refs + cleaned_preds
    try:
        vectorizer.fit(corpus)
    except ValueError:
        print("Warning: TF-IDF Vectorizer could not be fitted (empty corpus after cleaning?).")
        return 0.0
    total_similarity = 0
    count = 0
    for i in range(len(list_of_references)):
        ref = clean_text_series_for_metrics(list_of_references[i])
        pred = clean_text_series_for_metrics(list_of_predictions[i])
        if not ref or not pred: continue
        try:
            tfidf_ref = vectorizer.transform([ref])
            tfidf_pred = vectorizer.transform([pred])
            sim = cosine_similarity(tfidf_ref, tfidf_pred)[0, 0]
            total_similarity += sim
            count += 1
        except ValueError:
            print(f"Skipping cosine for pair due to empty vector: REF='{ref}', PRED='{pred}'")
            continue
    return (total_similarity / count) * 100 if count > 0 else 0.0

## 8. Training and Evaluation Loops

In [None]:

def evaluate_model_for_optuna(model, dataloader, tokenizer, device):
    """Simplified evaluation for Optuna, returns primary metric (e.g., ROUGE-1)."""
    model.eval()
    all_decoded_preds = []
    all_decoded_labels = []
    total_eval_loss = 0 

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            generated_ids = model.generate(
                input_ids=input_ids, attention_mask=attention_mask,
                max_length=MAX_TARGET_LENGTH, num_beams=4, early_stopping=True
            )
            decoded_preds_batch = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            labels_for_decode = labels.clone()
            labels_for_decode[labels_for_decode == -100] = tokenizer.pad_token_id
            decoded_labels_batch = tokenizer.batch_decode(labels_for_decode, skip_special_tokens=True)
            all_decoded_preds.extend([pred.strip() for pred in decoded_preds_batch])
            all_decoded_labels.extend([label.strip() for label in decoded_labels_batch])

    rouge_preds = [pred if pred else "<empty>" for pred in all_decoded_preds]
    rouge_labels = [label if label else "<empty>" for label in all_decoded_labels]
    
    if not rouge_preds or not rouge_labels:
        return 0.0 

    rouge_results = rouge_metric.compute(predictions=rouge_preds, references=rouge_labels, use_stemmer=True)
    return rouge_results.get('rouge1', 0.0) * 100 

def objective(trial: optuna.trial.Trial):
    gc.collect()
    if DEVICE == torch.device("cuda"):
        torch.cuda.empty_cache()

    print(f"\nStarting Optuna Trial: {trial.number}")
    lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
    trial_epochs = trial.suggest_int("num_train_epochs", 1, EPOCHS_PER_OPTUNA_TRIAL) # Short epochs for trials
    trial_weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
    trial_train_batch_size = trial.suggest_categorical("train_batch_size", [2, 4, 8])
    trial_warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.2)

    print(f"  Trial {trial.number} Params: LR={lr:.2e}, Epochs={trial_epochs}, WD={trial_weight_decay:.2e}, BS={trial_train_batch_size}")

    current_tokenizer = MT5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
    current_model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
    current_model.to(DEVICE)

    
    current_train_df, current_val_df = train_test_split(df_train_val_global, test_size=0.1, random_state=42)
    current_train_dataset = SummarizationDataset(current_train_df, current_tokenizer, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)
    current_val_dataset = SummarizationDataset(current_val_df, current_tokenizer, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)

    if len(current_train_dataset) == 0 or len(current_val_dataset) == 0:
        print(f"Warning: Trial {trial.number} has empty train or val dataset after init. Skipping.")
        return 0.0 

    current_collator = CustomSummarizationCollator(tokenizer=current_tokenizer, model=current_model, label_pad_token_id=-100)
    current_train_dataloader = DataLoader(current_train_dataset, batch_size=trial_train_batch_size, collate_fn=current_collator, shuffle=True, num_workers=2, pin_memory=True)
    current_val_dataloader = DataLoader(current_val_dataset, batch_size=trial_train_batch_size, collate_fn=current_collator, num_workers=2, pin_memory=True)

    if len(current_train_dataloader) == 0 or len(current_val_dataloader) == 0:
        print(f"Warning: Trial {trial.number} has empty train or val dataloader. Skipping.")
        return 0.0

    optimizer = AdamW(current_model.parameters(), lr=lr, eps=ADAM_EPSILON, weight_decay=trial_weight_decay)
    total_steps = len(current_train_dataloader) * trial_epochs
    # num_warmup = int(trial_warmup_ratio * total_steps)
    num_warmup = int(0.1 * total_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup, num_training_steps=total_steps)

    for epoch in range(trial_epochs):
        current_model.train()
        print(f"  Trial {trial.number}, Epoch {epoch+1}/{trial_epochs} Training...")
        for batch_idx, batch in enumerate(current_train_dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            decoder_input_ids = batch.get('decoder_input_ids', None)
            if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to(DEVICE)

            if (labels != -100).sum() == 0: 
                continue

            outputs = current_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, decoder_input_ids=decoder_input_ids)
            loss = outputs.loss
            if loss is None or torch.isnan(loss) or torch.isinf(loss):
                print(f"    Trial {trial.number} WARNING: Invalid loss at step {batch_idx}. Skipping.")
                return 0.0
            loss.backward()
            torch.nn.utils.clip_grad_norm_(current_model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            if scheduler: scheduler.step()

            if batch_idx % LOG_INTERVAL == 0:
                print(f"    Trial {trial.number}, Ep {epoch+1}, Batch {batch_idx+1}/{len(current_train_dataloader)}, Loss: {loss.item():.4f}")

    print(f"  Trial {trial.number} Evaluating...")
    validation_rouge1 = evaluate_model_for_optuna(current_model, current_val_dataloader, current_tokenizer, DEVICE)
    print(f"  Trial {trial.number} Validation ROUGE-1: {validation_rouge1:.4f}")
    
    del current_model, current_tokenizer, current_train_dataset, current_val_dataset, current_train_dataloader, current_val_dataloader, optimizer, scheduler
    gc.collect()
    if DEVICE == torch.device("cuda"):
        torch.cuda.empty_cache()

    return validation_rouge1 

## 9. Prepare Global Data

In [None]:
create_combined_datasets_if_not_exist()
df_train_val_global = load_data(COMBINED_TRAIN_DATA_PATH)
df_test_global = load_data(COMBINED_TEST_DATA_PATH)

if df_train_val_global is None or df_train_val_global.empty:
    print("STOPPING: Global training data for Optuna could not be loaded or is empty.")

## 10. Run Optuna Hyperparameter Search

In [None]:
if 'df_train_val_global' in globals() and df_train_val_global is not None and not df_train_val_global.empty:
    study = optuna.create_study(direction="maximize", study_name="mt5_summarization_tuning")
  
    try:
        study.optimize(objective, n_trials=N_OPTUNA_TRIALS, timeout=3600*4) # Example timeout of 4 hours
    except Exception as e:
        print(f"Error during Optuna optimization: {e}")
        traceback.print_exc()

    print("\nOptuna Hyperparameter Search Finished.")
    print("Best trial:")
    best_trial = study.best_trial
    print(f"  Value (Validation ROUGE-1): {best_trial.value:.4f}")
    print("  Params: ")
    for key, value in best_trial.params.items():
        print(f"    {key}: {value}")
    
    # Save best params
    if not os.path.exists(OUTPUT_DIR_BASE):
        os.makedirs(OUTPUT_DIR_BASE)
    best_params_file = os.path.join(OUTPUT_DIR_BASE, "best_hyperparameters.json")
    with open(best_params_file, "w") as f:
        json.dump(best_trial.params, f, indent=4)
    print(f"Best hyperparameters saved to {best_params_file}")
else:
    print("Global training data not loaded. Skipping Optuna hyperparameter search.")

## 11. Train & Evaluate Final Model

In [None]:
def run_final_training_and_evaluation(best_params, full_num_epochs):
    print("\n--- Starting Final Training with Best Hyperparameters ---")
    print(f"Best Params: {best_params}")
    print(f"Training for {full_num_epochs} epochs.")

    final_output_dir = os.path.join(OUTPUT_DIR_BASE, "final_model_with_best_params")
    if not os.path.exists(final_output_dir):
        os.makedirs(final_output_dir)

    final_tokenizer = MT5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
    final_model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
    final_model.to(DEVICE)

    if 'df_train_val_global' not in globals() or df_train_val_global is None or df_train_val_global.empty:
        print("Global training data not available for final training. Exiting.")
        return
    
    current_train_df, current_val_df = train_test_split(df_train_val_global, test_size=0.1, random_state=42)
    final_train_dataset = SummarizationDataset(current_train_df, final_tokenizer, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)
    final_val_dataset = SummarizationDataset(current_val_df, final_tokenizer, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)

    if len(final_train_dataset) == 0 or len(final_val_dataset) == 0:
        print("Final train or val dataset is empty. Exiting final training.")
        return

    final_collator = CustomSummarizationCollator(tokenizer=final_tokenizer, model=final_model, label_pad_token_id=-100)
    final_train_dataloader = DataLoader(final_train_dataset, batch_size=best_params.get('train_batch_size', TRAIN_BATCH_SIZE), collate_fn=final_collator, shuffle=True, num_workers=2, pin_memory=True)
    final_val_dataloader = DataLoader(final_val_dataset, batch_size=best_params.get('train_batch_size', VALID_BATCH_SIZE), collate_fn=final_collator, num_workers=2, pin_memory=True)

    if len(final_train_dataloader) == 0 or len(final_val_dataloader) == 0:
        print("Final train or val dataloader is empty. Exiting final training.")
        return

    final_optimizer = AdamW(final_model.parameters(), 
                            lr=best_params['learning_rate'], 
                            eps=ADAM_EPSILON, 
                            weight_decay=best_params['weight_decay'])
    final_total_steps = len(final_train_dataloader) * full_num_epochs
    final_warmup_ratio = 0.1 
    final_num_warmup = int(final_warmup_ratio * final_total_steps)
    final_scheduler = get_linear_schedule_with_warmup(final_optimizer, num_warmup_steps=final_num_warmup, num_training_steps=final_total_steps)

    global OUTPUT_DIR
    original_output_dir = OUTPUT_DIR
    OUTPUT_DIR = final_output_dir

    print(f"Training final model with best params, outputting to: {final_output_dir}")
    train_model(final_model, final_train_dataloader, final_val_dataloader, final_optimizer, final_scheduler, full_num_epochs, DEVICE)

    OUTPUT_DIR = original_output_dir

    # Evaluate on Test Set
    if 'df_test_global' in globals() and df_test_global is not None and not df_test_global.empty:
        print("\n--- Evaluating Final Model on Test Set ---")
        model_for_test_eval = MT5ForConditionalGeneration.from_pretrained(final_output_dir)
        model_for_test_eval.to(DEVICE)
        tokenizer_for_test_eval = MT5Tokenizer.from_pretrained(final_output_dir)

        final_test_dataset = SummarizationDataset(df_test_global, tokenizer_for_test_eval, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)
        if len(final_test_dataset) > 0:
            final_test_collator = CustomSummarizationCollator(tokenizer=tokenizer_for_test_eval, model=model_for_test_eval, label_pad_token_id=-100)
            final_test_dataloader = DataLoader(final_test_dataset, batch_size=TEST_BATCH_SIZE, collate_fn=final_test_collator)
            if len(final_test_dataloader) > 0:
                original_output_dir_eval = OUTPUT_DIR
                OUTPUT_DIR = final_output_dir 
                evaluate_model(model_for_test_eval, final_test_dataloader, tokenizer_for_test_eval, DEVICE, is_test_set=True)
                OUTPUT_DIR = original_output_dir_eval
            else:
                print("Final test dataloader is empty.")
        else:
            print("Final test dataset is empty.")
    else:
        print("Global test data not available or empty. Skipping final test set evaluation.")

# --- Main Execution for Optuna --- 
if 'df_train_val_global' in globals() and df_train_val_global is not None and not df_train_val_global.empty:
    study_db_path = f"sqlite:///{os.path.join(OUTPUT_DIR_BASE, 'optuna_study.db')}"
    best_params_from_file = None
    best_params_file_path = os.path.join(OUTPUT_DIR_BASE, "best_hyperparameters.json")

    if os.path.exists(best_params_file_path):
        print(f"Found existing best hyperparameters file: {best_params_file_path}")
        with open(best_params_file_path, 'r') as f:
            best_params_from_file = json.load(f)
    
    if best_params_from_file:
        print("Using previously found best hyperparameters for final training.")
        run_final_training_and_evaluation(best_params_from_file, full_num_epochs=NUM_TRAIN_EPOCHS) # Use original NUM_TRAIN_EPOCHS for full run
    else:
        print(f"No existing best hyperparameters found. Running Optuna study: {study_db_path}")
        if not os.path.exists(OUTPUT_DIR_BASE):
            os.makedirs(OUTPUT_DIR_BASE)
        study = optuna.create_study(direction="maximize", 
                                    study_name="mt5_summarization_tuning", 
                                    storage=study_db_path, 
                                    load_if_exists=True)
        try:
            study.optimize(objective, n_trials=N_OPTUNA_TRIALS, timeout=3600*6) 
        except Exception as e:
            print(f"Error during Optuna optimization: {e}")
            traceback.print_exc()

        print("\nOptuna Hyperparameter Search Finished.")
        if len(study.trials) > 0:
            print("Best trial:")
            best_trial = study.best_trial
            print(f"  Value (Validation ROUGE-1): {best_trial.value:.4f}")
            print("  Params: ")
            for key, value in best_trial.params.items():
                print(f"    {key}: {value}")
            with open(best_params_file_path, "w") as f:
                json.dump(best_trial.params, f, indent=4)
            print(f"Best hyperparameters saved to {best_params_file_path}")
            # Run final training with these best params
            run_final_training_and_evaluation(best_trial.params, full_num_epochs=NUM_TRAIN_EPOCHS)
        else:
            print("No trials completed in Optuna study.")
else:
    print("Global training data not loaded. Skipping Optuna hyperparameter search and final training.")



## 12. Example Prediction

In [None]:
final_model_dir = os.path.join(OUTPUT_DIR_BASE, "final_model_with_best_params")
if os.path.exists(os.path.join(final_model_dir, "pytorch_model.bin")):
    print("\n--- Example Prediction using the BEST fine-tuned model ---")
    loaded_model = MT5ForConditionalGeneration.from_pretrained(final_model_dir)
    loaded_model.to(DEVICE)
    loaded_model.eval()
    loaded_tokenizer = MT5Tokenizer.from_pretrained(final_model_dir) 

    sample_article_en = "summarize in en: Several research groups have been working on developing new types of batteries that could store more energy and charge faster. One promising approach involves using solid-state electrolytes instead of liquid ones, which could improve safety and energy density. These advancements are crucial for the future of electric vehicles and portable electronics."
    sample_article_hi = "summarize in hi: कई शोध समूह नई प्रकार की बैटरियों को विकसित करने पर काम कर रहे हैं जो अधिक ऊर्जा संग्रहीत कर सकें और तेजी से चार्ज हो सकें। एक आशाजनक दृष्टिकोण में तरल इलेक्ट्रोलाइट्स के बजाय ठोस-अवस्था वाले इलेक्ट्रोलाइट्स का उपयोग करना शामिल है, जिससे सुरक्षा और ऊर्जा घनत्व में सुधार हो सकता है। ये प्रगति इलेक्ट्रिक वाहनों और पोर्टेबल इलेक्ट्रॉनिक्स के भविष्य के लिए महत्वपूर्ण हैं।"

    for lang_code, sample_article in [("en", sample_article_en), ("hi", sample_article_hi)]:
        print(f"\nInput Article ({lang_code}): {sample_article.replace(f'summarize in {lang_code}: ', '')}")
        inputs = loaded_tokenizer(sample_article, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True, padding=True).to(DEVICE)
        
        with torch.no_grad():
            summary_ids = loaded_model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                num_beams=4,
                max_length=MAX_TARGET_LENGTH,
                early_stopping=True
            )
        generated_summary = loaded_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        print(f"Generated Summary ({lang_code}): {generated_summary}")
else:
    print(f"Fine-tuned model not found in {final_model_dir}. Skipping example prediction.")