In [None]:
!pip install evaluate transformers datasets accelerate bitsandbytes peft rouge_score nltk bert_score sentencepiece --quiet

import os
import gc
import json
import math
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import evaluate # Using evaluate library for metrics
from sklearn.metrics import accuracy_score # Import accuracy score


# Transformers imports
from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
    BartTokenizer,
    BartForConditionalGeneration, # Changed to BART
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainer, # Optional, we are using a custom loop
    Seq2SeqTrainingArguments, # Optional
    DataCollatorForSeq2Seq,
    get_linear_schedule_with_warmup,
    BertTokenizer,
    BertModel,
    BitsAndBytesConfig # For potential 4/8-bit loading
)

# PEFT imports
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    LoraConfig, # Changed to LoRA
    TaskType,
    PeftModel,
    PeftConfig
)

# Evaluation imports
import nltk
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction # Using evaluate's BLEU
from rouge_score import rouge_scorer # Using rouge_score library directly for Ea
from bert_score import score as bert_score_metric # Renamed to avoid conflict

# --- Basic Setup ---
warnings.filterwarnings("ignore")
nltk.download('punkt', quiet=True) # Download required NLTK data

# Set environment variable for CUDA debugging if needed
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Helps get clearer CUDA errors

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
else:
    print("CUDA not available, running on CPU.")

# Define paths for Kaggle (ensure your dataset is linked)
# Ensure this path exists or your Kaggle kernel has the dataset attached
BASE_PATH = '/kaggle/input/nlp-data'
if not os.path.exists(BASE_PATH):
     # Fallback for local testing if Kaggle path doesn't exist
     print(f"Warning: Kaggle path {BASE_PATH} not found. Using fallback './nlp-dataset'.")
     BASE_PATH = './nlp-dataset' # Adjust this to your local dataset path if needed
     # You might need to create dummy files locally if needed for the script to run:
     # os.makedirs(BASE_PATH, exist_ok=True)
     # dummy_data = [{"question": "q", "answers": ["a"], "labelled_summaries": {"INFORMATION_SUMMARY": "s"}}]
     # with open(os.path.join(BASE_PATH, 'train.json'), 'w') as f: json.dump(dummy_data, f)
     # with open(os.path.join(BASE_PATH, 'valid.json'), 'w') as f: json.dump(dummy_data, f)
     # with open(os.path.join(BASE_PATH, 'test.json'), 'w') as f: json.dump(dummy_data, f)


OUTPUT_PATH = '/kaggle/working/'
# Ensure output path exists (it should in Kaggle)
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Create output directories
GENERATED_DIR = os.path.join(OUTPUT_PATH, "generated")
CHECKPOINTS_DIR = os.path.join(OUTPUT_PATH, "checkpoints")
CLASSIFIER_CKPT_DIR = os.path.join(CHECKPOINTS_DIR, "classifier")
SUMMARIZER_CKPT_DIR = os.path.join(CHECKPOINTS_DIR, "summarizer")
PLOTS_DIR = os.path.join(OUTPUT_PATH, "plots")

