In [1]:
from datasets import load_from_disk, Dataset, concatenate_datasets
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tokenizers.processors import TemplateProcessing
import random


# --- Helper function for creating grouped validation data ---
def group_products_for_validation(dataset: Dataset) -> Dataset:
    reactant_to_products = {}
    for example in dataset:
        r = example['reactant']
        p = example['product']
        if r not in reactant_to_products:
            reactant_to_products[r] = []
        if p not in reactant_to_products[r]:
            reactant_to_products[r].append(p)
    new_data = [{'reactant': r, 'products': ps} for r, ps in reactant_to_products.items() if ps] # Ensure products list is not empty
    return Dataset.from_list(new_data)

# --- Load Data ---
print("Loading raw data...")
df_raw = load_from_disk('./data/data')

train_dataset_raw = df_raw['train']
validation_source_dataset_raw = df_raw['test']

print(f"Raw training samples: {len(train_dataset_raw)}")
print(f"Raw source samples for validation: {len(validation_source_dataset_raw)}")

validation_dataset_grouped = group_products_for_validation(validation_source_dataset_raw)
print(f"Grouped validation samples (reactants): {len(validation_dataset_grouped)}")

# --- Configuration ---
max_seq_length = 256
MODEL_NAME = "google/flan-t5-base"

tokenizer = AutoTokenizer.from_pretrained("smiles_tokenizer", use_fast=True)

tokenizer.backend_tokenizer.post_processor = TemplateProcessing(
    single="$A </s>",
    pair="$A </s> $B </s>",
    special_tokens=[("</s>", tokenizer.eos_token_id)],
)

quantization_config_bnb = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME, quantization_config=quantization_config_bnb, device_map="auto"
)

r_lora = 1536
peft_config = LoraConfig(
    r=r_lora, lora_alpha=2 * r_lora, target_modules=['q', 'v'],
    lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))


# --- Preprocessing Functions ---
def preprocess_train_function(examples):
    return tokenizer(
        examples["reactant"],
        text_target=examples["product"],  # Êõø‰ª£ as_target_tokenizer
        max_length=max_seq_length,
        truncation=True,
    )

def returnlength(example):
    return  {'length':len(example['labels'])}


def preprocess_eval_function(examples):
    reactant_inputs = tokenizer(
        examples["reactant"], max_length=max_seq_length, truncation=True, padding="do_not_pad"
    )
    return {
        "input_ids": reactant_inputs["input_ids"],
        "attention_mask": reactant_inputs["attention_mask"],
        "all_target_texts": examples["products"]
    }

print("Tokenizing training data...")
tokenized_train_dataset = train_dataset_raw.map(
    preprocess_train_function, batched=True, remove_columns=train_dataset_raw.column_names
)
tokenized_train_dataset = tokenized_train_dataset.map(
    returnlength)

print("Tokenizing validation data...")
tokenized_eval_dataset = validation_dataset_grouped.map(
    preprocess_eval_function, batched=True, remove_columns=validation_dataset_grouped.column_names
)

# --- Data Collators ---
train_data_collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, padding="longest", max_length=max_seq_length,
    pad_to_multiple_of=8 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else None
)

# Modified eval collator: no longer creates initial_decoder_input_ids
def custom_eval_data_collator(features):
    input_ids = [feature["input_ids"] for feature in features]
    attention_mask = [feature["attention_mask"] for feature in features]
    all_target_texts_list = [feature["all_target_texts"] for feature in features]

    # Pad only the encoder inputs (reactant)
    batch_encoder_inputs = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        padding="longest",
        max_length=max_seq_length,
        return_tensors="pt",
    )
    # The rest (decoder_input_ids, labels for loss) will be handled on-the-fly in the eval loop
    # for each target product.
    batch_encoder_inputs["all_target_texts"] = all_target_texts_list
    return batch_encoder_inputs


# --- DataLoaders ---
train_batch_size = 32
eval_batch_size = 24 # Smaller due to multiple forward passes per item in validation

train_dataloader = DataLoader(
    tokenized_train_dataset, batch_size=train_batch_size, shuffle=True,
    collate_fn=train_data_collator, pin_memory=True
)
eval_dataloader = DataLoader( # For validation
    tokenized_eval_dataset, batch_size=eval_batch_size, # Effective batch size for model forward pass is 1 in the inner loop
    collate_fn=custom_eval_data_collator, pin_memory=True
)

Loading raw data...
Raw training samples: 4770
Raw source samples for validation: 508
Grouped validation samples (reactants): 429
trainable params: 169,869,312 || all params: 417,447,168 || trainable%: 40.6924
Tokenizing training data...
Tokenizing validation data...


Map:   0%|          | 0/429 [00:00<?, ? examples/s]

In [2]:

# %%test input
for batch in train_dataloader:
    print(f"Sample 0 Input IDs: {batch['input_ids'][0]}")
    print(f"attention mask for input{batch['attention_mask'][0]}")
    print(f"Decoded Input: {tokenizer.decode(batch['input_ids'][0], skip_special_tokens=False)}")
    print(f"Sample 0 Labels: {batch['labels'][0]}")
    print(f"Decoded Labels: {tokenizer.convert_ids_to_tokens([l if l != -100 else tokenizer.pad_token_id for l in batch['labels'][0]], skip_special_tokens=False)}") # Handle -100 for decoding
    print(f"Sample 0 Decoder Input IDs: {batch['decoder_input_ids'][0]}")
    print(f"Decoded Decoder Input IDs: {tokenizer.convert_ids_to_tokens(batch['decoder_input_ids'][0], skip_special_tokens=False)}")
    print(f"Model's decoder_start_token_id: {model.config.decoder_start_token_id}")
    print(f"Tokenizer's pad_token_id: {tokenizer.pad_token_id}")
    print(f"length:{batch['length']}")
    break


