In [5]:
# Standard Libraries
import os
import sys
import evaluate
import argparse
import random
import yaml
from tqdm.notebook import tqdm # Use notebook-friendly tqdm
import warnings
import json
from sklearn.model_selection import train_test_split # For validation split
import gc # Garbage collector

# Data Handling
import numpy as np
import pandas as pd
import h5py

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Transformers Library
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    MT5ForConditionalGeneration,
    Seq2SeqTrainingArguments, # Using Seq2Seq specific args
    Seq2SeqTrainer,
    set_seed as transformers_set_seed # Alias to avoid conflict if needed
)
from safetensors.torch import load_file # Import for loading safetensors

# Evaluation Metric
from sacrebleu.metrics import BLEU, CHRF # Added CHRF
# Using load_metric for simplicity if sacrebleu gives issues with Trainer
# from datasets import load_metric

# Plotting for EDA
import matplotlib.pyplot as plt
import seaborn as sns

# --- Configuration ---
warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define base paths
CONFIG_DIR = "D:/configs"
DATA_DIR = "D:/saudi-signfor-all-competition"
MODEL_OUTPUT_DIR = "output_finetuned" # Changed directory name for clarity
EVAL_RESULTS_DIR = "results_finetuned" # Changed directory name

os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
os.makedirs(EVAL_RESULTS_DIR, exist_ok=True)

# ===========================================
# Cell 2: Configuration Loading Utilities
# ===========================================
class Config:
    """Helper class to load YAML configuration files into an object."""
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            if isinstance(value, dict):
                setattr(self, key, Config(value))
            else:
                setattr(self, key, value)
    def __repr__(self):
        items = (f"{k}={v!r}" for k, v in self.__dict__.items())
        return f"{self.__class__.__name__}({', '.join(items)})"
    def get(self, attribute_path, default=None):
        """Access nested attributes using dot notation string, e.g., 'ModelArguments.feature_dim'."""
        keys = attribute_path.split('.')
        value = self;
        try:
            for key in keys: value = getattr(value, key)
            return value
        except AttributeError: return default

def load_config_from_file(config_path):
    """Loads a YAML configuration file."""
    if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}")
    with open(config_path, "r", encoding='utf-8') as file:
        try: config_dict = yaml.safe_load(file)
        except yaml.YAMLError as e: print(f"Error parsing YAML {config_path}: {e}"); raise
    if config_dict is None: print(f"Warning: Config file {config_path} empty."); return Config({})
    return Config(config_dict)

# --- Load Configurations ---
# Modify config_train.yaml to include:
# ModelArguments:
#   model_name_or_path: "D:/YTASL - Checkpoint/t5-base" # <--- PATH TO STARTING CHECKPOINT
# DatasetArguments:
#   validation_split_ratio: 0.1 # Or your desired ratio (e.g., 0.15)

train_config_path = os.path.join(CONFIG_DIR, 'config_train.yaml')
eval_config_path = os.path.join(CONFIG_DIR, 'config_eval.yaml') # Used only for final test eval settings

print(f"Loading training config from: {train_config_path}")
train_cfg = load_config_from_file(train_config_path)
# Add default validation split ratio if not in config
if not train_cfg.get('DatasetArguments.validation_split_ratio'):
    print("Validation split ratio not found in config, using default 0.1 (10%)")
    if not hasattr(train_cfg, 'DatasetArguments'): train_cfg.DatasetArguments = Config({})
    train_cfg.DatasetArguments.validation_split_ratio = 0.1
# Ensure starting model path exists in config
if not train_cfg.get('ModelArguments.model_name_or_path'):
     print("WARNING: 'model_name_or_path' not specified in train_cfg.ModelArguments. Defaulting to google-t5/t5-base.")
     if not hasattr(train_cfg, 'ModelArguments'): train_cfg.ModelArguments = Config({})
     train_cfg.ModelArguments.model_name_or_path = 'google-t5/t5-base'


print("Training Config Loaded:")
print(train_cfg)
print("-" * 30)

print(f"Loading evaluation config from: {eval_config_path}")
eval_cfg = load_config_from_file(eval_config_path)

# --- Adjust Paths ---
# Eval paths (for final test run)
test_data_path = eval_cfg.get('DatasetArguments.test_dataset_path')
if test_data_path: eval_cfg.DatasetArguments.test_dataset_path = os.path.join(DATA_DIR, test_data_path)
test_labels_path_final = eval_cfg.get('DatasetArguments.test_labels_dataset_path') # Path to *actual* test labels
if test_labels_path_final: eval_cfg.DatasetArguments.test_labels_dataset_path = os.path.join(DATA_DIR, test_labels_path_final)
eval_results_save_rel_path = eval_cfg.get('EvaluationArguments.results_save_path', 'final_test_results') # Subdir name
eval_cfg.EvaluationArguments.results_save_path = os.path.join(EVAL_RESULTS_DIR, eval_results_save_rel_path)
os.makedirs(eval_cfg.EvaluationArguments.results_save_path, exist_ok=True)

# Train paths
train_data_path = train_cfg.get('DatasetArguments.train_dataset_path')
if train_data_path: train_cfg.DatasetArguments.train_dataset_path = os.path.join(DATA_DIR, train_data_path)
train_labels_rel_path = train_cfg.get('DatasetArguments.train_labels_dataset_path')
if train_labels_rel_path: train_cfg.DatasetArguments.train_labels_dataset_path = os.path.join(DATA_DIR, train_labels_rel_path)
# Validation paths are derived from train split
train_output_rel_folder = train_cfg.get('TrainingArguments.output_folder', 'finetune_run') # Base output dir name
train_cfg.TrainingArguments.output_folder = os.path.join(MODEL_OUTPUT_DIR, train_output_rel_folder)
start_model_path = train_cfg.get('ModelArguments.model_name_or_path') # Get the specified starting path

print("Evaluation Config Loaded (Paths adjusted):")
print(eval_cfg)
print("-" * 30)
print("Training Config Paths Adjusted:")
print(f"  Train/Val Data H5: {train_cfg.get('DatasetArguments.train_dataset_path', 'N/A')}")
print(f"  Train/Val Labels CSV: {train_cfg.get('DatasetArguments.train_labels_dataset_path', 'N/A')}")
print(f"  Validation Split Ratio: {train_cfg.get('DatasetArguments.validation_split_ratio', 'N/A')}")
print(f"  Starting Fine-tuning From: {start_model_path}")
print(f"  Base Fine-tuning Output Folder: {train_cfg.get('TrainingArguments.output_folder', 'N/A')}")
print("-" * 30)

# ===========================================
# Cell 3: EDA (Optional)
# ===========================================
# Skipped for brevity