os.makedirs(GENERATED_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(CLASSIFIER_CKPT_DIR, exist_ok=True)
os.makedirs(SUMMARIZER_CKPT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

print(f"Output base directory: {OUTPUT_PATH}")
print(f"Generated files directory: {GENERATED_DIR}")
print(f"Checkpoints directory: {CHECKPOINTS_DIR}")
print(f"Plots directory: {PLOTS_DIR}")


# Helper function for memory cleanup
def cleanup_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Memory cleaned up.")

# %% [code]
# --- ClassifierCustomDataset Implementation ---
# Define perspective labels
PERSPECTIVE_LABELS = {
    "INFORMATION": 0,
    "SUGGESTION": 1,
    "EXPERIENCE": 2,
    "CAUSE": 3,
    "QUESTION": 4
}
# Reverse mapping for reference if needed
LABEL_TO_PERSPECTIVE = {v: k for k, v in PERSPECTIVE_LABELS.items()}

class ClassifierCustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        # Store valid perspective names based on the mapping keys
        self.valid_perspective_names = set(PERSPECTIVE_LABELS.keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        answers = item.get('answers', [])
        source_context = ' '.join([ans.replace('\\n', ' ').strip() for ans in answers if ans.strip()])
        question = item.get('question', '').strip()
        # Combine question and context for classification input
        text_input = f"Question: {question} Context: {source_context}"

        # Determine target perspective label
        target_perspective_raw = "INFORMATION" # Default raw name
        if item.get('labelled_summaries'):
            first_perspective_raw = list(item['labelled_summaries'].keys())[0]
            target_perspective_raw = first_perspective_raw

        # Clean the name
        cleaned_perspective = target_perspective_raw.replace('_SUMMARY', '')

        # Ensure cleaned name is valid and get label
        if cleaned_perspective in self.valid_perspective_names:
            label = PERSPECTIVE_LABELS[cleaned_perspective]
        else:
            print(f"Warning: Found unexpected perspective '{target_perspective_raw}' in classifier data, defaulting to INFORMATION label.")
            label = PERSPECTIVE_LABELS["INFORMATION"] # Fallback

        # Tokenize text input
        inputs = self.tokenizer(
            text_input,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long) # Return label as tensor
        }

# %% [code]
# --- DataLoader Implementations ---

# Seq2Seq DataLoader (Remains mostly the same, uses perspective definitions from paper)
class SummarizationCustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_source_length=1024, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length # Max length for summaries

        # Define perspective attributes based on Table 3 of the paper
        self.perspective_prompts = {
            "SUGGESTION": {
                "defn": "Defined as advice or recommendations to assist users in making informed medical decisions, solving problems, or improving health issues.",
                "start_with": "It is suggested",
                "tone_attribute": "Advisory, Recommending"
            },
            "INFORMATION": {
                "defn": "Defined as knowledge about diseases, disorders, and health-related facts, providing insights into symptoms and diagnosis.",
                "start_with": "For information purposes",
                "tone_attribute": "Informative, Educational"
            },
            "EXPERIENCE": {
                "defn": "Defined as individual experiences, anecdotes, or firsthand insights related to health, medical treatments, medication usage, and coping strategies",
                "start_with": "In user's experience",
                "tone_attribute": "Personal, Narrative"
            },
            "QUESTION": {
                "defn": "Defined as inquiry made for deeper understanding.",
                "start_with": "It is inquired",
                "tone_attribute": "Seeking Understanding"
            },
            "CAUSE": {
                "defn": "Defined as reasons responsible for the occurrence of a particular medical condition, symptom, or disease",
                "start_with": "Some of the causes",
                "tone_attribute": "Explanatory, Causal"
            }
            # --- FIX: Add QUESTION definition if missing ---
            # "QUESTION": {
            #     "defn": "Defined as inquiry made for deeper understanding.", # Provide a definition
            #     "start_with": "It is asked", # Define a starting phrase
            #     "tone_attribute": "Inquisitive, Seeking" # Define a tone
            # }
        }
        self.valid_perspective_names = set(self.perspective_prompts.keys()) # Store valid names

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Concatenate all non-empty answers as the source context
        answers = item.get('answers', [])
        source_context = ' '.join([ans.replace('\\n', ' ').strip() for ans in answers if ans.strip()])
        question = item.get('question', '').strip()

        # Determine target perspective and summary
        target_perspective = "INFORMATION" # Default cleaned name
        target_summary = ""
        if item.get('labelled_summaries'):
            first_perspective_raw = list(item['labelled_summaries'].keys())[0]
            target_summary = item['labelled_summaries'][first_perspective_raw]

            # --- Clean the name ---
            cleaned_perspective = first_perspective_raw.replace('_SUMMARY', '')
            # --- Ensure cleaned name is valid ---
            if cleaned_perspective in self.valid_perspective_names:
                target_perspective = cleaned_perspective # Use the cleaned, valid name
            else:
                print(f"Warning: Found unexpected perspective '{first_perspective_raw}' in summarizer data, defaulting to INFORMATION.")
                target_perspective = "INFORMATION" # Fallback to default
        else:
            # Handle case where there are no labelled summaries if needed
             target_perspective = "INFORMATION" # Keep default

        # Get perspective-specific prompt details using the cleaned name
        prompt_details = self.perspective_prompts.get(target_perspective, self.perspective_prompts["INFORMATION"])
        defn = prompt_details['defn']
        start_with = prompt_details['start_with']
        tone_attribute = prompt_details['tone_attribute']

        # Ensure target summary starts with the anchor text (as per paper's prompt logic)
        # Only add if summary has content and doesn't already start with it (case-insensitive)
        if target_summary and not target_summary.lower().startswith(start_with.lower()):
            target_summary = start_with + " " + target_summary

        # Construct the prompt using the template from Figure 2 / Table 3
        task_prefix = (
            f"Summarize the following content according to Perspective: {target_perspective}; "
            f"{target_perspective} Definition: {defn}; "
            f"Begin Summary with: '{start_with}'; "
            f"Tone of summary: {tone_attribute}; "
            f"Content to summarize: {source_context}; "
            f"Associated question: {question}"
        )

        # Tokenize source (prompt)
        inputs = self.tokenizer(
            task_prefix,
            padding="max_length",
            max_length=self.max_source_length,
            truncation=True,
            return_tensors="pt"
        )

        # Tokenize target (summary)
        labels = self.tokenizer(
            target_summary,
            padding="max_length",
            max_length=self.max_target_length,
            truncation=True,
            return_tensors="pt"
        )

        # For BART, labels should not be padded with -100, the model handles padding tokens
        label_ids = labels["input_ids"].squeeze(0)
        # Replace tokenizer.pad_token_id with -100 for labels so they are ignored in loss
        label_ids[label_ids == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": label_ids,
            "perspective": target_perspective, # Pass the CLEANED perspective name
            "target_summary_text": target_summary # Pass text for validation/logging
        }

# DataLoader Creation Functions
def create_dataloader(dataset, batch_size, shuffle=True, num_workers=2):
     # Reduce num_workers if CPU count is low
    num_workers = min(num_workers, os.cpu_count() // 2 if os.cpu_count() else 1)
    if num_workers < 1: num_workers = 1
    # print(f"Using {num_workers} workers for DataLoader.")
    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)

def test_create_dataloader(test_dataset, test_batch_size):
     num_workers = min(2, os.cpu_count() // 2 if os.cpu_count() else 1)
     if num_workers < 1: num_workers = 1
     return DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

# %% [code]
# --- Classifier Training and Validation Functions ---

def classifier_validation(valid_dataloader, model, device):
    print("Classifier validation processing...")
    model.eval()
    valid_losses = []
    all_preds = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(valid_dataloader, desc="Classifier Validation")
        for batch in progress_bar:
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                logits = outputs.logits

                valid_losses.append(loss.item())
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                progress_bar.set_postfix({'val_loss': np.mean(valid_losses) if valid_losses else 0.0})
            except Exception as e:
                print(f"\nError during classifier validation batch: {e}")
                cleanup_memory() # Try to clean up memory if an error occurs
                continue # Skip batch

    valid_loss = np.mean(valid_losses) if len(valid_losses) > 0 else float('inf') # Return inf if no batches succeeded
    accuracy = accuracy_score(all_labels, all_preds) if all_labels and all_preds else 0.0
    print(f"Classifier Validation Loss: {valid_loss:.4f}, Accuracy: {accuracy:.4f}")
    return valid_loss, accuracy

def train_classifier(train_data_path, valid_data_path, classifier_output_dir, # Changed output_path to specific dir
                     train_batch_size=8, valid_batch_size=8,
                     learning_rate=2e-5, warmup_ratio=0.1, epochs=3): # Reduced epochs

    print("--- Starting Classifier Training ---")
    # Load data
    try:
        with open(train_data_path, 'r') as json_file:
            train_data = json.load(json_file)
        with open(valid_data_path, 'r') as json_file:
            valid_data = json.load(json_file)
        print(f"Loaded {len(train_data)} training samples and {len(valid_data)} validation samples for classifier.")
    except Exception as e:
        print(f"Error loading classifier data: {e}")
        return None, None

    # Ensure output directory exists (redundant if created globally, but safe)
    os.makedirs(classifier_output_dir, exist_ok=True)
    print(f"Classifier checkpoints will be saved to: {classifier_output_dir}")

    best_val_loss = float('inf')
    num_perspectives = len(PERSPECTIVE_LABELS) # Use the defined mapping

    # Initialize model and tokenizer
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=num_perspectives)
    model.to(device)

    # Create datasets and dataloaders
    train_dataset = ClassifierCustomDataset(train_data, tokenizer)
    eval_dataset = ClassifierCustomDataset(valid_data, tokenizer)
    train_dataloader = create_dataloader(train_dataset, train_batch_size, shuffle=True)
    eval_dataloader = create_dataloader(eval_dataset, valid_batch_size, shuffle=False)

    # Define optimizer and learning rate scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * epochs
    num_warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)

    # Training loop
    for epoch in range(epochs):
        model.train()
        print(f"\n{'#'*25} Classifier Epoch: {epoch+1}/{epochs} {'#'*25}")
        train_losses = []
        progress_bar = tqdm(train_dataloader, desc=f"Classifier Epoch {epoch+1} Training")

        for batch in progress_bar:
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                optimizer.zero_grad()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                if torch.isnan(loss):
                    print("NaN loss detected in classifier training! Skipping batch.")
                    optimizer.zero_grad()
                    continue
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
                optimizer.step()
                scheduler.step()

                train_losses.append(loss.item())
                progress_bar.set_postfix({'train_loss': np.mean(train_losses[-50:]) if train_losses else 0.0})

            except RuntimeError as e:
                 if "out of memory" in str(e):
                     print(f"\nCUDA OOM Error during classifier training. Skipping batch.")
                     cleanup_memory()
                     optimizer.zero_grad()
                     continue
                 else:
                     print(f"\nRuntime error during classifier training: {e}")
                     raise e
            except Exception as e:
                print(f"\nError during classifier training batch: {e}")
                raise e # Re-raise other errors

        avg_train_loss = np.mean(train_losses) if train_losses else 0.0
        print(f"Classifier Epoch {epoch+1} Average Train Loss: {avg_train_loss:.4f}")

        # Validation
        valid_loss, accuracy = classifier_validation(eval_dataloader, model, device)

        # --- Save Best Checkpoint ---
        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            print(f"* New best validation loss: {best_val_loss:.4f}. Saving model to {classifier_output_dir}... *")
            try:
                model.save_pretrained(classifier_output_dir)
                tokenizer.save_pretrained(classifier_output_dir)
                print(f"Successfully saved classifier model and tokenizer to {classifier_output_dir}")
                # --- Verification ---
                config_path = os.path.join(classifier_output_dir, "config.json")
                model_path = os.path.join(classifier_output_dir, "pytorch_model.bin") # Or model.safetensors
                tokenizer_path = os.path.join(classifier_output_dir, "tokenizer_config.json")
                if os.path.exists(config_path) and (os.path.exists(model_path) or os.path.exists(model_path.replace(".bin",".safetensors"))) and os.path.exists(tokenizer_path):
                     print("Classifier checkpoint files verified.")
                else:
                     print(f"!!! Warning: Checkpoint files verification failed in {classifier_output_dir} !!!")
                     if not os.path.exists(config_path): print(" - Missing:", config_path)
                     if not (os.path.exists(model_path) or os.path.exists(model_path.replace(".bin",".safetensors"))): print(" - Missing model file (.bin or .safetensors)")
                     if not os.path.exists(tokenizer_path): print(" - Missing:", tokenizer_path)

            except Exception as e:
                print(f"!!! Error saving classifier checkpoint: {e} !!!")
        else:
             print(f"Validation loss ({valid_loss:.4f}) did not improve from best ({best_val_loss:.4f}). Not saving.")

    print("--- Classifier Training Finished ---")
    cleanup_memory()
    # Return path to best model directory (where the last successful save occurred)
    return classifier_output_dir # Return the directory path directly

# %% [code]
# --- Energy Model Initialization and Custom Loss ---

# Global cache for energy models to avoid reloading if possible
energy_model_cache = {
    "bert_tokenizer": None, "bert_model": None,
    "roberta_tokenizer": None, "roberta_model": None,
    "device": None
}

def initialize_energy_models(classifier_checkpoint_path=None, device='cpu', force_reload=False):
    global energy_model_cache
    target_device = torch.device(device)

    # Check if models are already loaded on the correct device
    if not force_reload and \
       energy_model_cache["bert_model"] is not None and \
       energy_model_cache["roberta_model"] is not None and \
       energy_model_cache["device"] == target_device:
        print(f"Using cached energy models on device: {target_device}")
        return energy_model_cache["bert_tokenizer"], energy_model_cache["bert_model"], \
               energy_model_cache["roberta_tokenizer"], energy_model_cache["roberta_model"]

    print(f"Initializing energy models for device: {target_device}")
    bert_tokenizer, bert_model, roberta_tokenizer, roberta_model = None, None, None, None

    try:
        # Clear cache before loading large models
        cleanup_memory()

        # --- BERT for Et ---
        print("Loading BERT model (for Et)...")
        bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertModel.from_pretrained('bert-base-uncased')
        bert_model.to(target_device) # Move to target device
        bert_model.eval()
        print(f"BERT model loaded on {bert_model.device}.")

        # --- RoBERTa Classifier for Ep ---
        num_perspectives = len(PERSPECTIVE_LABELS)
        print("Loading RoBERTa model (for Ep)...")
        if classifier_checkpoint_path and os.path.exists(classifier_checkpoint_path):
            print(f"Loading fine-tuned RoBERTa classifier from: {classifier_checkpoint_path}")
            # Verify checkpoint content before loading
            config_path = os.path.join(classifier_checkpoint_path, "config.json")
            model_path = os.path.join(classifier_checkpoint_path, "pytorch_model.bin") # Or model.safetensors
            if not os.path.exists(config_path) or not (os.path.exists(model_path) or os.path.exists(model_path.replace(".bin",".safetensors"))):
                 print(f"!!! Error: Classifier checkpoint files missing in {classifier_checkpoint_path}. Loading base model instead.")
                 classifier_checkpoint_path = None # Force loading base model
            else:
                 roberta_tokenizer = RobertaTokenizer.from_pretrained(classifier_checkpoint_path)
                 roberta_model = RobertaForSequenceClassification.from_pretrained(classifier_checkpoint_path, num_labels=num_perspectives)
        else:
             classifier_checkpoint_path = None # Ensure path is None if not used

        if classifier_checkpoint_path is None: # Load base if checkpoint wasn't valid or provided
            print("Warning: Fine-tuned classifier checkpoint not found or invalid. Loading base roberta-base.")
            print("         The Ep energy term will be based on the base model, which might be less effective.")
            roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
            roberta_model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=num_perspectives)

        roberta_model.to(target_device) # Move to target device
        roberta_model.eval()
        print(f"RoBERTa model loaded on {roberta_model.device}.")

        print(f"Energy models initialized. BERT on: {bert_model.device}, RoBERTa on: {roberta_model.device}")

        # Update cache
        energy_model_cache["bert_tokenizer"] = bert_tokenizer
        energy_model_cache["bert_model"] = bert_model
        energy_model_cache["roberta_tokenizer"] = roberta_tokenizer
        energy_model_cache["roberta_model"] = roberta_model
        energy_model_cache["device"] = target_device

        return bert_tokenizer, bert_model, roberta_tokenizer, roberta_model

    except Exception as e:
        print(f"FATAL Error initializing energy models: {e}")
        import traceback
        traceback.print_exc()
        # Clear cache on failure
        energy_model_cache = {k: None for k in energy_model_cache}
        return None, None, None, None