Sample 0 Input IDs: tensor([ 83, 152, 483,  11, 109,   3,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0])
attention mask for inputtensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Decoded Input: CCC(F)C(Cl)(Cl)Cl.Cl</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Sample 0 Labels: tensor([ 109,   11,  115,  149,  147,  483,    3, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -

In [None]:
from bitsandbytes.optim import AdamW8bit
import os
import time
import math
import gc
import numpy as np
import torch
from transformers import get_scheduler
from rouge_score import rouge_scorer
from tqdm.auto import tqdm
import json

# --- Evaluation Metrics Calculation ---
def calculate_metrics_on_validation(model, tokenizer, eval_dataloader, device, rouge_calculator, progress_bar=None,
                                   max_rouge_samples=1000, calculate_rouge=False):
    model.eval()
    total_min_val_loss_sum = 0.0
    total_max_rouge_l_sum = 0.0
    val_loss_items_count = 0
    rouge_samples_processed = 0
    all_min_losses_for_perplexity = []

    with torch.autograd.set_detect_anomaly(False):
        for batch_data in eval_dataloader:
            batch_encoder_input_ids = batch_data["input_ids"].to(device, non_blocking=True)
            batch_encoder_attention_mask = batch_data["attention_mask"].to(device, non_blocking=True)
            batch_all_target_texts = batch_data["all_target_texts"]

            # Only generate predictions for RougeL if we're calculating it this epoch
            generated_ids_for_rouge = None
            if calculate_rouge and rouge_samples_processed < max_rouge_samples:
                num_to_gen_for_rouge = min(batch_encoder_input_ids.size(0), max_rouge_samples - rouge_samples_processed)
                if num_to_gen_for_rouge > 0:
                    with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda", enabled=torch.cuda.is_available()):
                        generated_ids_for_rouge = model.generate(
                            input_ids=batch_encoder_input_ids[:num_to_gen_for_rouge],
                            attention_mask=batch_encoder_attention_mask[:num_to_gen_for_rouge],
                            max_length=max_seq_length, num_beams=3, early_stopping=True,
                            bos_token_id=model.config.bos_token_id,eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id,
                            decoder_start_token_id=model.config.decoder_start_token_id
                        )

            # Iterate through each reactant in the batch
            for i in range(batch_encoder_input_ids.size(0)):
                reactant_input_ids = batch_encoder_input_ids[i:i+1]
                reactant_attention_mask = batch_encoder_attention_mask[i:i+1]
                item_possible_target_texts = batch_all_target_texts[i]

                if not item_possible_target_texts or all(not t.strip() for t in item_possible_target_texts):
                    if calculate_rouge and rouge_samples_processed < max_rouge_samples and i < (num_to_gen_for_rouge if 'num_to_gen_for_rouge' in locals() else 0):
                        rouge_samples_processed += 1
                    continue

                # --- Calculate Min Loss (multiple forward passes per reactant) ---
                min_loss_for_item_numeric = float('inf')
                for target_text in item_possible_target_texts:
                    if not target_text.strip(): continue

                    tokenized_target = tokenizer(
                        target_text, max_length=max_seq_length,
                        padding="longest", truncation=True, return_tensors="pt"
                    )
                    target_labels_single = tokenized_target.input_ids.to(device)

                    decoder_input_ids_single = model.prepare_decoder_input_ids_from_labels(labels=target_labels_single.clone())

                    with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda", enabled=torch.cuda.is_available()):
                        outputs_per_target = model(
                            input_ids=reactant_input_ids,
                            attention_mask=reactant_attention_mask,
                            decoder_input_ids=decoder_input_ids_single,
                            use_cache=False
                        )
                        logits_per_target = outputs_per_target.logits

                    output_seq_len = logits_per_target.size(1)
                    aligned_target_labels = torch.full((1, output_seq_len), tokenizer.pad_token_id, dtype=torch.long, device=device)
                    actual_target_len = min(output_seq_len, target_labels_single.size(1))
                    aligned_target_labels[0, :actual_target_len] = target_labels_single[0, :actual_target_len]
                    aligned_target_labels[aligned_target_labels == tokenizer.pad_token_id] = -100

                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
                    current_loss_value = loss_fct(
                        logits_per_target.view(-1, logits_per_target.size(-1)),
                        aligned_target_labels.view(-1)
                    )
                    min_loss_for_item_numeric = min(min_loss_for_item_numeric, current_loss_value.item())
                    del outputs_per_target, logits_per_target, tokenized_target, target_labels_single, decoder_input_ids_single, aligned_target_labels, current_loss_value

                if min_loss_for_item_numeric != float('inf'):
                    total_min_val_loss_sum += min_loss_for_item_numeric
                    all_min_losses_for_perplexity.append(min_loss_for_item_numeric)
                    val_loss_items_count += 1

                # --- Calculate Max ROUGE (only if requested) ---
                if calculate_rouge and rouge_samples_processed < max_rouge_samples and generated_ids_for_rouge is not None and i < len(generated_ids_for_rouge):
                    generated_text = tokenizer.decode(generated_ids_for_rouge[i], skip_special_tokens=True)
                    current_max_rouge_l_for_item = 0.0
                    for ref_text in item_possible_target_texts:
                        if not ref_text.strip(): continue
                        try:
                            rouge_scores = rouge_calculator.score(ref_text, generated_text)
                            current_max_rouge_l_for_item = max(current_max_rouge_l_for_item, rouge_scores['rougeL'].fmeasure)
                        except Exception: pass
                    total_max_rouge_l_sum += current_max_rouge_l_for_item
                    rouge_samples_processed += 1

            if 'generated_ids_for_rouge' in locals() and generated_ids_for_rouge is not None:
                del generated_ids_for_rouge
            del batch_encoder_input_ids, batch_encoder_attention_mask, batch_all_target_texts
            gc.collect(); torch.cuda.empty_cache()

    avg_min_val_loss = total_min_val_loss_sum / val_loss_items_count if val_loss_items_count > 0 else float('inf')

    # Calculate Rouge-L and perplexity only if requested or needed
    avg_max_rouge_l = 0.0
    if calculate_rouge:
        avg_max_rouge_l = total_max_rouge_l_sum / rouge_samples_processed if rouge_samples_processed > 0 else 0.0

    valid_losses_for_ppl = [l for l in all_min_losses_for_perplexity if l > 0 and l != float('inf')]
    perplexity = math.exp(sum(valid_losses_for_ppl) / len(valid_losses_for_ppl)) if valid_losses_for_ppl else float('inf')

    metrics = {"loss": avg_min_val_loss, "perplexity": perplexity}
    if calculate_rouge:
        metrics["rouge_l"] = avg_max_rouge_l

    # Display validation metrics based on what was calculated
    if progress_bar:
        status_msg = f"Validation - Loss: {metrics['loss']:.4f}, PPL: {metrics['perplexity']:.2f}"
        if calculate_rouge:
            status_msg += f", ROUGE-L: {metrics['rouge_l']:.4f} ({rouge_samples_processed} samples)"
        else:
            status_msg += f" (ROUGE not calculated this epoch)"
        progress_bar.write(status_msg)

    return metrics

# --- Scheduled Sampling Function ---
def get_scheduled_sampling_ratio(epoch, num_epochs, strategy="linear"):
    """
    ËÆ°ÁÆóscheduled samplingÁöÑÊØî‰æã
    epoch: ÂΩìÂâçepoch
    num_epochs: ÊÄªepochÊï∞
    strategy: Á≠ñÁï•Á±ªÂûãÔºåÂèØ‰ª•ÊòØlinearÊàñinverse_sigmoid

    ËøîÂõûÂÄº: ‰ΩøÁî®ÁúüÂÆûÊ†áÁ≠æ(teacher forcing)ÁöÑÊ¶ÇÁéá
    """
    if strategy == "linear":
        # Á∫øÊÄßË°∞ÂáèÔºö‰ªé1.0Âà∞0.0ÁöÑÊ¶ÇÁéá‰ΩøÁî®teacher forcing
        return max(0.0, 1.0 - (epoch - 1) / (num_epochs * 0.75))
    elif strategy == "inverse_sigmoid":
        # Inverse sigmoidË°∞ÂáèÔºöÊèê‰æõÊõ¥Âπ≥ÊªëÁöÑË°∞ÂáèÊõ≤Á∫ø
        k = num_epochs * 0.2  # ÊéßÂà∂Ë°∞ÂáèÈÄüÂ∫¶ÁöÑÂèÇÊï∞
        return k / (k + math.exp(epoch / k))
    else:
        return 1.0  # ÈªòËÆ§ÂÖ®ÈÉ®‰ΩøÁî®teacher forcing

rouge_calculator = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

model.config.use_cache = False
optimizer = AdamW8bit(model.parameters(), lr=5e-4, weight_decay=0.01)
num_epochs = 200
gradient_accumulation_steps = 4
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration for metric calculations
rouge_calculation_frequency = 5  # Calculate RougeL every N epochs
metric_strategy_threshold = 2  # Use different metrics based on epoch % N

early_stop_patience = 10
no_improve_epochs = 0
best_val_metrics = {"loss": float("inf"), "rouge_l": 0.0, "perplexity": float("inf"), "epoch": 0}
output_dir = "./best_model_multi_eval_v2_correct_loss"
os.makedirs(output_dir, exist_ok=True)

label_smoothing_factor = 0.1
training_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=label_smoothing_factor)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
total_train_steps = num_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
    name="cosine", optimizer=optimizer,
    num_warmup_steps=int(0.1 * total_train_steps), num_training_steps=total_train_steps,
)