# ===========================================
# Cell 4: Dataset Class (Modified for pre-split data)
# ===========================================
class VideoDataset(Dataset):
    def __init__(self, h5_file_path, data_keys_to_use, id_to_sentence_map,
                 tokenizer, h5_lengths_map, max_seq_length=600,
                 max_label_length=512, is_test_set=False, feature_dim=208):
        self.h5_file_path = h5_file_path
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.max_label_length = max_label_length
        self.is_test_set = is_test_set
        self.feature_dim = feature_dim
        self.id_to_sentence_map = id_to_sentence_map # Store sentence map

        self.data_info = []
        print(f"Initializing dataset ({'Test' if is_test_set else 'Train/Val'}). Processing {len(data_keys_to_use)} provided keys...")

        skipped_count = 0
        for key in tqdm(data_keys_to_use, desc="Preparing dataset samples"):
            original_length = h5_lengths_map.get(key)
            sentence = self.id_to_sentence_map.get(key) # Get sentence

            # Validation (ensure sentence exists unless it's test set without labels)
            if sentence is None and not self.is_test_set:
                 skipped_count += 1
                 continue
            if isinstance(sentence, str) and not sentence.strip() and not self.is_test_set:
                 skipped_count += 1
                 continue

            # Length validation
            if original_length is None or original_length == 0 or original_length > self.max_seq_length:
                skipped_count += 1
                continue

            self.data_info.append((key, sentence, original_length))

        print(f"Finished preparing dataset: {len(self.data_info)} samples included, {skipped_count} samples skipped.")
        if not self.data_info: print("WARNING: Dataset is empty after processing.")

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        key, sentence, original_length = self.data_info[idx]
        features = None
        actual_feature_dim = self.feature_dim
        try:
             with h5py.File(self.h5_file_path, 'r') as f:
                 item = f[key]
                 data_to_load = item[list(item.keys())[0]] if isinstance(item, h5py.Group) and item.keys() else item if isinstance(item, h5py.Dataset) else None
                 if data_to_load is not None: features_np = data_to_load[()]
                 else: raise ValueError(f"No dataset found for key {key}")
                 if features_np.shape[1] != self.feature_dim:
                      print(f"Warning: Feature dim mismatch key {key}. Expected {self.feature_dim}, Got {features_np.shape[1]}. Using actual.")
                      actual_feature_dim = features_np.shape[1]
                 features = torch.tensor(features_np, dtype=torch.float32)
        except Exception as e:
             print(f"Error loading key {key}: {e}"); return None # Signal error to collate_fn

        pad_value = 0.0; num_padding = self.max_seq_length - original_length
        if num_padding < 0: features = features[:self.max_seq_length]; num_padding = 0
        if features.shape[1] != actual_feature_dim: print(f"ERR dim mismatch key {key}"); return None

        padded_features = F.pad(features, (0, 0, 0, num_padding), value=pad_value)
        attention_mask = torch.zeros(self.max_seq_length, dtype=torch.long); attention_mask[:original_length] = 1

        labels_output = None
        if self.is_test_set:
             labels_output = sentence if sentence is not None else "" # Return string for test set
        else: # Train/Val -> Tokenize
             tokenized_output = self.tokenizer(
                 str(sentence), truncation=True, padding="max_length",
                 max_length=self.max_label_length, return_tensors="pt")
             input_ids = tokenized_output["input_ids"].squeeze(0)
             input_ids[input_ids == self.tokenizer.pad_token_id] = -100
             labels_output = input_ids

        return {"features": padded_features, "attention_mask": attention_mask, "labels": labels_output, "key": key}

    @staticmethod
    def preload_h5_info(file_path):
        # (Same as previous version)
        keys = []; lengths = {}
        if not file_path or not os.path.exists(file_path): print(f"Error: H5 path {file_path} invalid."); return [], {}
        try:
            with h5py.File(file_path, 'r') as f:
                print(f"Scanning HDF5 {file_path}..."); valid_keys = list(f.keys())
                for key in tqdm(valid_keys, desc="Preloading HDF5 info"):
                    try:
                        item = f[key]; seq_len = 0
                        if isinstance(item, h5py.Group):
                             dk = list(item.keys()); seq_len = item[dk[0]].shape[0] if dk else 0
                        elif isinstance(item, h5py.Dataset): seq_len = item.shape[0]
                        else: continue
                        if seq_len > 0: keys.append(key); lengths[key] = seq_len
                    except Exception as e: print(f"Err key {key}: {e}")
            print(f"Found {len(keys)} non-empty keys in HDF5.")
            return keys, lengths
        except Exception as e: print(f"Err opening HDF5 {file_path}: {e}"); return [], {}

# ===========================================
# Cell 5: Model Definition
# ===========================================
class SNLTraslationModel(nn.Module):
    # Takes model_name_or_path now
    def __init__(self, feature_dim, model_name_or_path, dropout_rate=0.1):
        super().__init__()
        print(f"Initializing SNLTraslationModel:")
        print(f"  Loading Base Model From: {model_name_or_path}")
        print(f"  Feature Dim (Input): {feature_dim}")

        # Load base model (T5 or mT5 based on config/path)
        is_mt5 = 'mt5' in model_name_or_path.lower()
        try:
            if is_mt5: self.model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path)
            else: self.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)
            print(f"  Base {'MT5' if is_mt5 else 'T5'} model loaded successfully.")
        except Exception as e: print(f"ERROR loading base model from {model_name_or_path}: {e}"); raise

        self.model_hidden_dim = self.model.config.d_model
        print(f"  Model Hidden Dim (d_model): {self.model_hidden_dim}")

        # Custom linear layer projects features -> model's expected input dim
        self.custom_linear = nn.Sequential(
            nn.Linear(feature_dim, self.model_hidden_dim),
            nn.Dropout(dropout_rate),
            nn.GELU(),
        )
        print(f"  Projection Layer: Linear({feature_dim} -> {self.model_hidden_dim})")

    def forward(self, features, attention_mask, labels=None):
        try: projected_features = self.custom_linear(features)
        except Exception as e: print(f"Linear proj error: {e}\nInp shape: {features.shape}"); raise
        try:
            # Make sure labels are passed correctly if provided
            outputs = self.model(inputs_embeds=projected_features, attention_mask=attention_mask, labels=labels, return_dict=True)
        except Exception as e: print(f"Model fwd error: {e}\nEmbed shape: {projected_features.shape}"); raise
        return outputs # Contains loss and logits

    @torch.no_grad()
    def generate(self, features, attention_mask, **generate_kwargs):
        self.eval()
        inputs_embeds = self.custom_linear(features)
        # Ensure generate_kwargs are passed correctly
        return self.model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generate_kwargs)

    # ---> ADD THIS METHOD <---
    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        """Delegates gradient checkpointing enabling to the underlying HF model."""
        print("Enabling gradient checkpointing on underlying model...")
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
        else:
            print("Warning: Underlying model does not have 'gradient_checkpointing_enable' method.")