# --- Define Energy Functions (Ep, Ea, Et) based on paper ---

# Perspective Energy (Ep) - using RoBERTa classifier
def calculate_Ep(text, tokenizer, model, device):
    model.eval() # Ensure model is in eval mode
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        # Ensure logits are on CPU for numpy conversion if needed, but softmax is fine on GPU
        probabilities = torch.softmax(logits, dim=-1).squeeze() # Get probabilities for the classes
    # Order needs to match PERSPECTIVE_LABELS: INFO, SUGG, EXP, CAUSE, QUES
    # Lower energy is better -> use -log(prob)
    log_probs = torch.log(probabilities + 1e-9) # Add epsilon to avoid log(0)
    # Return energy for each perspective based on its index in PERSPECTIVE_LABELS
    energies = {}
    for perspective, index in PERSPECTIVE_LABELS.items():
        energies[perspective] = -log_probs[index].item()
    return energies


# Anchor Energy (Ea) - ROUGE-1 between start of summary and anchor text
def calculate_Ea(generated_text, expected_anchor, scorer, device): # device param kept for consistency, but not used here
     # Get first few words of generated text matching length of anchor
     num_anchor_tokens = len(expected_anchor.split())
     # Handle edge case of very short generated text
     start_of_generated = " ".join(generated_text.split()[:num_anchor_tokens])

     if not start_of_generated or not expected_anchor:
         return 1.0 # Max energy (bad score) if empty

     # Calculate ROUGE-1 F1 score
     scores = scorer.score(target=expected_anchor, prediction=start_of_generated)
     rouge1_f1 = scores['rouge1'].fmeasure

     # Lower energy is better, so return 1 - F1 score
     energy = 1.0 - rouge1_f1
     return energy


# Tone Energy (Et) - Cosine similarity between summary embedding and tone keyword embeddings
TONE_KEYWORDS = {
    "INFORMATION": ["factual", "informative", "educational", "objective", "knowledge"],
    "SUGGESTION": ["advice", "recommend", "suggest", "should", "consider"],
    "EXPERIENCE": ["personal", "narrative", "my experience", "I felt", "anecdote"],
    "CAUSE": ["reason", "due to", "caused by", "explanation", "origin"],
    "QUESTION": ["inquiry", "wondering", "question", "clarify", "understand"]
}
# Cache for keyword embeddings
keyword_embedding_cache = {}

def get_bert_embedding(text, tokenizer, model, device):
    model.eval()
    # Batch process if text is a list? For now, single text
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True).to(device) # Use padding=True for single text
    with torch.no_grad():
        outputs = model(**inputs)
        # Use mean pooling of last hidden state
        attention_mask = inputs['attention_mask']
        mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
        sum_embeddings = torch.sum(outputs.last_hidden_state * mask_expanded, 1)
        sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
        embedding = (sum_embeddings / sum_mask).squeeze(0) # Squeeze batch dim
    return embedding

def calculate_Et(generated_text, perspective, bert_tokenizer, bert_model, device):
    global keyword_embedding_cache
    if not generated_text:
        return 1.0 # Max energy if empty

    keywords = TONE_KEYWORDS.get(perspective, [])
    if not keywords:
        return 0.5 # Default neutral energy if no keywords defined

    # Get summary embedding
    summary_embedding = get_bert_embedding(generated_text, bert_tokenizer, bert_model, device)
    if summary_embedding is None: return 1.0 # Handle potential embedding error

    # Get keyword embeddings (use cache)
    cache_key = (perspective, str(bert_model.device)) # Include device in cache key
    if cache_key in keyword_embedding_cache:
        avg_keyword_embedding = keyword_embedding_cache[cache_key]
    else:
        keyword_embeddings = []
        for keyword in keywords:
            emb = get_bert_embedding(keyword, bert_tokenizer, bert_model, device)
            if emb is not None:
                 keyword_embeddings.append(emb)

        if not keyword_embeddings:
             return 0.5 # Default if keywords couldn't be embedded
        avg_keyword_embedding = torch.stack(keyword_embeddings).mean(dim=0)
        keyword_embedding_cache[cache_key] = avg_keyword_embedding # Store in cache

    # Calculate cosine similarity
    cos_sim = torch.nn.functional.cosine_similarity(summary_embedding, avg_keyword_embedding, dim=0)

    # Lower energy is better, so return 1 - similarity
    energy = 1.0 - cos_sim.item()
    return max(0.0, energy) # Ensure energy is not negative

# --- Combined Custom Loss Calculation ---
perspective_details = SummarizationCustomDataset([], None).perspective_prompts # Get definitions for Ea anchor text
rouge_calc_scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True) # Initialize rouge scorer for Ea