# Initialize metrics tracking for visualization
metrics = {
    'epoch': [],
    'train_loss': [],
    'val_loss': [],
    'val_perplexity': [],
    'val_rouge_l': [],
    'teacher_forcing_ratio': [],
    'learning_rate': [],
    'epoch_time': [],
    'is_best': []
}

# For batch-level tracking
batch_metrics = {
    'global_step': [],
    'batch_loss': [],
    'learning_rate': []
}

print(f"\n--- Starting Training ---")
print(f"Device: {device}, Epochs: {num_epochs}, Grad Accum: {gradient_accumulation_steps}")
print(f"RougeL calculated every {rouge_calculation_frequency} epochs")
print(f"Using scheduled sampling with linear decay strategy")
overall_progress_bar = tqdm(range(total_train_steps), desc="Training Progress")

global_step = 0
for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    model.train()
    total_train_loss_for_epoch = 0.0
    optimizer.zero_grad(set_to_none=True)

    # Á°ÆÂÆöÂΩìÂâçepochÁöÑscheduled samplingÊØî‰æã
    teacher_forcing_ratio = get_scheduled_sampling_ratio(epoch, num_epochs, strategy="inverse_sigmoid")

    # Á°ÆÂÆöÊòØÂê¶ËÆ°ÁÆóRougeL
    calculate_rouge_this_epoch = (epoch % rouge_calculation_frequency == 0)

    # Êõ¥Êñ∞ËøõÂ∫¶Êù°
    overall_progress_bar.set_description(f"Epoch {epoch}/{num_epochs} (TF={teacher_forcing_ratio:.2f})")

    for step, batch in enumerate(train_dataloader, start=1):
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        labels = batch["labels"].to(device, non_blocking=True)
        decoder_input_ids = batch["decoder_input_ids"].to(device, non_blocking=True)

        # Ëé∑Âèñbatch‰∏≠ÁöÑÂ∫èÂàóÈïøÂ∫¶‰ø°ÊÅØ
        if "length" in batch:
            seq_lengths = batch["length"].to(device, non_blocking=True)
        else:
            # Â¶ÇÊûúÊ≤°ÊúâÊèê‰æõlengthÔºåËÆ°ÁÆóÈùûpaddingÁöÑÈïøÂ∫¶
            seq_lengths = (labels != -100).sum(dim=1)

        # ÂÜ≥ÂÆöÊòØÂê¶‰ΩøÁî®scheduled sampling
        use_scheduled_sampling = teacher_forcing_ratio < 1.0 and np.random.random() >= teacher_forcing_ratio

        if use_scheduled_sampling:
            # ----- Scheduled SamplingÂÆûÁé∞ -----
            batch_size = input_ids.size(0)

            # Ëé∑ÂèñÊ†áÁ≠æÂΩ¢Áä∂‰ø°ÊÅØ
            tgt_len = labels.size(1)

            # ‰ªéÁ¨¨‰∏Ä‰∏™tokenÂºÄÂßã
            curr_decoder_input_ids = decoder_input_ids[:, :1].clone()

            # ÈÄêÊ≠•ÁîüÊàêÂ∫èÂàóÂπ∂ËÆ°ÁÆóÊçüÂ§±
            with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda", enabled=torch.cuda.is_available()):
                # È¶ñÂÖàÂæóÂà∞ÂÆåÊï¥ÁöÑÈ¢ÑÊµãÂ∫èÂàó
                for i in range(1, tgt_len):
                    # ÂØπÂΩìÂâçÂ∫èÂàóËøõË°åÂâçÂêëËÆ°ÁÆó
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        decoder_input_ids=curr_decoder_input_ids,
                        use_cache=False
                    )

                    # Ëé∑ÂèñÈ¢ÑÊµãÁöÑ‰∏ã‰∏Ä‰∏™token
                    next_token_logits = outputs.logits[:, -1, :]
                    next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

                    # Ê∑ªÂä†Âà∞Â∫èÂàó‰∏≠
                    curr_decoder_input_ids = torch.cat([curr_decoder_input_ids, next_token], dim=1)

                    # ÊèêÂâçÂÅúÊ≠¢ÔºàÂ¶ÇÊûúÊâÄÊúâÊ†∑Êú¨ÈÉΩËææÂà∞‰∫ÜÊúÄÂ§ßÈïøÂ∫¶Ôºâ
                    if (i >= seq_lengths - 1).all().item():
                        break

                # ‰ΩøÁî®ÁîüÊàêÁöÑÂÆåÊï¥Â∫èÂàóËøõË°åÊúÄÁªàÁöÑÂâçÂêë‰º†Êí≠
                final_outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=curr_decoder_input_ids,
                    use_cache=False
                )

                # ËÆ°ÁÆóÊçüÂ§±ÔºåÂ§ÑÁêÜÂ∫èÂàóÈïøÂ∫¶‰∏çÂåπÈÖçÁöÑÊÉÖÂÜµ
                final_logits = final_outputs.logits

                # Á°Æ‰øùlabelsÂíålogitsÂΩ¢Áä∂ÂåπÈÖç
                if final_logits.size(1) < labels.size(1):
                    # Â¶ÇÊûúÁîüÊàêÁöÑÂ∫èÂàóÊØîÊ†áÁ≠æÁü≠ÔºåÂè™‰ΩøÁî®ÁîüÊàêÁöÑÈÉ®ÂàÜ
                    truncated_labels = labels[:, :final_logits.size(1)]
                    loss = training_loss_fct(final_logits.reshape(-1, final_logits.size(-1)), truncated_labels.reshape(-1))
                elif final_logits.size(1) > labels.size(1):
                    # Â¶ÇÊûúÁîüÊàêÁöÑÂ∫èÂàóÊØîÊ†áÁ≠æÈïøÔºåÊâ©Â±ïÊ†áÁ≠æ
                    padding = torch.full(
                        (batch_size, final_logits.size(1) - labels.size(1)),
                        -100, dtype=labels.dtype, device=labels.device
                    )
                    extended_labels = torch.cat([labels, padding], dim=1)
                    loss = training_loss_fct(final_logits.reshape(-1, final_logits.size(-1)), extended_labels.reshape(-1))
                else:
                    # Â¶ÇÊûúÈïøÂ∫¶ÂåπÈÖç
                    loss = training_loss_fct(final_logits.reshape(-1, final_logits.size(-1)), labels.reshape(-1))
        else:
            # ‰ΩøÁî®Ê†áÂáÜÁöÑteacher forcing
            with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda", enabled=torch.cuda.is_available()):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    use_cache=False
                )
                logits = outputs.logits
                loss = training_loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))

        scaled_loss = loss / gradient_accumulation_steps
        scaler.scale(scaled_loss).backward()
        total_train_loss_for_epoch += loss.item()

        # Track batch-level metrics
        batch_metrics['global_step'].append(global_step)
        batch_metrics['batch_loss'].append(loss.item())
        batch_metrics['learning_rate'].append(lr_scheduler.get_last_lr()[0])

        if (step % gradient_accumulation_steps == 0) or (step == len(train_dataloader)):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1

            # Êõ¥Êñ∞ËøõÂ∫¶Êù°
            avg_loss_so_far = total_train_loss_for_epoch / step
            overall_progress_bar.set_postfix({
                "loss": f"{avg_loss_so_far:.4f}",
                "lr": f"{lr_scheduler.get_last_lr()[0]:.3e}"
            })
            overall_progress_bar.update(1)

        # Ê∏ÖÁêÜÂÜÖÂ≠ò
        del input_ids, attention_mask, labels, decoder_input_ids, outputs, loss, scaled_loss
        if use_scheduled_sampling:
            del curr_decoder_input_ids, final_outputs, final_logits

        torch.cuda.empty_cache()

    avg_train_loss_epoch = total_train_loss_for_epoch / len(train_dataloader)
    epoch_time = time.time() - epoch_start_time

    overall_progress_bar.write(
        f"üöÄEpoch {epoch} - Avg Train Loss: {avg_train_loss_epoch:.4f} - " +
        f"Time: {epoch_time:.2f}s - Teacher Forcing: {teacher_forcing_ratio:.2f}"
    )

    # ÈÄöÁü•È™åËØÅÈò∂ÊÆµÂºÄÂßã
    overall_progress_bar.write(
        f"Running validation for epoch {epoch}" +
        (f" (with ROUGE-L)" if calculate_rouge_this_epoch else " (loss only)")
    )

    # ËÆ°ÁÆóÈ™åËØÅÊåáÊ†á
    validation_start_time = time.time()
    current_val_metrics = calculate_metrics_on_validation(
        model, tokenizer, eval_dataloader, device, rouge_calculator,
        progress_bar=overall_progress_bar,
        calculate_rouge=calculate_rouge_this_epoch
    )
    validation_time = time.time() - validation_start_time

    # Á°Æ‰øùÊàë‰ª¨Êúârouge_lÂÄºÔºåÂç≥‰ΩøÊ≤°ÊúâÂú®Ëøô‰∏™epochËÆ°ÁÆó
    if "rouge_l" not in current_val_metrics and epoch > 1:
        # Â¶ÇÊûúÊúâÂèØÁî®ÁöÑÂâçÂÄºÔºå‰ΩøÁî®ÂÆÉ
        current_val_metrics["rouge_l"] = best_val_metrics.get("rouge_l", 0.0)
    elif "rouge_l" not in current_val_metrics:
        current_val_metrics["rouge_l"] = 0.0

    # Á°ÆÂÆöÊòØÂê¶ÊòØÊúÄ‰Ω≥Ê®°Âûã
    is_best = False
    reason_for_best = ""

    # ÂÅ∂Êï∞epoch: ‰ºòÂÖàËÄÉËôëÊçüÂ§±, Â•áÊï∞epoch: ‰ºòÂÖàËÄÉËôëRougeLÔºàÂ¶ÇÊûúÂèØÁî®Ôºâ
    if epoch % metric_strategy_threshold == 0 and "rouge_l" in current_val_metrics:
        # ÂÅ∂Êï∞epoch with RougeL: ÁªºÂêàÊåáÊ†á
        rouge_improved = current_val_metrics["rouge_l"] > best_val_metrics["rouge_l"] + 0.0005
        loss_improved_significantly = current_val_metrics["loss"] < best_val_metrics["loss"] * 0.98  # 2%ÊîπËøõ

        if rouge_improved and current_val_metrics["loss"] <= best_val_metrics["loss"] * 1.05:  # ÂÖÅËÆ∏5%Êõ¥Â∑ÆÁöÑÊçüÂ§±
            is_best = True
            reason_for_best = f"‚ÜëROUGE-L ({best_val_metrics['rouge_l']:.4f} ‚Üí {current_val_metrics['rouge_l']:.4f})"
        elif loss_improved_significantly and current_val_metrics["rouge_l"] >= best_val_metrics["rouge_l"] * 0.98:  # ÂÖÅËÆ∏2%Êõ¥Â∑ÆÁöÑROUGE
            is_best = True
            reason_for_best = f"‚ÜëLoss ({best_val_metrics['loss']:.4f} ‚Üí {current_val_metrics['loss']:.4f})"
    else:
        # Â•áÊï∞epochÊàñÊó†RougeL: ‰ºòÂÖàËÄÉËôëÈ™åËØÅÊçüÂ§±
        if current_val_metrics["loss"] < best_val_metrics["loss"]:
            is_best = True
            reason_for_best = f"‚ÜëLoss ({best_val_metrics['loss']:.4f} ‚Üí {current_val_metrics['loss']:.4f})"

    # Store metrics for visualization
    metrics['epoch'].append(epoch)
    metrics['train_loss'].append(avg_train_loss_epoch)
    metrics['val_loss'].append(current_val_metrics["loss"])
    metrics['val_perplexity'].append(current_val_metrics["perplexity"])
    metrics['val_rouge_l'].append(current_val_metrics["rouge_l"])
    metrics['teacher_forcing_ratio'].append(teacher_forcing_ratio)
    metrics['learning_rate'].append(lr_scheduler.get_last_lr()[0])
    metrics['epoch_time'].append(epoch_time)
    metrics['is_best'].append(is_best)

    if is_best:
        best_val_metrics = current_val_metrics.copy()
        best_val_metrics["epoch"] = epoch
        no_improve_epochs = 0

        # ‰øùÂ≠òÊ®°Âûã
        model.save_pretrained(output_dir, save_embedding_layers=True)
        tokenizer.save_pretrained(output_dir)

        # Save training metrics as JSON at checkpoint
        metrics_path = os.path.join(output_dir, "training_metrics.json")
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)

        with open(os.path.join(output_dir, "best_metrics.txt"), "w") as f:
            f.write(f"Best Model from Epoch: {epoch}\nMetrics: {best_val_metrics}\n")

        overall_progress_bar.write(
            f"‚úÖ Ep {epoch}: New best! {reason_for_best}. Loss: {best_val_metrics['loss']:.4f}, "
            f"PPL: {best_val_metrics['perplexity']:.2f}" +
            (f", ROUGE: {best_val_metrics['rouge_l']:.4f}" if "rouge_l" in best_val_metrics else "") +
            f" - Val time: {validation_time:.2f}s"
        )
    else:
        no_improve_epochs += 1
        overall_progress_bar.write(
            f"‚ùó Ep {epoch}: No improvement ({no_improve_epochs}). "
            f"Best (Ep {best_val_metrics['epoch']}): Loss {best_val_metrics['loss']:.4f}, "
            f"PPL {best_val_metrics['perplexity']:.2f}" +
            (f", ROUGE {best_val_metrics['rouge_l']:.4f}" if "rouge_l" in best_val_metrics else "") +
            f" - Val time: {validation_time:.2f}s"
        )

        if no_improve_epochs >= early_stop_patience:
            overall_progress_bar.write(f"‚õî Early stopping at Epoch {epoch}.")
            break

    # # Generate intermediate plots every 10 epochs
    # if epoch % 50 == 0 or epoch == 1:
    #     try:
    #         create_sci_training_plots(metrics, output_dir, prefix=f"training_epoch_{epoch}", journal_format="nature")
    #         overall_progress_bar.write(f"üìä Generated training plots at epoch {epoch}")
    #     except Exception as e:
    #         overall_progress_bar.write(f"Warning: Could not generate plots: {e}")

    gc.collect()
    torch.cuda.empty_cache()