# ===========================================
# Cell 6: Fine-tuning Setup and Execution
# ===========================================
def set_seed(seed_value):
    random.seed(seed_value); np.random.seed(seed_value); torch.manual_seed(seed_value)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)
    transformers_set_seed(seed_value)

def collate_fn_train_val(batch):
    """Collate function for training/validation DataLoaders."""
    batch = [sample for sample in batch if sample is not None]
    if not batch: return None
    try:
        features = torch.stack([sample['features'] for sample in batch])
        attention_mask = torch.stack([sample['attention_mask'] for sample in batch])
        labels = torch.stack([sample['labels'] for sample in batch]) # Expect tensors with -100
        return {"features": features, "attention_mask": attention_mask, "labels": labels}
    except Exception as e: print(f"Collate error: {e}"); raise

# --- Metrics ---
# Load metrics (safer way)
# --- Metric Calculation Function ---
try:
    sacrebleu_metric = evaluate.load("sacrebleu")
    chrf_metric = evaluate.load("chrf")
    metrics_loaded = True
    print("Metrics 'sacrebleu' and 'chrf' loaded using 'evaluate' library.")
except Exception as e:
    print(f"Warning: Could not load metrics using 'evaluate': {e}")
    print("BLEU/CHRF calculation during training/evaluation might fail.")
    metrics_loaded = False
   
except Exception as e:
    print(f"Warning: Could not load metrics using 'datasets.load_metric': {e}. Using direct import.")
    # Fallback or alternative needed if load_metric fails completely
    metrics_loaded = False # Need to handle this in compute_metrics




# Define compute_metrics function (needs tokenizer available)
def compute_metrics_fn(tokenizer_for_decode):
    def compute_metrics(eval_pred):
        # eval_pred.predictions often is a tuple if multiple outputs, logits are usually first
        logits = eval_pred.predictions[0] if isinstance(eval_pred.predictions, tuple) else eval_pred.predictions
        labels = eval_pred.label_ids

        if logits is None or labels is None:
            print("Warning: Logits or labels are None in compute_metrics.")
            return {}

        # Get predicted token ids
        predictions = np.argmax(logits, axis=-1)

        # Decode predictions
        decoded_preds = tokenizer_for_decode.batch_decode(predictions, skip_special_tokens=True)

        # Decode labels (handle -100)
        labels = np.where(labels != -100, labels, tokenizer_for_decode.pad_token_id)
        decoded_labels = tokenizer_for_decode.batch_decode(labels, skip_special_tokens=True)

        # Format for sacrebleu
        decoded_labels_for_bleu = [[label] for label in decoded_labels]

        results = {}
        if metrics_loaded:
            try:
                bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels_for_bleu)
                chrf_result = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels_for_bleu, word_match=False) # chrF++

                results["bleu"] = round(bleu_result["score"], 4)
                results["chrf++"] = round(chrf_result["score"], 4)
                # Add BLEU n-gram precisions if desired
                # results.update({f"bleu_{n}": round(bleu_result['precisions'][i], 4) for i, n in enumerate([1, 2, 3, 4])})

            except Exception as e:
                print(f"Error calculating metrics: {e}")
                results["bleu"] = 0.0
                results["chrf++"] = 0.0
        else:
             print("Metrics library not loaded, skipping BLEU/CHRF calculation.")
             results["bleu"] = 0.0 # Return default

        # Add generation length metric
        prediction_lens = [np.count_nonzero(pred != tokenizer_for_decode.pad_token_id) for pred in predictions]
        results["gen_len"] = np.mean(prediction_lens)

        return {k: round(v, 4) for k, v in results.items()}
    return compute_metrics

# --- Fine-tuning Configuration ---
print("\n--- Setting up Fine-Tuning ---")
seed = train_cfg.get('TrainingArguments.seed', 42); set_seed(seed)
val_split_ratio = train_cfg.get('DatasetArguments.validation_split_ratio', 0.1)
start_model_path = train_cfg.get('ModelArguments.model_name_or_path') # Use the configured path
feature_dim = train_cfg.get('ModelArguments.feature_dim', 208)

# --- Experiment Name & Output Dir ---
exp_name_parts = [
    "SSL_FineTune",
    f"Start_{start_model_path.split('/')[-1].replace(':','-')}", # Sanitize path separators/colons
    f"LR_{train_cfg.get('TrainingArguments.learning_rate', 'Def')}",
    f"Epochs_{train_cfg.get('TrainingArguments.epochs', 'Def')}",
    f"Seed_{seed}"
]
exp_name_base = "-".join(exp_name_parts)
exp_name_safe = "".join(c if c.isalnum() or c in ('-', '_', '.') else '_' for c in exp_name_base)
training_output_dir_base = train_cfg.get('TrainingArguments.output_folder') # Already joined with MODEL_OUTPUT_DIR
training_output_dir = os.path.join(training_output_dir_base, exp_name_safe)
print(f"Fine-tuning Output Directory: {training_output_dir}")

# --- Load Tokenizer ---
print(f"Loading Tokenizer from: {start_model_path}")
try: tokenizer = AutoTokenizer.from_pretrained(start_model_path)
except Exception as e: print(f"Error loading tokenizer {start_model_path}: {e}"); exit()

# --- Prepare Data Split ---
print("\n--- Preparing Train/Validation Split ---")
train_labels_path = train_cfg.get('DatasetArguments.train_labels_dataset_path')
if not train_labels_path or not os.path.exists(train_labels_path): print(f"ERROR: Train labels {train_labels_path} not found."); exit()
all_train_df = pd.read_csv(train_labels_path); all_train_df['ID'] = all_train_df['ID'].astype(str)
all_ids = all_train_df['ID'].tolist(); all_sentences = all_train_df['Translation'].tolist()
train_ids, valid_ids, train_sentences, valid_sentences = train_test_split(all_ids, all_sentences, test_size=val_split_ratio, random_state=seed)
print(f"Splitting data: {len(train_ids)} train samples, {len(valid_ids)} validation samples.")
train_id_to_sentence = dict(zip(train_ids, train_sentences)); valid_id_to_sentence = dict(zip(valid_ids, valid_sentences))

h5_path = train_cfg.get('DatasetArguments.train_dataset_path')
all_h5_keys, h5_lengths_map = VideoDataset.preload_h5_info(h5_path)
all_h5_keys_set = set(all_h5_keys)
train_keys_final = [id_ for id_ in train_ids if id_ in all_h5_keys_set]
valid_keys_final = [id_ for id_ in valid_ids if id_ in all_h5_keys_set]
print(f"Filtered keys vs HDF5: {len(train_keys_final)} train, {len(valid_keys_final)} validation")