def compute_custom_loss(decoded_target_summary, # Use decoded target labels
                      target_perspective, # Target perspective string
                      tokenizer, # Main tokenizer (BART) - NOT USED here, but kept for potential future use
                      bert_tokenizer, bert_model,
                      roberta_tokenizer, roberta_model,
                      device):
    """
    Calculates the energy-based loss based on the paper's components (Ep, Ea, Et).
    Operates on the DECODED TARGET SUMMARY for stability during training.
    """
    current_target_device = torch.device(device)
    if not decoded_target_summary.strip() or target_perspective not in perspective_details:
        # print(f"Warning: Skipping custom loss for empty summary or invalid perspective '{target_perspective}'.")
        return torch.tensor(0.0, device=current_target_device, requires_grad=False)

    try:
        # --- Ensure Energy Models are on Correct Device ---
        # This check might be less critical if initialize_energy_models handles it robustly
        if bert_model.device != current_target_device: bert_model.to(current_target_device)
        if roberta_model.device != current_target_device: roberta_model.to(current_target_device)

        # --- Calculate Energy Components ---
        with torch.no_grad(): # No gradients needed through energy models themselves
            # Ep: Perspective Score
            Ep_dict = calculate_Ep(decoded_target_summary, roberta_tokenizer, roberta_model, current_target_device)
            if not Ep_dict: # Handle potential calculation failure
                 print("Warning: Ep calculation failed.")
                 return torch.tensor(0.0, device=current_target_device, requires_grad=False)

            # Ea: Anchor Score
            expected_anchor = perspective_details[target_perspective]['start_with']
            Ea_val = calculate_Ea(decoded_target_summary, expected_anchor, rouge_calc_scorer, current_target_device)
            Ea_dict = {p: 1.0 for p in perspective_details} # Default high energy
            Ea_dict[target_perspective] = Ea_val # Set energy for the target perspective

            # Et: Tone Score
            Et_val = calculate_Et(decoded_target_summary, target_perspective, bert_tokenizer, bert_model, current_target_device)
            Et_dict = {p: 0.5 for p in perspective_details} # Default neutral energy
            Et_dict[target_perspective] = Et_val # Set energy for the target perspective

        # --- Combine Energies (Linear Combination per Perspective) ---
        alpha1 = 0.5 # Weight for Ep
        alpha2 = 0.3 # Weight for Ea
        alpha3 = 0.2 # Weight for Et

        E_X = {}
        perspectives_ordered = list(perspective_details.keys()) # Ensure consistent order
        for p in perspectives_ordered:
             ep_val = Ep_dict.get(p)
             if ep_val is None: # Check if perspective was found in Ep_dict
                 print(f"Warning: Perspective '{p}' not found in Ep calculation results. Assigning high energy.")
                 ep_val = 10.0 # Assign high energy if missing
             E_X[p] = (alpha1 * ep_val +
                       alpha2 * Ea_dict.get(p, 1.0) +
                       alpha3 * Et_dict.get(p, 0.5))

        # --- Calculate Probabilities using Boltzmann distribution ---
        try:
             # Use log-sum-exp trick for numerical stability
             energies_tensor = torch.tensor([E_X[k] for k in perspectives_ordered], device=current_target_device)
             log_probs = torch.nn.functional.log_softmax(-energies_tensor, dim=0) # Note the minus sign for energy
        except Exception as e:
            print(f"Error during log_softmax calculation: {e}")
            print("E_X values:", E_X)
            return torch.tensor(0.0, device=current_target_device, requires_grad=False)


        # --- Calculate Cross-Entropy Loss (equivalent to Negative Log Likelihood) ---
        target_idx = perspectives_ordered.index(target_perspective)
        perspective_loss = -log_probs[target_idx] # NLL = -log(P(correct_class))

        # Check for NaN/Inf
        if torch.isnan(perspective_loss) or torch.isinf(perspective_loss):
             print("Warning: NaN/Inf detected in custom perspective loss!")
             # print("Input summary:", decoded_target_summary)
             # print("Target perspective:", target_perspective)
             # print("Ep:", Ep_dict)
             # print("Ea:", Ea_dict)
             # print("Et:", Et_dict)
             # print("E_X:", E_X)
             # print("log_probs:", log_probs)
             return torch.tensor(0.0, device=current_target_device, requires_grad=False)

        # This loss acts as a regularizer based on properties of the target summary.
        # It doesn't directly use the main model's gradients for this part.
        # Detach it so gradients don't flow back through the energy models.
        return perspective_loss.detach()

    except Exception as e:
        print(f"Error in compute_custom_loss for perspective {target_perspective}: {e}")
        import traceback
        traceback.print_exc()
        return torch.tensor(0.0, device=current_target_device, requires_grad=False)


# %% [code]
# --- Summarization Model Training and Validation ---

def seq2seq_validation(eval_dataloader, model, tokenizer, device):
    print("Summarizer validation processing...")
    model.eval()
    eval_losses = []
    all_preds = []
    all_refs = []
    rouge_metric = evaluate.load('rouge', quiet=True)

    with torch.no_grad():
        progress_bar = tqdm(eval_dataloader, desc="Summarizer Validation")
        for batch in progress_bar:
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                target_texts = batch['target_summary_text'] # Get reference text

                # Calculate standard CE Loss for validation
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                if not torch.isnan(loss):
                    eval_losses.append(loss.item())

                # Generate summaries for ROUGE calculation
                # Make sure model is on the correct device for generation
                if next(model.parameters()).device != torch.device(device):
                    model.to(device)

                generated_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_length=256, # Match max_target_length
                    num_beams=4,
                    early_stopping=True,
                    # pad_token_id=tokenizer.pad_token_id, # Ensure pad token is set
                    # eos_token_id=tokenizer.eos_token_id  # Ensure EOS token is set
                )
                # Move generated_ids to CPU before decoding if they are on GPU
                preds = tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
                all_preds.extend(preds)
                all_refs.extend(target_texts) # Add the ground truth summaries

                progress_bar.set_postfix({'val_loss': np.mean(eval_losses) if eval_losses else 0.0})
            except RuntimeError as e:
                 if "out of memory" in str(e):
                     print(f"\nCUDA OOM Error during summarizer validation. Skipping batch.")
                     cleanup_memory()
                     continue
                 else:
                     print(f"\nRuntime error during summarizer validation: {e}")
                     # Don't raise, try to continue validation if possible
            except Exception as e:
                print(f"\nError during summarizer validation batch: {e}")
                # Don't raise, try to continue validation


    avg_eval_loss = np.mean(eval_losses) if eval_losses else float('inf') # Return inf if validation failed

    # Calculate ROUGE score using evaluate library
    rouge_l_f1 = 0.0
    if all_preds and all_refs:
        try:
            # Filter out empty predictions/references which can cause issues
            filtered_preds = [p for p, r in zip(all_preds, all_refs) if p and r]
            filtered_refs = [r for p, r in zip(all_preds, all_refs) if p and r]
            if filtered_preds:
                 results = rouge_metric.compute(predictions=filtered_preds, references=filtered_refs)
                 rouge_l_f1 = results.get('rougeLsum', 0.0) # Use rougeLsum for summaries
            else:
                 print("Warning: No valid prediction/reference pairs for ROUGE calculation.")
        except Exception as e:
            print(f"Error calculating ROUGE score during validation: {e}")
    else:
         print("Warning: No predictions or references generated during validation.")


    print(f"Summarizer Validation CE Loss: {avg_eval_loss:.4f}, ROUGE-Lsum F1: {rouge_l_f1:.4f}")
    cleanup_memory()
    # Use CE loss for selecting best model, but ROUGE is also informative
    return avg_eval_loss, rouge_l_f1