overall_progress_bar.close()
print("\n===== Training Complete =====")

# Generate final training visualization plots
# try:
#     create_sci_training_plots(metrics, output_dir, prefix="final_training", journal_format="nature")
#     print("üìä Final training plots generated successfully")
# except Exception as e:
#     print(f"Warning: Could not generate final plots: {e}")

# Save full training metrics
metrics_path = os.path.join(output_dir, "final_training_metrics.json")
with open(metrics_path, "w") as f:
    json.dump(metrics, f, indent=2)

if best_val_metrics["epoch"] > 0:
    print(f"Best model (Ep {best_val_metrics['epoch']}) saved to '{output_dir}'. Metrics: {best_val_metrics}")
else:
    print(f"No best model saved. Check '{output_dir}'.")


In [None]:
def create_sci_training_plots(metrics_data, output_dir, prefix="training", journal_format="nature"):
    """
    Create publication-quality plots of training metrics optimized for SCI journals

    Args:
        metrics_data: Dictionary containing training metrics
        output_dir: Directory to save plots
        prefix: Prefix for saved files
        journal_format: Journal style preset ('nature', 'science', 'ieee', 'elsevier')
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from matplotlib.gridspec import GridSpec
    from cycler import cycler
    import matplotlib as mpl

    # Create directory for plots
    plots_dir = os.path.join(output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # Set fixed DPI and formats
    DPI = 300
    FORMATS = ["jpg", "PNG"]

    # Journal-specific formatting presets - refined for SCI journal standards and adjusted for larger figure size
    journal_formats = {
        "default": {
            "figsize": (12, 15),
            "fontfamily": "serif",
            "fontname": "Times New Roman",
            "fontsize": 13,
            "labelsize": 15,
            "titlesize": 17,
            "linewidth": 2.25,
            "markersize": 120,
            "colors": ["#0072B2", "#D55E00", "#009E73", "#CC79A7", "#F0E442", "#56B4E9"],
        },
        "nature": {
            "figsize": (12, 15),
            "fontfamily": "sans-serif",
            "fontname": "Times New Roman",
            "fontsize": 13,
            "labelsize": 15,
            "titlesize": 17,
            "linewidth": 2.0,
            "markersize": 110,
            "colors": ["#0072B2", "#D55E00", "#009E73", "#CC79A7", "#E69F00", "#56B4E9"],
        },
        "science": {
            "figsize": (12, 15),
            "fontfamily": "sans-serif",
            "fontname": "Arial",
            "fontsize": 12,
            "labelsize": 14,
            "titlesize": 16,
            "linewidth": 2.0,
            "markersize": 110,
            "colors": ["#3C5488", "#DC0000", "#00A087", "#E64B35", "#8491B4", "#F39B7F"],
        },
        "ieee": {
            "figsize": (12, 15),
            "fontfamily": "serif",
            "fontname": "Times New Roman",
            "fontsize": 11,
            "labelsize": 13,
            "titlesize": 15,
            "linewidth": 2.0,
            "markersize": 110,
            "colors": ["#0072B2", "#D55E00", "#009E73", "#CC79A7", "#F0E442", "#56B4E9"],
        },
        "elsevier": {
            "figsize": (12, 15),
            "fontfamily": "serif",
            "fontname": "Times New Roman",
            "fontsize": 13,
            "labelsize": 15,
            "titlesize": 17,
            "linewidth": 2.0,
            "markersize": 110,
            "colors": ["#4477AA", "#EE6677", "#228833", "#CCBB44", "#66CCEE", "#AA3377"],
        }
    }

    # Get the specified journal format or default
    fmt = journal_formats.get(journal_format, journal_formats["default"])

    # Set global plotting parameters optimized for scientific publication
    plt.style.use('default')  # Reset to default style first
    mpl.rcParams.update({
        # Font settings
        'font.family': fmt["fontfamily"],
        'font.{}'.format(fmt["fontfamily"]): [fmt["fontname"]],
        'font.size': fmt["fontsize"],
        'axes.labelsize': fmt["labelsize"],
        'axes.titlesize': fmt["titlesize"],
        'xtick.labelsize': fmt["fontsize"],
        'ytick.labelsize': fmt["fontsize"],
        'legend.fontsize': fmt["fontsize"],

        # Figure settings
        'figure.figsize': fmt["figsize"],
        'figure.dpi': DPI,
        'figure.facecolor': 'white',
        'figure.edgecolor': 'white',

        # Axes settings
        'axes.facecolor': 'white',
        'axes.edgecolor': '#333333',
        'axes.prop_cycle': cycler('color', fmt["colors"]),
        'axes.linewidth': 1.0,  # Increased from 0.8
        'axes.grid': True,
        'axes.axisbelow': True,  # Place grid behind data

        # Grid settings - more visible grid lines
        'grid.alpha': 0.5,
        'grid.color': '#b0b0b0',
        'grid.linestyle': '--',
        'grid.linewidth': 1.0,  # Increased from 0.8

        # Line settings
        'lines.linewidth': fmt["linewidth"],
        'lines.markersize': fmt["markersize"]/20,
        'lines.markeredgewidth': 1.0,  # Increased from 0.8

        # Tick settings
        'xtick.major.size': 4.5,  # Increased from 3.5
        'ytick.major.size': 4.5,  # Increased from 3.5
        'xtick.minor.size': 2.5,  # Increased from 2
        'ytick.minor.size': 2.5,  # Increased from 2
        'xtick.major.pad': 4.5,  # Increased from 3.5
        'ytick.major.pad': 4.5,  # Increased from 3.5
        'xtick.color': '#333333',
        'ytick.color': '#333333',
        'xtick.direction': 'out',
        'ytick.direction': 'out',
        'xtick.major.width': 1.0,  # Increased from 0.8
        'ytick.major.width': 1.0,  # Increased from 0.8
        'xtick.minor.width': 0.8,  # Increased from 0.6
        'ytick.minor.width': 0.8,  # Increased from 0.6

        # Legend settings
        'legend.frameon': True,
        'legend.framealpha': 1.0,  # No transparency to avoid warnings
        'legend.edgecolor': '#cccccc',
        'legend.fancybox': True,

        # Save settings
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1,  # Increased from 0.05
        'savefig.dpi': DPI,
        'savefig.transparent': False,  # Set to False for consistent white background
    })

    # For colorblind-friendly and grayscale-compatible plots
    sns.set_palette(fmt["colors"])

    # Add colorblind-friendly and grayscale-friendly markers and line styles
    markers = ['o', 's', '^', 'd', 'v', '<', '>']
    linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1))]

    # Extract metrics
    epochs = metrics_data['epoch']
    train_loss = metrics_data['train_loss']
    val_loss = metrics_data['val_loss']
    val_perplexity = metrics_data['val_perplexity']
    val_rouge_l = metrics_data['val_rouge_l']
    teacher_forcing = metrics_data['teacher_forcing_ratio']
    learning_rates = metrics_data['learning_rate']
    epoch_times = metrics_data['epoch_time']
    is_best = metrics_data['is_best']

    # Find best epochs
    best_epochs = [e for i, e in enumerate(epochs) if is_best[i]]
    best_val_losses = [val_loss[i] for i, e in enumerate(epochs) if is_best[i]]

    # Create comprehensive plot with multiple metrics
    fig = plt.figure(figsize=fmt["figsize"], dpi=DPI, facecolor='white')

    # Use a different approach for layout to avoid warnings
    gs = GridSpec(3, 2, figure=fig)
    gs.update(wspace=0.35, hspace=0.45)  # Adjusted spacing for larger figure

    # 1.1 Loss curves (train vs validation)
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(epochs, train_loss, color=fmt["colors"][0],
             linestyle=linestyles[0], marker=markers[0], markevery=max(1, len(epochs)//10),
             label='Training Loss', linewidth=fmt["linewidth"])
    ax1.plot(epochs, val_loss, color=fmt["colors"][1],
             linestyle=linestyles[1], marker=markers[1], markevery=max(1, len(epochs)//10),
             label='Test Loss', linewidth=fmt["linewidth"])

    # Highlight best epochs
    if best_epochs:
        ax1.scatter(best_epochs, best_val_losses, c='gold', s=fmt["markersize"],
                   edgecolors='black', zorder=5, label='Best Models')

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend(loc='upper right', framealpha=1.0)  # No transparency
    ax1.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0)
    ax1.minorticks_on()
    ax1.grid(True, which='minor', linestyle=':', alpha=0.3, color='#b0b0b0', linewidth=0.8)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    # 1.2 Perplexity and ROUGE-L
    ax2 = fig.add_subplot(gs[0, 1])
    color = fmt["colors"][2]
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Perplexity', color=color)
    ax2.plot(epochs, val_perplexity, color=color,
             linestyle=linestyles[0], marker=markers[2], markevery=max(1, len(epochs)//10),
             linewidth=fmt["linewidth"])
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_ylim(bottom=0)  # Ensure y-axis starts at 0 or appropriate value
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)

    ax2_twin = ax2.twinx()
    color = fmt["colors"][3]
    ax2_twin.set_ylabel('ROUGE-L Score', color=color)
    ax2_twin.plot(epochs, val_rouge_l, color=color,
             linestyle=linestyles[1], marker=markers[3], markevery=max(1, len(epochs)//10),
             linewidth=fmt["linewidth"])
    ax2_twin.tick_params(axis='y', labelcolor=color)
    ax2_twin.spines['top'].set_visible(False)
    ax2_twin.spines['left'].set_visible(False)

    # Fix for identical ylims - ensure there's always a reasonable range
    top_value = max(max(val_rouge_l)*1.1, 0.1) if val_rouge_l and max(val_rouge_l) > 0 else 1.0
    ax2_twin.set_ylim(bottom=0, top=top_value)

    ax2.set_title('Validation Metrics: Perplexity and ROUGE-L')
    ax2.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0)
    ax2.minorticks_on()
    ax2.grid(True, which='minor', linestyle=':', alpha=0.3, color='#b0b0b0', linewidth=0.8)

    # 1.3 Teacher Forcing Ratio
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.plot(epochs, teacher_forcing, color=fmt["colors"][4],
             linestyle=linestyles[0], marker=markers[4], markevery=max(1, len(epochs)//10),
             linewidth=fmt["linewidth"])
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Teacher Forcing Ratio')
    ax3.set_title('Teacher Forcing Schedule')
    ax3.set_ylim(0, 1.05)
    ax3.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0)
    ax3.minorticks_on()
    ax3.grid(True, which='minor', linestyle=':', alpha=0.3, color='#b0b0b0', linewidth=0.8)
    ax3.spines['top'].set_visible(False)
    ax3.spines['right'].set_visible(False)

    # 1.4 Learning Rate
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.plot(epochs, learning_rates, color=fmt["colors"][5],
             linestyle=linestyles[0], marker=markers[5], markevery=max(1, len(epochs)//10),
             linewidth=fmt["linewidth"])
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Learning Rate')
    ax4.set_title('Learning Rate Schedule')
    ax4.set_yscale('log')
    ax4.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0)
    ax4.minorticks_on()
    ax4.grid(True, which='minor', linestyle=':', alpha=0.3, color='#b0b0b0', linewidth=0.8)
    ax4.spines['top'].set_visible(False)
    ax4.spines['right'].set_visible(False)

    # 1.5 Training time per epoch
    ax5 = fig.add_subplot(gs[2, 0])
    bars = ax5.bar(epochs, epoch_times, color=fmt["colors"][0], alpha=0.7,
            edgecolor='black', linewidth=0.8, width=0.7)

    # Add light value labels above bars
    if len(epochs) <= 15:  # Only add labels if not too crowded
        for bar in bars:
            height = bar.get_height()
            ax5.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{int(height)}', ha='center', va='bottom',
                    fontsize=fmt["fontsize"]-1, alpha=1.0)  # No transparency

    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Time (seconds)')
    ax5.set_title('Training Time per Epoch')
    ax5.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0, axis='y')
    ax5.spines['top'].set_visible(False)
    ax5.spines['right'].set_visible(False)

    # 1.6 Best models distribution
    ax6 = fig.add_subplot(gs[2, 1])
    if best_epochs:
        diff_epochs = np.diff([0] + best_epochs)
        bars = ax6.bar(range(len(diff_epochs)), diff_epochs, color=fmt["colors"][1], alpha=0.7,
                edgecolor='black', linewidth=0.8, width=0.7)

        # Removed the value labels above bars as requested

        ax6.set_xlabel('Best Model Index')
        ax6.set_ylabel('Epochs Between Best Models')
        ax6.set_title('Training Progress Pace')
        ax6.grid(True, linestyle='--', alpha=0.5, color='#b0b0b0', linewidth=1.0, axis='y')
        ax6.spines['top'].set_visible(False)
        ax6.spines['right'].set_visible(False)
    else:
        ax6.text(0.5, 0.5, "No best models recorded",
                 horizontalalignment='center', verticalalignment='center',
                 fontsize=fmt["fontsize"]+2)  # Increased fontsize for empty plot message
        ax6.set_title('Training Progress Pace')
        ax6.spines['top'].set_visible(False)
        ax6.spines['right'].set_visible(False)

    # Add overall title but don't use tight_layout which causes warnings
    fig.suptitle('Training Process Metrics', fontsize=fmt["titlesize"]+4, y=0.98)  # Increased suptitle size

    # Use subplots_adjust instead of tight_layout to avoid warnings
    fig.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.92)

    # Save only as JPG and PNG as requested
    for format_type in FORMATS:
        plt.savefig(f"{plots_dir}/{prefix}_combined.{format_type}",
                   dpi=DPI, bbox_inches='tight',
                   facecolor='white', edgecolor='none')

    # Close figure to free memory
    plt.close(fig)


In [None]:
import json
#
# with open('./best_model_multi_eval_v3_correct_loss/training_metrics', 'w') as f:
#     json.dump(metrics, f)
output_dir="./best_model_multi_eval_v3_correct_loss"
# ËØªÂèñÂ≠óÂÖ∏
with open('./best_model_multi_eval_v3_correct_loss/training_metrics', 'r') as f:
    loaded_dict = json.load(f)
create_sci_training_plots(loaded_dict, output_dir, prefix="final_training", journal_format="nature")