# --- Instantiate Datasets ---
print("\nInstantiating Training Dataset...")
train_dataset = VideoDataset(
    h5_file_path=h5_path, data_keys_to_use=train_keys_final, id_to_sentence_map=train_id_to_sentence,
    tokenizer=tokenizer, h5_lengths_map=h5_lengths_map,
    max_seq_length=train_cfg.get('DatasetArguments.max_sequence_length', 600),
    max_label_length=train_cfg.get('DatasetArguments.max_label_length', 128),
    is_test_set=False, feature_dim=feature_dim )

print("\nInstantiating Validation Dataset...")
valid_dataset = VideoDataset(
    h5_file_path=h5_path, data_keys_to_use=valid_keys_final, id_to_sentence_map=valid_id_to_sentence,
    tokenizer=tokenizer, h5_lengths_map=h5_lengths_map,
    max_seq_length=train_cfg.get('DatasetArguments.max_sequence_length', 600),
    max_label_length=train_cfg.get('DatasetArguments.max_label_length', 128),
    is_test_set=False, feature_dim=feature_dim )



# --- Initialize Model ---
print(f"\nInitializing Model for Fine-tuning from: {start_model_path}")
model = SNLTraslationModel(
    feature_dim=feature_dim,
    model_name_or_path=start_model_path,
    dropout_rate=train_cfg.get('TrainingArguments.dropout_rate', 0.1)
)

# ---> ADD THIS BLOCK <---
# Fix missing decoder_start_token_id in the loaded config
if model.model.config.decoder_start_token_id is None and tokenizer.pad_token_id is not None:
    print("Setting decoder_start_token_id to pad_token_id...")
    model.model.config.decoder_start_token_id = tokenizer.pad_token_id
elif model.model.config.decoder_start_token_id is None:
    print("ERROR: Cannot set decoder_start_token_id because tokenizer.pad_token_id is also None.")
    # Handle error - maybe exit or raise ValueError
    exit()
else:
    print(f"Decoder_start_token_id already set: {model.model.config.decoder_start_token_id}")
# ---> END OF ADDED BLOCK <---

model.to(device)


model.to(device)

# --- Training Arguments ---
# ===========================================
# Cell 6: Fine-tuning Setup and Execution
# ===========================================
def set_seed(seed_value):
    """Sets seed for reproducibility."""
    random.seed(seed_value); np.random.seed(seed_value); torch.manual_seed(seed_value)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)
    transformers_set_seed(seed_value)

def collate_fn_train_val(batch):
    """Collate function for training/validation DataLoaders."""
    batch = [sample for sample in batch if sample is not None] # Filter out None samples
    if not batch: return None
    try:
        features = torch.stack([sample['features'] for sample in batch])
        attention_mask = torch.stack([sample['attention_mask'] for sample in batch])
        labels = torch.stack([sample['labels'] for sample in batch]) # Expect tensors with -100
        return {"features": features, "attention_mask": attention_mask, "labels": labels}
    except Exception as e: print(f"Collate error: {e}"); raise

# --- Metrics ---
# Load metrics (safer way)
try:
    sacrebleu_metric = evaluate.load("sacrebleu")
    chrf_metric = evaluate.load("chrf")
    metrics_loaded = True
    print("Metrics 'sacrebleu' and 'chrf' loaded using 'evaluate' library.")
except Exception as e:
    print(f"Warning: Could not load metrics using 'evaluate': {e}")
    print("BLEU/CHRF calculation during training/evaluation might fail.")
    metrics_loaded = False

# Define compute_metrics function (needs tokenizer available)
def compute_metrics_fn(tokenizer_for_decode):
    def compute_metrics(eval_pred):
        # eval_pred.predictions often is a tuple if multiple outputs, logits are usually first
        logits = eval_pred.predictions[0] if isinstance(eval_pred.predictions, tuple) else eval_pred.predictions
        labels = eval_pred.label_ids

        if logits is None or labels is None:
            print("Warning: Logits or labels are None in compute_metrics.")
            return {}

        # Get predicted token ids
        # Ensure logits is on CPU and numpy for argmax
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        predictions_ids = np.argmax(logits, axis=-1)

        # Decode predictions
        # Replace potentially negative indices (if any) before decoding
        predictions_ids = np.maximum(predictions_ids, 0)
        decoded_preds = tokenizer_for_decode.batch_decode(predictions_ids, skip_special_tokens=True)

        # Decode labels (handle -100)
        labels = np.where(labels != -100, labels, tokenizer_for_decode.pad_token_id)
        decoded_labels = tokenizer_for_decode.batch_decode(labels, skip_special_tokens=True)

        # Format for sacrebleu
        decoded_labels_for_bleu = [[label] for label in decoded_labels]

        results = {}
        if metrics_loaded:
            try:
                # Use compute method from evaluate library
                bleu_result = sacrebleu_metric.compute(predictions=decoded_preds, references=decoded_labels_for_bleu)
                chrf_result = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels_for_bleu) # word_match=False is default

                results["bleu"] = bleu_result["score"] # Main BLEU score
                results["chrf"] = chrf_result["score"] # Main CHRF score (likely CHRF++)
                # Add BLEU n-gram precisions if desired
                # results.update({f"bleu_{n}_prec": bleu_result['precisions'][i] for i, n in enumerate([1, 2, 3, 4])})

            except Exception as e:
                print(f"Error calculating metrics: {e}")
                results["bleu"] = 0.0
                results["chrf"] = 0.0
        else:
             print("Metrics library not loaded, skipping BLEU/CHRF calculation.")
             results["bleu"] = 0.0 # Return default

        # Add generation length metric
        prediction_lens = [np.count_nonzero(pred != tokenizer_for_decode.pad_token_id) for pred in predictions_ids]
        results["gen_len"] = np.mean(prediction_lens) if prediction_lens else 0.0

        # Round results
        return {k: round(v, 4) for k, v in results.items()}
    return compute_metrics

# --- Fine-tuning Configuration ---
print("\n--- Setting up Fine-Tuning ---")
seed = train_cfg.get('TrainingArguments.seed', 42); set_seed(seed)
val_split_ratio = train_cfg.get('DatasetArguments.validation_split_ratio', 0.1)
start_model_path = train_cfg.get('ModelArguments.model_name_or_path') # Use the configured path
feature_dim = train_cfg.get('ModelArguments.feature_dim', 208)