# Modified train_seq2seq function
def train_seq2seq(train_data_path, valid_data_path, model_name, summarizer_output_dir, # Changed output_path to specific dir
                  classifier_checkpoint_path, # Path to the trained RoBERTa classifier
                  train_batch_size=2, # Small batch size for BART-large
                  gradient_accumulation_steps=8, # Effective batch size = 2 * 8 = 16
                  valid_batch_size=4,
                  learning_rate=5e-5, # Standard LR for fine-tuning
                  lora_r=16, lora_alpha=32, lora_dropout=0.05, # LoRA config
                  lambda_perspective=0.1, # Weight for the custom energy loss
                  warmup_ratio=0.1, epochs=3, # Reduced epochs
                  use_energy_loss=True): # Flag to enable/disable custom loss

    print(f"--- Starting Summarizer Training (Model: {model_name}, PEFT: LoRA) ---")
    global target_device # Make target_device accessible for energy model checks
    target_device = device # Assign global device based on initial setup

    # --- Load data ---
    try:
        with open(train_data_path, 'r') as json_file:
            train_data = json.load(json_file)
        with open(valid_data_path, 'r') as json_file:
            valid_data = json.load(json_file)
        print(f"Loaded {len(train_data)} training samples and {len(valid_data)} validation samples for summarizer.")
    except Exception as e:
        print(f"Error loading summarizer data: {e}")
        return None, None, []


    # --- Initialize Energy Models ---
    bert_tokenizer, bert_model, roberta_tokenizer, roberta_model = None, None, None, None
    if use_energy_loss:
        print("Initializing energy models for custom loss...")
        bert_tokenizer, bert_model, roberta_tokenizer, roberta_model = initialize_energy_models(
            classifier_checkpoint_path, device=device, force_reload=True # Force reload for training context
        )
        if bert_model is None or roberta_model is None:
            print("Warning: Failed to initialize one or more energy models. Disabling energy loss.")
            use_energy_loss = False
        else:
            print("Energy models initialized successfully.")

    # --- Initialize Summarization Model (BART + LoRA) ---
    model = None
    tokenizer = None
    try:
        print(f"Loading base model: {model_name}")
        # Load base model (consider 4/8 bit if needed)
        # quantization_config = BitsAndBytesConfig(load_in_8bit=True) # Example for 8-bit
        # model = BartForConditionalGeneration.from_pretrained(model_name, quantization_config=quantization_config)
        model = BartForConditionalGeneration.from_pretrained(model_name)
        tokenizer = BartTokenizer.from_pretrained(model_name)
        print("Base model loaded.")

        # Configure LoRA
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj"], # Target modules for BART attention
            lora_dropout=lora_dropout,
            bias="none",
            task_type=TaskType.SEQ_2_SEQ_LM
        )

        # Apply LoRA PEFT to the model
        model = get_peft_model(model, lora_config) # model is now the PeftModel
        print("Applied LoRA configuration.")
        model.print_trainable_parameters()

        # Move model to device *after* PEFT application
        print(f"Moving main model to {device}...")
        model.to(device)
        print(f"Main model is now on device: {next(model.parameters()).device}")

    except Exception as e:
        print(f"Error initializing summarization model: {e}")
        import traceback
        traceback.print_exc()
        cleanup_memory()
        return None, None, []

    # --- DataLoaders and Optimizer ---
    train_dataset = SummarizationCustomDataset(train_data, tokenizer)
    eval_dataset = SummarizationCustomDataset(valid_data, tokenizer)
    train_dataloader = create_dataloader(train_dataset, train_batch_size, shuffle=True)
    eval_dataloader = create_dataloader(eval_dataset, valid_batch_size, shuffle=False)

    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) # Optimize only trainable params
    total_steps = math.ceil(len(train_dataloader) / gradient_accumulation_steps) * epochs
    num_warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)

    # --- Training Loop ---
    # Ensure output directory exists
    os.makedirs(summarizer_output_dir, exist_ok=True)
    print(f"Summarizer checkpoints will be saved to: {summarizer_output_dir}")

    best_val_loss = float('inf')
    train_losses_history = []
    global_step = 0

    for epoch in range(epochs):
        model.train()
        # Ensure energy models are also in eval mode if used
        if use_energy_loss and bert_model and roberta_model:
            bert_model.eval()
            roberta_model.eval()

        print(f"\n{'#'*25} Summarizer Epoch: {epoch+1}/{epochs} {'#'*25}")
        epoch_train_losses = []
        epoch_ce_losses = []
        epoch_p_losses = []
        progress_bar = tqdm(train_dataloader, desc=f"Summarizer Epoch {epoch+1} Training")

        optimizer.zero_grad() # Zero gradients at the start of the epoch accumulation cycle

        for step, batch in enumerate(progress_bar):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                perspectives = batch['perspective'] # List of perspectives
                target_summaries = batch['target_summary_text'] # List of target texts

                # --- Standard Forward Pass ---
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                ce_loss = outputs.loss

                if torch.isnan(ce_loss):
                    print(f"NaN detected in CE loss at step {step}. Skipping batch.")
                    # optimizer.zero_grad() # Clear potentially bad gradients from this batch
                    cleanup_memory()
                    continue # Skip this batch entirely

                # --- Calculate Custom Perspective Loss (if enabled) ---
                perspective_loss = torch.tensor(0.0, device=device)
                if use_energy_loss and bert_model is not None and roberta_model is not None:
                    batch_perspective_loss = []
                    for i in range(len(perspectives)): # Process each item in batch individually
                         # Check device consistency before calling custom loss
                         if bert_model.device != target_device: bert_model.to(target_device)
                         if roberta_model.device != target_device: roberta_model.to(target_device)

                         loss_item = compute_custom_loss(
                             target_summaries[i], perspectives[i],
                             tokenizer, bert_tokenizer, bert_model,
                             roberta_tokenizer, roberta_model, device
                         )
                         # Ensure loss_item is a tensor before appending
                         if isinstance(loss_item, torch.Tensor):
                             batch_perspective_loss.append(loss_item)
                         else: # Handle cases where compute_custom_loss might return non-tensor on error
                              batch_perspective_loss.append(torch.tensor(0.0, device=device))


                    if batch_perspective_loss:
                        # Stack and calculate mean, ensuring all tensors are on the correct device
                        valid_losses = [l.to(device) for l in batch_perspective_loss if torch.isfinite(l)]
                        if valid_losses:
                            perspective_loss = torch.mean(torch.stack(valid_losses))
                        else:
                            perspective_loss = torch.tensor(0.0, device=device) # Handle case where all items failed

                    if torch.isnan(perspective_loss):
                         print(f"NaN detected in Perspective loss at step {step}. Setting to 0.")
                         perspective_loss = torch.tensor(0.0, device=device)


                # --- Combine Losses ---
                total_loss = ce_loss + lambda_perspective * perspective_loss

                # --- Gradient Accumulation and Backward Pass ---
                scaled_loss = total_loss / gradient_accumulation_steps
                if torch.isnan(scaled_loss):
                     print(f"NaN detected in final scaled loss at step {step}. Skipping backward pass for this batch.")
                     # optimizer.zero_grad() # Ensure grads are zeroed if skipping optimizer step
                     cleanup_memory()
                     continue

                scaled_loss.backward()

                epoch_train_losses.append(total_loss.item()) # Log unscaled loss
                epoch_ce_losses.append(ce_loss.item())
                epoch_p_losses.append(perspective_loss.item())

                # --- Optimizer Step ---
                if (step + 1) % gradient_accumulation_steps == 0 or (step + 1 == len(train_dataloader)):
                    # Unscale gradients before clipping/stepping (if using GradScaler, otherwise skip)
                    # scaler.unscale_(optimizer) # If using GradScaler

                    torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), 1.0)

                    # scaler.step(optimizer) # If using GradScaler
                    # scaler.update() # If using GradScaler
                    optimizer.step() # Use this if not using GradScaler

                    scheduler.step()
                    optimizer.zero_grad() # Zero gradients *after* optimizer step
                    global_step += 1

                    # Update progress bar display (more stable averages)
                    avg_loss = np.mean(epoch_train_losses[-50*gradient_accumulation_steps:]) if epoch_train_losses else 0.0
                    avg_ce_loss = np.mean(epoch_ce_losses[-50*gradient_accumulation_steps:]) if epoch_ce_losses else 0.0
                    avg_p_loss = np.mean(epoch_p_losses[-50*gradient_accumulation_steps:]) if epoch_p_losses else 0.0
                    progress_bar.set_postfix({
                        'loss': f"{total_loss.item():.4f}",
                        'avg_loss': f"{avg_loss:.4f}",
                        'avg_CE': f"{avg_ce_loss:.4f}",
                        'avg_P': f"{avg_p_loss:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.2e}"
                    })

                # Optional: Periodic cleanup
                if global_step % 100 == 0:
                    cleanup_memory()

            except RuntimeError as e:
                 if "out of memory" in str(e):
                     print(f"\nCUDA OOM Error at step {step}. Skipping batch.")
                     cleanup_memory()
                     optimizer.zero_grad() # Important to clear potentially bad grads
                     continue
                 else:
                     print(f"\nRuntime error in training batch {step}: {e}")
                     # Consider stopping if a non-OOM runtime error occurs
                     raise e
            except Exception as e:
                print(f"\nError in training batch {step}: {e}")
                import traceback
                traceback.print_exc()
                # Consider stopping on other exceptions
                raise e


        # --- End of Epoch ---
        avg_epoch_train_loss = np.mean(epoch_train_losses) if epoch_train_losses else float('inf')
        train_losses_history.append(avg_epoch_train_loss)
        print(f"Summarizer Epoch {epoch+1} Average Train Loss: {avg_epoch_train_loss:.4f}")

        # --- Validation ---
        valid_loss, valid_rouge = seq2seq_validation(eval_dataloader, model, tokenizer, device)

        # --- Save Best Model (based on validation CE loss) ---
        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            print(f"* New best validation loss: {best_val_loss:.4f}. Saving PEFT model to {summarizer_output_dir}... *")
            try:
                model.save_pretrained(summarizer_output_dir) # Saves only LoRA weights + config
                tokenizer.save_pretrained(summarizer_output_dir) # Save tokenizer config
                print(f"Successfully saved PEFT adapter and tokenizer to {summarizer_output_dir}")

                # --- Verification ---
                adapter_model_path_bin = os.path.join(summarizer_output_dir, "adapter_model.bin")
                adapter_model_path_safe = os.path.join(summarizer_output_dir, "adapter_model.safetensors")
                adapter_config_path = os.path.join(summarizer_output_dir, "adapter_config.json")
                tokenizer_config_path = os.path.join(summarizer_output_dir, "tokenizer_config.json")

                adapter_model_exists = os.path.exists(adapter_model_path_bin) or os.path.exists(adapter_model_path_safe)

                if adapter_model_exists and os.path.exists(adapter_config_path) and os.path.exists(tokenizer_config_path):
                     print("PEFT adapter and tokenizer files verified.")
                else:
                     print(f"!!! Warning: PEFT checkpoint files verification failed in {summarizer_output_dir} !!!")
                     if not adapter_model_exists: print(f" - Missing: adapter_model (.bin or .safetensors)")
                     if not os.path.exists(adapter_config_path): print(" - Missing:", adapter_config_path)
                     if not os.path.exists(tokenizer_config_path): print(" - Missing:", tokenizer_config_path)

            except Exception as e:
                 print(f"!!! Error saving PEFT adapter checkpoint: {e} !!!")

        else:
            print(f"Validation loss ({valid_loss:.4f}) did not improve from best ({best_val_loss:.4f}). Not saving.")


    print("--- Summarizer Training Finished ---")
    cleanup_memory()
    # Return path to best model directory and loss history
    # Important: return the *directory path*, not the model object itself
    return summarizer_output_dir, tokenizer, train_losses_history


# %% [code]
# --- Inference Function ---

def inference(test_file, base_model_name, peft_ckpt_dir, output_csv_path, batch_size_test=4, max_source_length=1024, max_target_length=256):
    print("--- Starting Inference ---")
    global target_device # Use the globally set device
    # Load test data
    try:
        with open(test_file, 'r') as json_file:
            test_data = json.load(json_file)
        print(f"Loaded {len(test_data)} test samples from {test_file}")
    except Exception as e:
        print(f"Error loading test data from {test_file}: {e}")
        return None

    # Verify PEFT checkpoint directory and necessary files
    print(f"Checking PEFT checkpoint directory: {peft_ckpt_dir}")
    adapter_config_path = os.path.join(peft_ckpt_dir, "adapter_config.json")
    adapter_model_path_bin = os.path.join(peft_ckpt_dir, "adapter_model.bin")
    adapter_model_path_safe = os.path.join(peft_ckpt_dir, "adapter_model.safetensors")
    if not os.path.exists(adapter_config_path):
         print(f"Error: adapter_config.json not found in {peft_ckpt_dir}")
         return None
    if not (os.path.exists(adapter_model_path_bin) or os.path.exists(adapter_model_path_safe)):
         print(f"Error: adapter_model file (.bin or .safetensors) not found in {peft_ckpt_dir}")
         return None
    print("PEFT checkpoint files seem present.")


    # Load base model and tokenizer
    peft_model = None
    tokenizer = None
    try:
        print(f"Loading base model '{base_model_name}' for inference...")
        # Load base model potentially onto CPU first if memory is tight, then move PeftModel to GPU
        base_model = BartForConditionalGeneration.from_pretrained(base_model_name)
        tokenizer = BartTokenizer.from_pretrained(peft_ckpt_dir) # Load tokenizer from PEFT dir (if saved)
        print("Base model and tokenizer loaded.")

        # Load PEFT adapter weights
        print(f"Loading PEFT adapter from: {peft_ckpt_dir}")
        peft_model = PeftModel.from_pretrained(base_model, peft_ckpt_dir, is_trainable=False)
        print("PEFT adapter loaded.")

        # Move the combined model to the target device
        print(f"Moving PEFT model to device: {target_device}")
        peft_model.to(target_device)
        peft_model.eval() # Set to evaluation mode
        print(f"PEFT model moved to {next(peft_model.parameters()).device} and set to eval mode.")

    except Exception as e:
        print(f"Error loading model for inference: {e}")
        import traceback
        traceback.print_exc()
        cleanup_memory()
        return None

    # Create test dataset and dataloader
    # Use the same dataset class as training to ensure prompt consistency
    test_dataset = SummarizationCustomDataset(test_data, tokenizer, max_source_length, max_target_length)
    test_dataloader = test_create_dataloader(test_dataset, batch_size_test)

    # Generate predictions
    results = []
    with torch.no_grad():
        progress_bar = tqdm(test_dataloader, desc="Inference")
        # Keep track of original indices if dataloader shuffles (it shouldn't with test_create_dataloader)
        current_idx = 0
        for batch in progress_bar:
            try:
                input_ids = batch['input_ids'].to(target_device)
                attention_mask = batch['attention_mask'].to(target_device)

                # Check model device just before generation
                if next(peft_model.parameters()).device != target_device:
                     print("Warning: Model moved off device? Moving back...")
                     peft_model.to(target_device)

                # Use the loaded PeftModel for generation
                outputs = peft_model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    num_beams=5,
                    max_new_tokens=max_target_length, # Use max_new_tokens for consistency with newer HF versions
                    min_length=10, # Avoid very short summaries
                    repetition_penalty=1.2,
                    early_stopping=True,
                    no_repeat_ngram_size=3,
                    pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token id
                    eos_token_id=tokenizer.eos_token_id   # Explicitly set eos token id
                )

                # Decode generated summaries (move outputs to CPU first)
                predictions = tokenizer.batch_decode(outputs.cpu(), skip_special_tokens=True, clean_up_tokenization_spaces=True)

                # Store results (perspective, prediction, actual, source)
                batch_size = len(batch['perspective'])
                for i in range(batch_size):
                     original_item_index = current_idx + i
                     if original_item_index < len(test_data):
                         item = test_data[original_item_index] # Get original item for full context
                         answers = item.get('answers', [])
                         source_context = ' '.join([ans.replace('\\n', ' ').strip() for ans in answers if ans.strip()])
                         question = item.get('question', '').strip()

                         results.append({
                             'PERSPECTIVE': batch['perspective'][i],
                             'PREDICTED': predictions[i],
                             'ACTUAL_SUMMARY': batch['target_summary_text'][i],
                             'SOURCE_QUESTION': question,
                             'SOURCE_ANSWERS': source_context
                         })
                     else:
                         print(f"Warning: Index mismatch during inference result processing ({original_item_index} vs {len(test_data)})")

                current_idx += batch_size # Increment index tracker

            except RuntimeError as e:
                 if "out of memory" in str(e):
                     print(f"\nCUDA OOM Error during inference. Skipping batch.")
                     cleanup_memory()
                     continue
                 else:
                     print(f"\nRuntime error during inference: {e}")
                     # Don't raise, try to continue inference if possible
            except Exception as e:
                print(f"\nError during inference batch: {e}")
                # Don't raise, try to continue


    # Save results to CSV
    try:
        results_df = pd.DataFrame(results)
        # Ensure the directory exists before saving
        os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
        results_df.to_csv(output_csv_path, index=False)
        print(f"Inference complete. Results saved to {output_csv_path}")
        # --- Verification ---
        if os.path.exists(output_csv_path):
            print(f"Inference results CSV file verified at {output_csv_path}")
        else:
            print(f"!!! Warning: Inference results CSV file NOT found at {output_csv_path} after saving !!!")
    except Exception as e:
         print(f"!!! Error saving inference results to CSV: {e} !!!")
         results_df = None # Ensure df is None if saving failed

    cleanup_memory()
    return results_df


# %% [code]
# --- Evaluation Metrics ---