# --- Experiment Name & Output Dir ---
exp_name_parts = [
    "SSL_FineTune",
    f"Start_{start_model_path.split('/')[-1].replace(':','-')}", # Sanitize path separators/colons
    f"LR_{train_cfg.get('TrainingArguments.learning_rate', 'Def')}",
    f"Epochs_{train_cfg.get('TrainingArguments.num_train_epochs', 'Def')}", # Use correct arg name
    f"Seed_{seed}"
]
exp_name_base = "-".join(exp_name_parts)
exp_name_safe = "".join(c if c.isalnum() or c in ('-', '_', '.') else '_' for c in exp_name_base)
training_output_dir_base = train_cfg.get('TrainingArguments.output_folder') # Already joined with MODEL_OUTPUT_DIR
training_output_dir = os.path.join(training_output_dir_base, exp_name_safe)
print(f"Fine-tuning Output Directory: {training_output_dir}")

# --- Load Tokenizer ---
print(f"Loading Tokenizer from: {start_model_path}")
try: tokenizer = AutoTokenizer.from_pretrained(start_model_path, legacy=False) # Try setting legacy=False for potential SentencePiece fix
except Exception as e: print(f"Error loading tokenizer {start_model_path}: {e}"); exit()

# --- Prepare Data Split ---
print("\n--- Preparing Train/Validation Split ---")
train_labels_path = train_cfg.get('DatasetArguments.train_labels_dataset_path')
if not train_labels_path or not os.path.exists(train_labels_path): print(f"ERROR: Train labels {train_labels_path} not found."); exit()
all_train_df = pd.read_csv(train_labels_path); all_train_df['ID'] = all_train_df['ID'].astype(str)
all_ids = all_train_df['ID'].tolist(); all_sentences = all_train_df['Translation'].tolist()
train_ids, valid_ids, train_sentences, valid_sentences = train_test_split(all_ids, all_sentences, test_size=val_split_ratio, random_state=seed)
print(f"Splitting data: {len(train_ids)} train samples, {len(valid_ids)} validation samples.")
train_id_to_sentence = dict(zip(train_ids, train_sentences)); valid_id_to_sentence = dict(zip(valid_ids, valid_sentences))

h5_path = train_cfg.get('DatasetArguments.train_dataset_path')
all_h5_keys, h5_lengths_map = VideoDataset.preload_h5_info(h5_path)
all_h5_keys_set = set(all_h5_keys)
train_keys_final = [id_ for id_ in train_ids if id_ in all_h5_keys_set]
valid_keys_final = [id_ for id_ in valid_ids if id_ in all_h5_keys_set]
print(f"Filtered keys vs HDF5: {len(train_keys_final)} train, {len(valid_keys_final)} validation")

# --- Instantiate Datasets ---
print("\nInstantiating Training Dataset...")
train_dataset = VideoDataset(
    h5_file_path=h5_path, data_keys_to_use=train_keys_final, id_to_sentence_map=train_id_to_sentence,
    tokenizer=tokenizer, h5_lengths_map=h5_lengths_map,
    max_seq_length=train_cfg.get('DatasetArguments.max_sequence_length', 512), # Use updated value
    max_label_length=train_cfg.get('DatasetArguments.max_label_length', 128),
    is_test_set=False, feature_dim=feature_dim )

print("\nInstantiating Validation Dataset...")
valid_dataset = VideoDataset(
    h5_file_path=h5_path, data_keys_to_use=valid_keys_final, id_to_sentence_map=valid_id_to_sentence,
    tokenizer=tokenizer, h5_lengths_map=h5_lengths_map,
    max_seq_length=train_cfg.get('DatasetArguments.max_sequence_length', 512), # Use updated value
    max_label_length=train_cfg.get('DatasetArguments.max_label_length', 128),
    is_test_set=False, feature_dim=feature_dim )

# --- Initialize Model ---
print(f"\nInitializing Model for Fine-tuning from: {start_model_path}")
model = SNLTraslationModel(
    feature_dim=feature_dim,
    model_name_or_path=start_model_path,
    # Use dropout from TrainingArguments if specified, else default
    dropout_rate=train_cfg.get('TrainingArguments.dropout_rate', 0.1)
)

# --- > FIX DECODER START TOKEN ID IN MODEL CONFIG < ---
if model.model.config.decoder_start_token_id is None and tokenizer.pad_token_id is not None:
    print("Setting model.config.decoder_start_token_id to pad_token_id...")
    model.model.config.decoder_start_token_id = tokenizer.pad_token_id
elif model.model.config.decoder_start_token_id is None:
    print("ERROR! Cannot set decoder_start_token_id: tokenizer.pad_token_id is None.")
    exit() # Stop if we can't fix it
else:
    print(f"Model config decoder_start_token_id already set: {model.model.config.decoder_start_token_id}")

model.to(device)

# --- Training Arguments ---
print("\nSetting Training Arguments for Fine-tuning...")

# --- > Prepare GenerationConfig < ---
from transformers import GenerationConfig
gen_config = GenerationConfig.from_pretrained(
    start_model_path, # Load defaults from the starting checkpoint's config
    # Override with specific settings for validation generation
    max_length=train_cfg.get('DatasetArguments.max_label_length', 128),
    num_beams=train_cfg.get('TrainingArguments.generation_num_beams', eval_cfg.get('GenerationArguments.num_beams', 4)), # Allow overriding in train_cfg or use eval_cfg
    # Add other relevant params from eval_cfg if desired for validation loop generation
    early_stopping=eval_cfg.get('GenerationArguments.early_stopping', True),
    length_penalty=eval_cfg.get('GenerationArguments.length_penalty', 1.0),
    no_repeat_ngram_size=eval_cfg.get('GenerationArguments.no_repeat_ngram_size', 0),
    # Sample params are usually off for BLEU eval, but can be inherited/set if needed
    do_sample = eval_cfg.get('GenerationArguments.do_sample', False),
    top_p = eval_cfg.get('GenerationArguments.top_p', 1.0) if eval_cfg.get('GenerationArguments.do_sample') else None,
    temperature = eval_cfg.get('GenerationArguments.temperature', 1.0) if eval_cfg.get('GenerationArguments.do_sample') else None,

)
# --- > Explicitly set the decoder_start_token_id in GenerationConfig <---
if gen_config.decoder_start_token_id is None and tokenizer.pad_token_id is not None:
    print("Setting decoder_start_token_id in GenerationConfig...")
    gen_config.decoder_start_token_id = tokenizer.pad_token_id
elif gen_config.decoder_start_token_id is None:
     print("ERROR! Cannot set decoder_start_token_id in GenerationConfig: tokenizer.pad_token_id is None.")
     exit() # Stop if we can't fix it

# Filter None sample args if do_sample is false
if not gen_config.do_sample:
    if hasattr(gen_config, 'top_p'): delattr(gen_config, 'top_p')
    if hasattr(gen_config, 'temperature'): delattr(gen_config, 'temperature')