class EvaluationMetrics:
    def __init__(self, predictions, references):
        # Ensure lists are not empty and are of strings
        self.predictions = [str(p) if p else "" for p in predictions] # Replace None/NaN with empty string
        self.references = [str(r) if r else "" for r in references]

        if not self.predictions or not self.references:
             print("Warning: Predictions or References list is empty or became empty after cleaning.")
             self.valid_data = False
        elif len(self.predictions) != len(self.references):
             print(f"Warning: Mismatch in prediction ({len(self.predictions)}) and reference ({len(self.references)}) counts.")
             # Optionally trim to the shorter length
             min_len = min(len(self.predictions), len(self.references))
             self.predictions = self.predictions[:min_len]
             self.references = self.references[:min_len]
             self.valid_data = bool(min_len)
             print(f"Trimmed to {min_len} pairs for evaluation.")
        else:
            self.valid_data = True

        # Load metrics only if data is valid
        if self.valid_data:
             self.rouge_metric = evaluate.load('rouge', quiet=True)
             self.meteor_metric = evaluate.load('meteor', quiet=True)
             self.bleu_metric = evaluate.load('bleu', quiet=True)
             # BERTScore will be handled within its method due to potential slowness

    def compute_rouge_score(self):
        if not self.valid_data: return {}
        # Filter empty strings again, as evaluate library might handle them differently
        paired_preds = [p for p, r in zip(self.predictions, self.references) if p and r]
        paired_refs = [r for p, r in zip(self.predictions, self.references) if p and r]
        if not paired_preds: return {"rouge1": 0, "rouge2": 0, "rougeL": 0, "rougeLsum": 0}
        try:
            results = self.rouge_metric.compute(predictions=paired_preds, references=paired_refs)
            return {
                "rouge1": results.get('rouge1', 0.0) * 100,
                "rouge2": results.get('rouge2', 0.0) * 100,
                "rougeL": results.get('rougeL', 0.0) * 100,
                "rougeLsum": results.get('rougeLsum', 0.0) * 100,
            }
        except Exception as e:
            print(f"Error computing ROUGE: {e}")
            return {}

    def compute_meteor_score(self):
        if not self.valid_data: return {}
        paired_preds = [p for p, r in zip(self.predictions, self.references) if p and r]
        paired_refs = [r for p, r in zip(self.predictions, self.references) if p and r]
        if not paired_preds: return {"meteor": 0}
        try:
            # Meteor requires NLTK wordnet, download if needed
            try:
                 nltk.data.find('corpora/wordnet.zip')
            except:
                 nltk.download('wordnet', quiet=True)
                 nltk.download('omw-1.4', quiet=True) # Also download Open Multilingual Wordnet

            results = self.meteor_metric.compute(predictions=paired_preds, references=paired_refs)
            return {"meteor": results.get('meteor', 0.0) * 100}
        except Exception as e:
            print(f"Error computing METEOR: {e}")
            return {}

    def compute_bleu_scores(self):
        if not self.valid_data: return {}
        paired_preds = [p for p, r in zip(self.predictions, self.references) if p and r]
        paired_refs = [r for p, r in zip(self.predictions, self.references) if p and r]
        if not paired_preds: return {"bleu": 0}
        try:
            # BLEU expects references as list of lists for multiple refs, but evaluate handles single refs
            results = self.bleu_metric.compute(predictions=paired_preds, references=paired_refs)
            return {"bleu": results.get('bleu', 0.0) * 100}
        except Exception as e:
            print(f"Error computing BLEU: {e}")
            return {}

    def compute_bertscore(self, lang="en"):
        if not self.valid_data: return {}
        paired_preds = [p for p, r in zip(self.predictions, self.references) if p and r]
        paired_refs = [r for p, r in zip(self.predictions, self.references) if p and r]
        if not paired_preds: return {"bertscore_precision": 0, "bertscore_recall": 0, "bertscore_f1": 0}
        try:
            # BERTScore computation needs device context
            with torch.no_grad(): # Ensure no grads during scoring
                P, R, F1 = bert_score_metric(paired_preds, paired_refs, lang=lang, verbose=False, device=device)
            return {
                "bertscore_precision": P.mean().item() * 100,
                "bertscore_recall": R.mean().item() * 100,
                "bertscore_f1": F1.mean().item() * 100,
            }
        except Exception as e:
            print(f"Error computing BERTScore: {e}")
            # Attempt cleanup if it was an OOM error
            if "out of memory" in str(e).lower():
                cleanup_memory()
            return {}

    def evaluate_all(self):
        """Run all evaluation metrics and return a combined dictionary of results"""
        if not self.valid_data:
             print("Cannot evaluate: Invalid or empty prediction/reference data.")
             return {}

        print("Calculating evaluation metrics...")
        results = {}
        results.update(self.compute_rouge_score())
        results.update(self.compute_meteor_score())
        results.update(self.compute_bleu_scores())

        print("Calculating BERTScore (this may take a moment)...")
        # Add BERTScore calculation here, ensuring device consistency
        bertscore_results = self.compute_bertscore()
        results.update(bertscore_results)
        print("Evaluation metrics calculation complete.")
        return results

# %% [code]
# --- Main Execution ---

if __name__ == "__main__": # Use this block if running as a script

    # --- Configuration ---
    TRAIN_CLASSIFIER = True # Set to False if you already have a trained classifier
    TRAIN_SUMMARIZER = True # Set to False to skip summarizer training
    RUN_INFERENCE = True   # Set to False to skip inference
    RUN_EVALUATION = True  # Set to False to skip evaluation

    CLASSIFIER_EPOCHS = 1 # Quick train for demo
    SUMMARIZER_EPOCHS = 1 # Quick train for demo
    SUMMARIZER_BASE_MODEL = "facebook/bart-large"
    USE_ENERGY_LOSS_TRAINING = False # <<< DISABLE Energy Loss by default due to complexity/instability
                                     # Set to True to experiment, ensure classifier is trained first

    TRAIN_FILE = os.path.join(BASE_PATH, 'train.json')
    VALID_FILE = os.path.join(BASE_PATH, 'valid.json')
    TEST_FILE = os.path.join(BASE_PATH, 'test.json')

    # --- Verify Data Files Exist ---
    print("Checking for data files...")
    for f_path in [TRAIN_FILE, VALID_FILE, TEST_FILE]:
        if not os.path.exists(f_path):
            print(f"!!! FATAL ERROR: Data file not found at {f_path} !!!")
            print("Please ensure the dataset is correctly linked/placed.")
            # exit() # Exit if data is missing
            # For testing without data, comment out exit() and expect downstream errors
        else:
             print(f"Found: {f_path}")

    # Use the predefined checkpoint directory paths
    classifier_checkpoint_dir = CLASSIFIER_CKPT_DIR
    summarizer_checkpoint_dir = SUMMARIZER_CKPT_DIR

    # --- 1. Train Classifier (RoBERTa for Ep) ---
    classifier_trained_successfully = False
    if TRAIN_CLASSIFIER:
        print("\n--- Initiating Classifier Training ---")
        # The function now takes the specific output directory
        train_classifier(
            train_data_path=TRAIN_FILE,
            valid_data_path=VALID_FILE,
            classifier_output_dir=classifier_checkpoint_dir, # Pass the defined path
            epochs=CLASSIFIER_EPOCHS,
            # Adjust batch sizes if needed based on GPU memory
            train_batch_size=8,
            valid_batch_size=16
        )
        # Check if training likely succeeded by verifying output files
        config_path = os.path.join(classifier_checkpoint_dir, "config.json")
        model_path = os.path.join(classifier_checkpoint_dir, "pytorch_model.bin") # Or model.safetensors
        if os.path.exists(config_path) and (os.path.exists(model_path) or os.path.exists(model_path.replace(".bin",".safetensors"))):
            print(f"Classifier training finished. Checkpoint seems saved in {classifier_checkpoint_dir}")
            classifier_trained_successfully = True
        else:
            print(f"Classifier training finished, but checkpoint files were NOT verified in {classifier_checkpoint_dir}. Energy loss might fail.")
            classifier_trained_successfully = False

    elif not os.path.exists(os.path.join(classifier_checkpoint_dir, "config.json")):
        print(f"Classifier training skipped, and required config.json not found at {classifier_checkpoint_dir}.")
        if USE_ENERGY_LOSS_TRAINING:
            print("Disabling energy loss because the required classifier checkpoint is missing.")
            USE_ENERGY_LOSS_TRAINING = False
        classifier_trained_successfully = False
    else:
        print(f"Skipping classifier training. Using existing checkpoint: {classifier_checkpoint_dir}")
        classifier_trained_successfully = True # Assume existing checkpoint is valid

    # Ensure energy loss is disabled if classifier training failed or was skipped without a valid checkpoint
    if USE_ENERGY_LOSS_TRAINING and not classifier_trained_successfully:
        print("Warning: Disabling energy loss as a valid classifier checkpoint is not available.")
        USE_ENERGY_LOSS_TRAINING = False


    # --- 2. Train Summarizer (BART + LoRA + Optional Energy Loss) ---
    summarizer_tokenizer = None
    summarizer_loss_history = []
    summarizer_trained_successfully = False

    if TRAIN_SUMMARIZER:
        print("\n--- Initiating Summarizer Training ---")
        # train_seq2seq now returns the directory path, tokenizer, and history
        summarizer_saved_dir, summarizer_tokenizer, summarizer_loss_history = train_seq2seq(
            train_data_path=TRAIN_FILE,
            valid_data_path=VALID_FILE,
            model_name=SUMMARIZER_BASE_MODEL,
            summarizer_output_dir=summarizer_checkpoint_dir, # Pass the defined path
            classifier_checkpoint_path=classifier_checkpoint_dir if USE_ENERGY_LOSS_TRAINING else None,
            epochs=SUMMARIZER_EPOCHS,
            use_energy_loss=USE_ENERGY_LOSS_TRAINING,
            # Adjust batch sizes and grad accum based on GPU memory
            train_batch_size=2,        # BART-Large needs small batches
            gradient_accumulation_steps=16, # Effective batch size 32
            valid_batch_size=4,
            lambda_perspective=0.1 # Weight for custom loss if used
        )
        # Check if training likely succeeded by verifying output files
        adapter_config_path = os.path.join(summarizer_checkpoint_dir, "adapter_config.json")
        adapter_model_path_bin = os.path.join(summarizer_checkpoint_dir, "adapter_model.bin")
        adapter_model_path_safe = os.path.join(summarizer_checkpoint_dir, "adapter_model.safetensors")
        if os.path.exists(adapter_config_path) and (os.path.exists(adapter_model_path_bin) or os.path.exists(adapter_model_path_safe)):
            print(f"Summarizer training finished. PEFT Checkpoint seems saved in {summarizer_checkpoint_dir}")
            summarizer_trained_successfully = True
        else:
            print(f"Summarizer training finished, but PEFT checkpoint files were NOT verified in {summarizer_checkpoint_dir}. Inference might fail.")
            summarizer_trained_successfully = False

        # Plot training loss if history is available
        if summarizer_loss_history:
            try:
                plt.figure(figsize=(10, 6))
                # Filter out potential inf values before plotting
                plot_loss_history = [l for l in summarizer_loss_history if np.isfinite(l)]
                if plot_loss_history:
                     plt.plot(range(1, len(plot_loss_history) + 1), plot_loss_history, marker='o')
                     plt.title('Summarizer Training Loss per Epoch')
                     plt.xlabel('Epoch')
                     plt.ylabel('Average Loss')
                     plt.grid(True)
                     plot_path = os.path.join(PLOTS_DIR, "summarizer_training_loss.png")
                     plt.savefig(plot_path)
                     print(f"Training loss plot saved to {plot_path}")
                     # plt.show() # Show plot in notebook if desired
                     plt.close() # Close the plot to free memory
                else:
                     print("No finite loss values recorded to plot.")
            except Exception as e:
                print(f"Error plotting training loss: {e}")

    elif not os.path.exists(os.path.join(summarizer_checkpoint_dir, "adapter_config.json")):
        print(f"Summarizer training skipped, and required adapter_config.json not found at {summarizer_checkpoint_dir}.")
        RUN_INFERENCE = False # Disable inference if no model
        RUN_EVALUATION = False
        summarizer_trained_successfully = False
    else:
        print(f"Skipping summarizer training. Using existing checkpoint: {summarizer_checkpoint_dir}")
        # Need to load the tokenizer if skipping training but running inference
        try:
             summarizer_tokenizer = BartTokenizer.from_pretrained(summarizer_checkpoint_dir)
             summarizer_trained_successfully = True # Assume valid checkpoint if tokenizer loads
        except Exception as e:
             print(f"Could not load tokenizer from {summarizer_checkpoint_dir}. Trying base model tokenizer. Error: {e}")
             try:
                 summarizer_tokenizer = BartTokenizer.from_pretrained(SUMMARIZER_BASE_MODEL)
                 summarizer_trained_successfully = True # Still assume checkpoint might be usable
             except Exception as e2:
                  print(f"Could not load base model tokenizer either: {e2}")
                  summarizer_trained_successfully = False
                  RUN_INFERENCE = False
                  RUN_EVALUATION = False


    # --- 3. Run Inference ---
    inference_results_df = None
    generated_csv_path = os.path.join(GENERATED_DIR, "bart_lora_generated_results.csv")

    if RUN_INFERENCE:
        if not summarizer_trained_successfully:
             print(f"Cannot run inference: Summarizer model checkpoint is missing or invalid in {summarizer_checkpoint_dir}.")
        else:
            print("\n--- Initiating Inference ---")
            inference_results_df = inference(
                 test_file=TEST_FILE,
                 base_model_name=SUMMARIZER_BASE_MODEL,
                 peft_ckpt_dir=summarizer_checkpoint_dir, # Use the specific path
                 output_csv_path=generated_csv_path,
                 batch_size_test=8 # Adjust as needed based on GPU memory for inference
            )
            if inference_results_df is not None:
                  print("Sample Inference Results:")
                  print(inference_results_df.head())


    # --- 4. Run Evaluation ---
    if RUN_EVALUATION:
        if inference_results_df is None:
             # Try to load from CSV if inference wasn't run in this session but file exists
             if os.path.exists(generated_csv_path):
                  print(f"\n--- Loading previous inference results for Evaluation from {generated_csv_path} ---")
                  try:
                      inference_results_df = pd.read_csv(generated_csv_path)
                  except Exception as e:
                       print(f"Error loading inference results CSV: {e}")
                       inference_results_df = None # Ensure it's None if loading fails
             else:
                  print("Cannot run evaluation: No inference results DataFrame available and CSV file not found.")

        if inference_results_df is not None:
             print("\n--- Initiating Evaluation ---")
             # Ensure columns exist and handle potential NaN values robustly
             if 'PREDICTED' not in inference_results_df.columns or 'ACTUAL_SUMMARY' not in inference_results_df.columns:
                  print("Error: Required columns ('PREDICTED', 'ACTUAL_SUMMARY') not found in inference results DataFrame.")
             else:
                 predictions = inference_results_df['PREDICTED'].fillna('').astype(str).tolist()
                 references = inference_results_df['ACTUAL_SUMMARY'].fillna('').astype(str).tolist()

                 evaluator = EvaluationMetrics(predictions, references)
                 all_scores = evaluator.evaluate_all()

                 print("\n--- Evaluation Scores ---")
                 print(json.dumps(all_scores, indent=2))

                 # Save scores to a file
                 scores_path = os.path.join(GENERATED_DIR, "bart_lora_evaluation_scores.json")
                 try:
                     # Ensure the directory exists
                     os.makedirs(os.path.dirname(scores_path), exist_ok=True)
                     with open(scores_path, 'w') as f:
                          json.dump(all_scores, f, indent=2)
                     print(f"Evaluation scores saved to {scores_path}")
                     # --- Verification ---
                     if os.path.exists(scores_path):
                          print(f"Evaluation scores JSON file verified at {scores_path}")
                     else:
                          print(f"!!! Warning: Evaluation scores JSON file NOT found at {scores_path} after saving !!!")
                 except Exception as e:
                     print(f"!!! Error saving evaluation scores: {e} !!!")
        else:
             print("Skipping evaluation as inference results are missing or failed to load.")

    print(f"\n--- Pipeline Finished ---")
    print(f"All outputs generated in: {OUTPUT_PATH}")
    print(f"Checkpoints saved in: {CHECKPOINTS_DIR}")
    print(f"Generated files (CSV, JSON) saved in: {GENERATED_DIR}")
    print(f"Plots saved in: {PLOTS_DIR}")

Using device: cuda
CUDA Device Name: Tesla T4
CUDA Version: 12.4
PyTorch Version: 2.5.1+cu124
Output base directory: /kaggle/working/
Generated files directory: /kaggle/working/generated
Checkpoints directory: /kaggle/working/checkpoints
Plots directory: /kaggle/working/plots
Checking for data files...
Found: /kaggle/input/nlp-data/train.json
Found: /kaggle/input/nlp-data/valid.json
Found: /kaggle/input/nlp-data/test.json

--- Initiating Classifier Training ---
--- Starting Classifier Training ---
Loaded 2236 training samples and 959 validation samples for classifier.
Classifier checkpoints will be saved to: /kaggle/working/checkpoints/classifier


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



######################### Classifier Epoch: 1/1 #########################


Classifier Epoch 1 Training:   0%|          | 0/280 [00:00<?, ?it/s]

Classifier Epoch 1 Average Train Loss: 1.1479
Classifier validation processing...


Classifier Validation:   0%|          | 0/60 [00:00<?, ?it/s]

Classifier Validation Loss: 1.0488, Accuracy: 0.6382
* New best validation loss: 1.0488. Saving model to /kaggle/working/checkpoints/classifier... *
Successfully saved classifier model and tokenizer to /kaggle/working/checkpoints/classifier
 - Missing model file (.bin or .safetensors)
--- Classifier Training Finished ---
Memory cleaned up.
Classifier training finished, but checkpoint files were NOT verified in /kaggle/working/checkpoints/classifier. Energy loss might fail.

--- Initiating Summarizer Training ---
--- Starting Summarizer Training (Model: facebook/bart-large, PEFT: LoRA) ---
Loaded 2236 training samples and 959 validation samples for summarizer.
Loading base model: facebook/bart-large


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Base model loaded.
Applied LoRA configuration.
trainable params: 2,359,296 || all params: 408,650,752 || trainable%: 0.5773
Moving main model to cuda...
Main model is now on device: cuda:0
Summarizer checkpoints will be saved to: /kaggle/working/checkpoints/summarizer

######################### Summarizer Epoch: 1/1 #########################


Summarizer Epoch 1 Training:   0%|          | 0/1118 [00:00<?, ?it/s]

Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Memory cleaned up.
Summarizer Epoch 1 Average Train Loss: 3.0905
Summarizer validation processing...


Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Summarizer Validation:   0%|          | 0/240 [00:00<?, ?it/s]