training_args = Seq2SeqTrainingArguments(
    output_dir=training_output_dir,
    # Use renamed 'eval_strategy'
    eval_strategy=train_cfg.get('TrainingArguments.evaluation_strategy', 'epoch'),
    logging_strategy=train_cfg.get('TrainingArguments.logging_strategy', 'epoch'),
    save_strategy=train_cfg.get('TrainingArguments.save_strategy', 'epoch'),
    learning_rate=train_cfg.get('TrainingArguments.learning_rate', 5e-5),
    # Use direct arg names from config where possible
    per_device_train_batch_size=train_cfg.get('TrainingArguments.per_device_train_batch_size', 1),
    per_device_eval_batch_size=train_cfg.get('TrainingArguments.per_device_eval_batch_size', 4),
    gradient_accumulation_steps=train_cfg.get('TrainingArguments.gradient_accumulation_steps', 4),
    num_train_epochs=train_cfg.get('TrainingArguments.num_train_epochs', 20),
    weight_decay=train_cfg.get('TrainingArguments.weight_decay', 0.01),
    fp16=train_cfg.get('TrainingArguments.fp16', True) and torch.cuda.is_available(),
    gradient_checkpointing=train_cfg.get('TrainingArguments.gradient_checkpointing', True),
    optim=train_cfg.get('TrainingArguments.optim', 'adafactor'),
    load_best_model_at_end=train_cfg.get('TrainingArguments.load_best_model_at_end', True),
    metric_for_best_model=train_cfg.get('TrainingArguments.metric_for_best_model', 'eval_bleu'),
    greater_is_better=train_cfg.get('TrainingArguments.greater_is_better', True),
    predict_with_generate=train_cfg.get('TrainingArguments.predict_with_generate', True),
    # Pass the GenerationConfig object
    generation_config=gen_config,
    report_to="none", # Set Wandb arguments if needed
    seed=seed,
    save_total_limit=train_cfg.get('TrainingArguments.save_total_limit', 2),
    dataloader_num_workers=train_cfg.get('TrainingArguments.dataloader_num_workers', 0),
)

# --- Initialize Trainer ---
print("\nInitializing Seq2SeqTrainer...")
# Pass the tokenizer to the compute_metrics function factory
compute_metrics_with_tokenizer = compute_metrics_fn(tokenizer)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer, # Pass tokenizer for padding/saving
    data_collator=collate_fn_train_val,
    compute_metrics=compute_metrics_with_tokenizer # Pass the metrics function
)

# --- Start Fine-tuning ---
print(f"\n--- Starting Fine-tuning ---")
best_model_checkpoint_path = None # Initialize in case training fails
try:
    train_result = trainer.train(
        resume_from_checkpoint=train_cfg.get('ModelArguments.resume_checkpoint') if train_cfg.get('ModelArguments.resume') else None
    )
    print("Fine-tuning Finished.")

    # --- Save Final Model, State, and Metrics ---
    print("\nSaving final best model and tokenizer...")
    # Trainer should have loaded the best model if load_best_model_at_end=True
    # The path to the best checkpoint is stored in trainer state
    best_model_checkpoint_path = trainer.state.best_model_checkpoint
    final_save_path = os.path.join(training_output_dir, "best_finetuned_checkpoint")

    if best_model_checkpoint_path:
         print(f"Best checkpoint identified at: {best_model_checkpoint_path}")
         # If load_best_model_at_end was true, trainer.model is already the best one
         # We just need to save it to our designated final location
         trainer.save_model(final_save_path)
         print(f"Best fine-tuned model saved to: {final_save_path}")
         # Also save tokenizer and generation config there
         if hasattr(trainer, 'tokenizer') and trainer.tokenizer:
             trainer.tokenizer.save_pretrained(final_save_path)
             print(f"Tokenizer saved to: {final_save_path}")
         if hasattr(model.model, 'generation_config'): # Save the generation config used
             model.model.generation_config.save_pretrained(final_save_path)
             print(f"Generation config saved to: {final_save_path}")
    else:
         print("No best checkpoint found (or load_best_model_at_end=False). Saving final model state instead.")
         trainer.save_model(final_save_path) # Save whatever the final state is

    trainer.save_state() # Saves trainer state (like optimizer) in the main output dir

    metrics = train_result.metrics
    print("\n--- Final Fine-tuning Metrics (from train_result) ---")
    print(metrics)
    if trainer.state.best_metric is not None:
        print(f"Best validation {trainer.args.metric_for_best_model}: {trainer.state.best_metric:.4f} at step {trainer.state.best_model_checkpoint.split('-')[-1] if trainer.state.best_model_checkpoint else 'N/A'}")

    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics) # Saves to train_results.json

except Exception as e:
    print(f"\n--- Fine-tuning Interrupted or Failed: {e} ---"); import traceback; traceback.print_exc()
    # Attempt to get best checkpoint path even if training failed mid-epoch after saving one
    if trainer.state.best_model_checkpoint:
         best_model_checkpoint_path = trainer.state.best_model_checkpoint
         print(f"(Attempting to use best checkpoint found before interruption: {best_model_checkpoint_path})")
    else:
         print("(No best checkpoint identified before interruption)")


# Clean up memory before final evaluation
print("\nCleaning up training objects...")
# We need best_model_checkpoint_path for Cell 7, keep it
del model
del trainer
del train_dataset
del valid_dataset
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()

print("\n--- Fine-tuning Phase Complete ---")




# ===========================================
# Cell 7: Final Evaluation on Test Set
# ===========================================
print("\n--- Setting up Evaluation on ACTUAL TEST SET ---")

# --- Load Best Fine-tuned Model ---
if best_model_checkpoint_path and os.path.isdir(best_model_checkpoint_path):
    print(f"Loading best fine-tuned model for final test evaluation from: {best_model_checkpoint_path}")
    # Pass eval_cfg just for structure/defaults, actual config loaded from checkpoint path
    model_eval, tokenizer_eval = load_trained_model_for_eval(eval_cfg, best_model_checkpoint_path)

    # --- Load ACTUAL Test Dataset ---
    print("\nLoading ACTUAL Test Dataset...")
    test_h5_path_eval = eval_cfg.get('DatasetArguments.test_dataset_path')
    test_labels_path_eval = eval_cfg.get('DatasetArguments.test_labels_dataset_path')

    # Check if REAL test labels file exists
    test_labels_exist_eval = False
    test_id_to_sentence_map_eval = {}
    if test_labels_path_eval and os.path.exists(test_labels_path_eval):
        try:
            test_df_eval = pd.read_csv(test_labels_path_eval)
            test_df_eval['ID'] = test_df_eval['ID'].astype(str)
            test_id_to_sentence_map_eval = {row["ID"]: row["Translation"] for _, row in test_df_eval.iterrows()}
            test_labels_exist_eval = True
            print(f"Actual test labels file found and loaded: {test_labels_path_eval}")
        except Exception as e:
            print(f"Warning: Found test labels file ({test_labels_path_eval}) but couldn't read: {e}. Cannot calculate BLEU.")
            test_labels_exist_eval = False
    else:
        print(f"Actual test labels file not found or specified ({test_labels_path_eval}). Cannot calculate BLEU.")

    # Preload HDF5 info for the TEST file
    test_h5_keys, test_h5_lengths_map = VideoDataset.preload_h5_info(test_h5_path_eval)
    if not test_h5_keys: print("ERROR: Could not load keys from test HDF5 file."); exit()

    # Filter test keys (include all HDF5 keys if no labels exist)
    if test_labels_exist_eval:
         test_keys_final = [id_ for id_ in test_h5_keys if id_ in test_id_to_sentence_map_eval]
         print(f"Using {len(test_keys_final)} keys present in both test HDF5 and test labels CSV.")
    else:
         test_keys_final = test_h5_keys # Use all keys from HDF5 if no labels
         print(f"Using all {len(test_keys_final)} keys found in test HDF5 (no labels provided/loaded).")


    # Use sequence length consistent with training
    max_seq_len_eval = train_cfg.get('DatasetArguments.max_sequence_length', 600)

    test_dataset_eval = VideoDataset(
        h5_file_path=test_h5_path_eval,
        data_keys_to_use=test_keys_final,
        id_to_sentence_map=test_id_to_sentence_map_eval,
        tokenizer=tokenizer_eval, # Use tokenizer loaded from best checkpoint
        h5_lengths_map=test_h5_lengths_map,
        max_seq_length=max_seq_len_eval,
        max_label_length=eval_cfg.get('GenerationArguments.max_length', 128),
        is_test_set=True, # Returns strings for labels if they exist
        feature_dim=train_cfg.get('ModelArguments.feature_dim', 208) # Match feature dim used in training
    )

    # --- Create DataLoader ---
    eval_batch_size = eval_cfg.get('EvaluationArguments.batch_size', 8)
    eval_num_workers = eval_cfg.get('EvaluationArguments.dataloader_num_workers', 0)
    test_loader_eval = DataLoader(
        test_dataset_eval, batch_size=eval_batch_size, collate_fn=collate_fn_eval,
        shuffle=False, num_workers=eval_num_workers )
    print(f"\nCreated Test DataLoader: {len(test_dataset_eval)} samples, Batch size: {eval_batch_size}")

    # --- Run Evaluation Loop on Test Set ---
    print("\n--- Starting Evaluation Loop on TEST SET ---")
    keys_list_test = []; predictions_test = []; references_test = []

    # Use generation settings from eval config
    gen_args_config_eval = eval_cfg.get('GenerationArguments', Config({}))
    generation_config_test = {
        "max_length": gen_args_config_eval.get('max_length', 128),
        "num_beams": gen_args_config_eval.get('num_beams', 4),
        "early_stopping": gen_args_config_eval.get('early_stopping', True),
        "length_penalty": gen_args_config_eval.get('length_penalty', 1.0),
        "no_repeat_ngram_size": gen_args_config_eval.get('no_repeat_ngram_size', 0),
        "do_sample": gen_args_config_eval.get('do_sample', False),
        "top_p": gen_args_config_eval.get('top_p', 1.0) if gen_args_config_eval.get('do_sample') else None,
        "temperature": gen_args_config_eval.get('temperature', 1.0) if gen_args_config_eval.get('do_sample') else None,
    }
    if not generation_config_test.get("do_sample"):
         generation_config_test = {k: v for k, v in generation_config_test.items() if k not in ["top_p", "temperature"]}
    print(f"Test Generation config: {generation_config_test}")

    model_eval.eval() # Ensure model is in eval mode
    with torch.no_grad():
        for batch in tqdm(test_loader_eval, desc="Evaluating on Test Set"):
            if batch is None: continue
            try:
                output_sequences = model_eval.generate(
                    features=batch["features"].to(device), # Ensure data on correct device
                    attention_mask=batch["attention_mask"].to(device),
                    **generation_config_test )
                # Move back to CPU for decoding if needed, though decode handles it
                batch_preds = tokenizer_eval.batch_decode(output_sequences.cpu(), skip_special_tokens=True)
                keys_list_test.extend(batch["keys"])
                predictions_test.extend(batch_preds)
                if test_labels_exist_eval: references_test.extend(batch["labels"])
            except Exception as e:
                 print(f"\nError during test set generation: {e}"); import traceback; traceback.print_exc()
                 error_count = len(batch["keys"]); keys_list_test.extend(batch["keys"])
                 predictions_test.extend(["<GENERATION_ERROR>"] * error_count)
                 if test_labels_exist_eval: references_test.extend(batch["labels"])

    print("\n--- Test Set Evaluation Loop Finished ---")

    # --- Save Test Set Results ---
    # Ensure results_df_test is created even if the loop had issues
    if not keys_list_test:
        print("WARNING: No keys were processed during test set evaluation. Cannot save results.")
    else:
        results_df_test = pd.DataFrame({'ID': keys_list_test, 'Prediction': predictions_test})
        if test_labels_exist_eval and references_test and len(references_test) == len(predictions_test):
            results_df_test['Reference'] = references_test
            print(f"Added Reference column to test results (Total samples: {len(results_df_test)})")

        eval_results_save_dir = eval_cfg.EvaluationArguments.results_save_path
        # Use the fine-tuning experiment name for the results file
        prediction_save_file_test = os.path.join(eval_results_save_dir, f"{exp_name_safe}_TEST_predictions.csv")
        print(f"\nSaving test set predictions to: {prediction_save_file_test}")
        try:
            results_df_test.to_csv(prediction_save_file_test, encoding='utf-8', index=False)
            print("Test predictions saved successfully.")
        except Exception as e: print(f"Error saving test predictions: {e}")

        if not results_df_test.empty: print("\nSample Test Predictions:\n", results_df_test.head())

    # --- Calculate Metrics on Test Set (if references exist) ---
    if test_labels_exist_eval and references_test and len(references_test) == len(predictions_test) and metrics_loaded:
        print("\n--- Calculating Final Metrics on TEST SET ---")
        references_sacrebleu_test = [[ref] for ref in references_test]
        try:
            bleu_result_test = bleu_metric.compute(predictions=predictions_test, references=references_sacrebleu_test)
            chrf_result_test = chrf_metric.compute(predictions=predictions_test, references=references_sacrebleu_test, word_match=False)

            print(f"\nCorpus BLEU Score (Test Set): {bleu_result_test['score']:.4f}")
            print(f"Corpus CHRF++ Score (Test Set): {chrf_result_test['score']:.4f}")

            metrics_summary_test = {
                "BLEU": round(bleu_result_test["score"], 4),
                "CHRF++": round(chrf_result_test["score"], 4),
            }
            print('=' * 30); print("Detailed BLEU Scores (Test Set):"); print('-' * 10)
            for i, n in enumerate([1, 2, 3, 4]):
                 prec = bleu_result_test['precisions'][i]
                 metrics_summary_test[f"BLEU-{n}"] = round(prec, 4)
                 print(f"BLEU-{n} Precision: {prec:.2f}") # Note: Sacrebleu gives precisions, not full BLEU-n scores directly in dict
            print('=' * 30)

            metrics_save_file_test = os.path.join(eval_results_save_dir, f"{exp_name_safe}_TEST_metrics.json")
            print(f"Saving test metrics summary to: {metrics_save_file_test}")
            with open(metrics_save_file_test, 'w', encoding='utf-8') as f: json.dump(metrics_summary_test, f, indent=2)
            print("Test metrics saved.")
        except Exception as e: print(f"Error calculating test metrics: {e}")
    else:
        print("\nSkipping final metric calculation for test set.")

else:
    print("Skipping final evaluation on test set as the best fine-tuned checkpoint was not found.")

print("\n--- Full Script Finished ---")

Using device: cuda
Loading training config from: D:/configs\config_train.yaml
Training Config Loaded:
Config(ModelArguments=Config(model_name_or_path='D:/YTASL - Checkpoint/t5-base', base_model_name='T5', feature_dim=208, hidden_dim=768, resume=False, resume_checkpoint=''), DatasetArguments=Config(train_dataset_path='SSL.keypoints.train_signers_train_sentences.0.h5', train_labels_dataset_path='SSL.keypoints.train_signers_train_sentences.csv', validation_split_ratio=0.15, max_sequence_length=512, max_label_length=128), TrainingArguments=Config(output_folder='finetune_run', seed=42, num_train_epochs=20, learning_rate=5e-05, weight_decay=0.01, per_device_train_batch_size=1, gradient_accumulation_steps=4, per_device_eval_batch_size=4, fp16=True, gradient_checkpointing=True, optim='adafactor', evaluation_strategy='epoch', logging_strategy='epoch', save_strategy='epoch', save_total_limit=2, load_best_model_at_end=True, metric_for_best_model='eval_bleu', greater_is_better=True, dataloader_num

Preloading HDF5 info:   0%|          | 0/24109 [00:00<?, ?it/s]

Found 24109 non-empty keys in HDF5.
Filtered keys vs HDF5: 20492 train, 3617 validation

Instantiating Training Dataset...
Initializing dataset (Train/Val). Processing 20492 provided keys...


Preparing dataset samples:   0%|          | 0/20492 [00:00<?, ?it/s]

Finished preparing dataset: 20190 samples included, 302 samples skipped.

Instantiating Validation Dataset...
Initializing dataset (Train/Val). Processing 3617 provided keys...


Preparing dataset samples:   0%|          | 0/3617 [00:00<?, ?it/s]

Finished preparing dataset: 3572 samples included, 45 samples skipped.

Initializing Model for Fine-tuning from: D:/YTASL - Checkpoint/t5-base
Initializing SNLTraslationModel:
  Loading Base Model From: D:/YTASL - Checkpoint/t5-base
  Feature Dim (Input): 208


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at D:/YTASL - Checkpoint/t5-base and are newly initialized: ['decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfA

  Base T5 model loaded successfully.
  Model Hidden Dim (d_model): 512
  Projection Layer: Linear(208 -> 512)
Setting decoder_start_token_id to pad_token_id...
Metrics 'sacrebleu' and 'chrf' loaded using 'evaluate' library.

--- Setting up Fine-Tuning ---
Fine-tuning Output Directory: output_finetuned\finetune_run\SSL_FineTune-Start_t5-base-LR_5e-05-Epochs_20-Seed_42
Loading Tokenizer from: D:/YTASL - Checkpoint/t5-base

--- Preparing Train/Validation Split ---
Splitting data: 20492 train samples, 3617 validation samples.
Scanning HDF5 D:/saudi-signfor-all-competition\SSL.keypoints.train_signers_train_sentences.0.h5...


Preloading HDF5 info:   0%|          | 0/24109 [00:00<?, ?it/s]

Found 24109 non-empty keys in HDF5.
Filtered keys vs HDF5: 20492 train, 3617 validation

Instantiating Training Dataset...
Initializing dataset (Train/Val). Processing 20492 provided keys...


Preparing dataset samples:   0%|          | 0/20492 [00:00<?, ?it/s]

Finished preparing dataset: 20190 samples included, 302 samples skipped.

Instantiating Validation Dataset...
Initializing dataset (Train/Val). Processing 3617 provided keys...


Preparing dataset samples:   0%|          | 0/3617 [00:00<?, ?it/s]

Finished preparing dataset: 3572 samples included, 45 samples skipped.

Initializing Model for Fine-tuning from: D:/YTASL - Checkpoint/t5-base
Initializing SNLTraslationModel:
  Loading Base Model From: D:/YTASL - Checkpoint/t5-base
  Feature Dim (Input): 208


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at D:/YTASL - Checkpoint/t5-base and are newly initialized: ['decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfA

  Base T5 model loaded successfully.
  Model Hidden Dim (d_model): 512
  Projection Layer: Linear(208 -> 512)
Setting model.config.decoder_start_token_id to pad_token_id...

Setting Training Arguments for Fine-tuning...
Setting decoder_start_token_id in GenerationConfig...

Initializing Seq2SeqTrainer...

--- Starting Fine-tuning ---


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Enabling gradient checkpointing on underlying model...


Epoch,Training Loss,Validation Loss



--- Fine-tuning Interrupted or Failed: `decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation. ---
(No best checkpoint identified before interruption)

Cleaning up training objects...

--- Fine-tuning Phase Complete ---

--- Setting up Evaluation on ACTUAL TEST SET ---
Skipping final evaluation on test set as the best fine-tuned checkpoint was not found.

--- Full Script Finished ---


Traceback (most recent call last):
  File "C:\Users\Fatima\AppData\Local\Temp\ipykernel_18120\2461493475.py", line 742, in <module>
    train_result = trainer.train(
  File "C:\Users\Fatima\anaconda3\envs\pytorch_cuda\lib\site-packages\transformers\trainer.py", line 2245, in train
    return inner_training_loop(
  File "C:\Users\Fatima\anaconda3\envs\pytorch_cuda\lib\site-packages\transformers\trainer.py", line 2661, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "C:\Users\Fatima\anaconda3\envs\pytorch_cuda\lib\site-packages\transformers\trainer.py", line 3096, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "C:\Users\Fatima\anaconda3\envs\pytorch_cuda\lib\site-packages\transformers\trainer.py", line 3045, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "C:\Users\Fatima\anaconda3\envs\pytorch_cuda\lib\site-packages\transformers\trainer_seq2seq.py", line 197, in evaluate
    return super().