In [1]:
# (Cell 1) Install Libraries

!pip install nltk rouge-score py-rouge transformers tqdm datasets evaluate scispacy negspacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
! python -m nltk.downloader punkt stopwords
! pip install rouge_score evaluate
! pip install evaluate

Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
  Using cached https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz (531.2 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [4]:
# (Cell 2) Import Modules
# Import core libraries

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BartTokenizer, BartForConditionalGeneration # Using BART model
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm # Use auto version for better notebook integration
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer # For TF-IDF target generation
import gc  # Garbage collector
import os
import json
import re # For section parsing
import time # For runtime tracking
import random # For selecting qualitative examples
from torch.cuda.amp import autocast, GradScaler # For Mixed Precision Training
from collections import Counter
import evaluate

# --- spaCy / scispaCy / negspacy Imports (for Clinical Metrics) ---
import spacy
import scispacy # Required for loading sci models
# Optional: for advanced negation detection - uncomment if used
# from negspacy.negation import Negex


# --- Download NLTK resources ---
# Download necessary NLTK data for tokenization and stop words
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("Downloading NLTK punkt tokenizer...")
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except nltk.downloader.DownloadError:
    print("Downloading NLTK stopwords...")
    nltk.download('stopwords', quiet=True)

nltk.download('punkt_tab')

# Download punkt_tab for sentence tokenization if needed (TF-IDF)
try:
    nltk.data.find('tokenizers/punkt/english.pickle')
except nltk.downloader.DownloadError:
     print("Downloading NLTK punkt_tab for sentence tokenization...")
     nltk.download('punkt_tab', quiet=True)


# --- Load spaCy and scispaCy model ---
# Initialize flags
SCISPACY_LOADED = False
TARGET_ENTITY_TYPES = set() # Set to empty initially
nlp_sci = None # Initialize nlp_sci model

print("\n--- Loading scispaCy model for clinical evaluation ---")
print(f"SpaCy version being used by script: {spacy.__version__}")
try:
    from spacy.util import get_data_path
    print(f"SpaCy data path being used by script: {get_data_path()}")
except Exception as e:
    print(f"Could not get spaCy data path: {e}")

try:
    # Using the large model for potentially better entity recognition
    # Ensure en_core_sci_lg model v0.5.4 is installed and compatible with your spaCy version
    nlp_sci = spacy.load("en_core_sci_lg")
    # Optional: Add negspacy component if you want to explore negation later
    # nlp_sci.add_pipe("negex", config={"ent_types":list(TARGET_ENTITY_TYPES)}) # Ensure types match
    print("scispaCy model 'en_core_sci_lg' loaded successfully.")
    SCISPACY_LOADED = True
    # Define target entity types for clinical evaluation (adjust as needed)
    TARGET_ENTITY_TYPES = {"PROBLEM", "TREATMENT", "TEST", "ENTITY"}
    print(f"Target entity types for clinical evaluation: {TARGET_ENTITY_TYPES}")

except OSError:
    print("Error: scispaCy model 'en_core_sci_lg' not found or incompatible.")
    print("Please ensure it is installed correctly and compatible with your spaCy version.")
    print("Refer to scispaCy documentation for installation instructions: https://allenai.github.io/scispacy/")
    SCISPACY_LOADED = False
    TARGET_ENTITY_TYPES = set() # Set to empty if model not loaded
except Exception as e:
    print(f"An unexpected error occurred while loading scispaCy model: {e}")
    SCISPACY_LOADED = False
    TARGET_ENTITY_TYPES = set()


# --- Evaluation Metric Scorers ---
# Initialize ROUGE and BLEU scorers
print("\nSetting up ROUGE and BLEU scorers...")
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
smoother = SmoothingFunction().method1 # For BLEU smoothing

# Clean up memory after imports and initial setup
gc.collect()
torch.cuda.empty_cache()

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.



--- Loading scispaCy model for clinical evaluation ---
SpaCy version being used by script: 3.7.5
Could not get spaCy data path: cannot import name 'get_data_path' from 'spacy.util' (/usr/local/lib/python3.11/dist-packages/spacy/util.py)


  deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]


scispaCy model 'en_core_sci_lg' loaded successfully.
Target entity types for clinical evaluation: {'TREATMENT', 'PROBLEM', 'ENTITY', 'TEST'}

Setting up ROUGE and BLEU scorers...


In [5]:
# (Cell 3) Configuration and Paths

# --- Configuration ---
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED) # Seed for random sample selection
tqdm.pandas() # Enable progress bars for pandas apply

from google.colab import drive
drive.mount('/content/drive')

! unzip /content/drive/MyDrive/mimic-iii-10k.zip -d /content/datasetA

# --- Paths ---

DRIVE_PATH = "/content/drive/MyDrive/BioBart_Radiology_Summarization" # Define your base drive path
CHECKPOINT_DIR = os.path.join(DRIVE_PATH, "checkpoints")
METRICS_FILE = os.path.join(DRIVE_PATH, "training_metrics.json")
FINAL_RESULTS_FILE = os.path.join(DRIVE_PATH, "final_results.json")
TOKENIZER_PATH = os.path.join(DRIVE_PATH, "tokenizer") # Path to save/load custom tokenizer
FINAL_MODEL_PATH = os.path.join(DRIVE_PATH, "complete_model") # Path to save final model
CONFIG_FILE = os.path.join(DRIVE_PATH, "training_config.json") # Path for configuration file
QUALITATIVE_EXAMPLES_FILE = os.path.join(DRIVE_PATH, "qualitative_examples.json") # Path for qualitative examples

# Create directories if they don't exist
os.makedirs(DRIVE_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Working directory: {DRIVE_PATH}")


# --- Data Parameters ---
DATA_PATH = '/content/datasetA/MIMIC -III (10000 patients)/NOTEEVENTS/NOTEEVENTS_sorted.csv' # Path to your MIMIC-III data
SAMPLE_SIZE = 104995 # Number of radiology reports to sample from the full dataset
min_word_count = 50 # Minimum word count for an original report to be included
max_word_count = 1024 # Maximum word count for the original report (truncate if longer)
tfidf_ratio = 0.3 # Ratio of sentences to select for TF-IDF extractive summary


# --- Special Tokens Definition ---
# Section Headers found in reports and corresponding special tokens
SECTION_HEADERS = {
    "INDICATION": ["[INDICATION_SEP]", "[REASON_SEP]"],
    "TECHNIQUE": ["[TECHNIQUE_SEP]"],
    "FINDINGS": ["[FINDINGS_SEP]"],
    "IMPRESSION": ["[IMPRESSION_SEP]", "[CONCLUSION_SEP]"]
}
# Regex to find any of the defined headers for parsing
ALL_HEADERS_REGEX = "|".join([h.replace("[","").replace("]","").replace("_SEP","") for sl in SECTION_HEADERS.values() for h in sl])
# List of unique section special tokens
SECTION_SPECIAL_TOKENS = list(set([token for sublist in SECTION_HEADERS.values() for token in sublist]))
print(f"Section special tokens defined: {SECTION_SPECIAL_TOKENS}")

# Length Control Tokens based on target summary word count
LENGTH_CONTROL_TOKENS = ["<SUM_SHORT>", "<SUM_MEDIUM>", "<SUM_LONG>"]
# Thresholds for defining short, medium, long summaries (in words)
short_threshold = 50
medium_threshold = 100
print(f"Length control tokens defined: {LENGTH_CONTROL_TOKENS}")

# All new tokens to be added to the tokenizer vocabulary
ALL_NEW_SPECIAL_TOKENS = list(set(SECTION_SPECIAL_TOKENS + LENGTH_CONTROL_TOKENS))


# --- Model and Training Parameters ---
model_name = "GanjinZero/biobart-v2-base" # Pre-trained BART model
MAX_INPUT_LENGTH = 512  # Max token length for model input (includes control + section tokens + text)
MAX_TARGET_LENGTH = 150 # Max token length for model output (TF-IDF target summary)

INITIAL_LR = 2e-5 # Initial learning rate for the optimizer
num_epochs = 15 # Total number of training epochs
batch_size = 8 # Batch size per GPU/device
gradient_accumulation_steps = 8 # Number of batches to accumulate gradients over
effective_batch_size = batch_size * gradient_accumulation_steps
print(f"Effective batch size: {effective_batch_size}")

# Progressive Unfreezing Schedule: Define which layers to unfreeze at which epoch
# Note: Epoch numbers are 0-indexed here, but in training logs they are +1 (1-indexed)
progressive_unfreezing_schedule = {
    3: "unfreeze_all_decoder", # Example: Unfreeze all decoder layers at epoch 3
    8: "unfreeze_half_encoder" # Example: Unfreeze top half of encoder layers at epoch 8
    # Add more epochs and layers as needed
}
# Initial freezing happens after loading the base model and resizing embeddings
# Default: Unfreeze shared embeddings, lm_head, and the last N decoder layers
num_decoder_layers_to_unfreeze_initial = 4 # Number of decoder layers to unfreeze initially


# Curriculum Learning Schedule: Define how the training data size increases
curriculum_learning_schedule = {
    "initial_size": 20000, # Initial number of training samples
    "increment": 10000, # Number of samples to add per phase
    "increment_every_epochs": 3 # Number of epochs per phase
}


# --- Generation Parameters for Validation and Testing ---
# These parameters control the beam search during model.generate()
generation_parameters = {
    "max_length": MAX_TARGET_LENGTH + 10, # Maximum length of generated summary
    "num_beams": 4, # Number of beams for beam search
    "length_penalty": 1.0, # Encourages longer summaries if > 1.0, shorter if < 1.0
    "early_stopping": True, # Stop beam search when all beams have generated an EOS token
    "no_repeat_ngram_size": 3 # Avoids repeating n-grams of this size
}


# --- Configuration Dictionary (for saving) ---
training_config = {
    "seed": SEED,
    "drive_path": DRIVE_PATH,
    "data_path": DATA_PATH,
    "model_name": model_name,
    "sample_size": SAMPLE_SIZE,
    "min_word_count": min_word_count,
    "max_word_count": max_word_count,
    "tfidf_ratio": tfidf_ratio,
    "section_headers": SECTION_HEADERS,
    "length_control_tokens": LENGTH_CONTROL_TOKENS,
    "length_control_thresholds": {"short": short_threshold, "medium": medium_threshold},
    "all_new_special_tokens": ALL_NEW_SPECIAL_TOKENS,
    "max_input_length": MAX_INPUT_LENGTH,
    "max_target_length": MAX_TARGET_LENGTH,
    "initial_lr": INITIAL_LR,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "effective_batch_size": effective_batch_size,
    "initial_unfrozen_decoder_layers": num_decoder_layers_to_unfreeze_initial,
    "progressive_unfreezing_schedule": progressive_unfreezing_schedule,
    "curriculum_learning_schedule": curriculum_learning_schedule,
    "generation_parameters": generation_parameters,
    "evaluation_entity_types": list(TARGET_ENTITY_TYPES) if TARGET_ENTITY_TYPES else [], # Ensure serializable
    "scispacy_model": "en_core_sci_lg",
    "scispacy_loaded": SCISPACY_LOADED # Record if scispaCy was loaded
}

# Save the configuration
print(f"\nSaving training configuration to {CONFIG_FILE}...")
try:
    with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
        json.dump(training_config, f, ensure_ascii=False, indent=4)
    print("Configuration saved.")
except Exception as e:
    print(f"Error saving configuration: {e}")

Working directory: /content/drive/MyDrive/BioBart_TFIDF_Structured
Section special tokens defined: ['[TECHNIQUE_SEP]', '[REASON_SEP]', '[FINDINGS_SEP]', '[IMPRESSION_SEP]', '[CONCLUSION_SEP]', '[INDICATION_SEP]']
Length control tokens defined: ['<SUM_SHORT>', '<SUM_MEDIUM>', '<SUM_LONG>']
Effective batch size: 64

Saving training configuration to /content/drive/MyDrive/BioBart_TFIDF_Structured/training_config.json...
Configuration saved.


In [6]:
# (Cell 4) Utility Functions (Preprocessing Helpers)

# Function to add section tokens
def add_section_tokens(text, section_headers_map=SECTION_HEADERS, all_headers_regex=ALL_HEADERS_REGEX):
    """Parses report text and inserts section-specific special tokens."""
    if not text or not all_headers_regex: return text
    processed_text = text
    # Regex to find section headers, handling variations and start of line/document
    pattern = re.compile(r"(?:^|\\n\\s*)(" + all_headers_regex + r")\\s*:?\\s*", re.IGNORECASE)
    last_match_end = 0
    modified_parts = []
    try:
        for match in pattern.finditer(text):
            header_found = match.group(1).upper() # Get the matched header text
            match_start = match.start(1) # Start position of the matched header text
            special_token = None
            # Find the corresponding special token for the matched header
            for canonical, tokens in section_headers_map.items():
                # Create header names from the tokens in the map for matching
                headers_in_map = [h.replace("[","").replace("]","").replace("_SEP","") for h in tokens]
                if header_found in headers_in_map:
                    special_token = tokens[0] # Use the first token in the list as the main one
                    break

            if special_token:
                # Add text before the header, the special token, and a space
                modified_parts.append(text[last_match_end:match_start])
                modified_parts.append(special_token + " ")
                last_match_end = match.end(0) # Update the end position to the end of the full match (header + colon/space)

        # Add the remaining text after the last match
        modified_parts.append(text[last_match_end:])

        processed_text = "".join(modified_parts)
        processed_text = ' '.join(processed_text.split()) # Normalize whitespace
    except Exception as e:
        print(f"Warning: Error during section token insertion - {e}. Returning original text.")
        return text # Return original text on error

    return processed_text


# Function for TF-IDF Extractive Summary
def create_extractive_summary_tfidf(text, ratio=0.3):
    """Creates an extractive summary using TF-IDF."""
    if not text or len(text.split()) < min_word_count: # Use same threshold as original data filtering
        return text # Return original text if too short

    try:
        # Use NLTK's punkt for sentence tokenization
        sentences = sent_tokenize(text)

        if len(sentences) <= 3: # Handle very short texts or texts with few sentences
            num_sentences = max(1, int(len(sentences) * ratio))
            return ' '.join(sentences[:num_sentences])

        # Use scikit-learn's default English stop words
        tfidf_vectorizer = TfidfVectorizer(stop_words='english')
        tfidf_matrix = tfidf_vectorizer.fit_transform(sentences)

        # Calculate sentence scores by summing TF-IDF values
        sentence_scores = np.array(tfidf_matrix.sum(axis=1)).flatten()

        # Get indices of top sentences based on score
        num_sentences = max(1, int(len(sentences) * ratio))
        # Use argsort to get indices of sorted scores, take the top ones
        top_sentence_indices = sentence_scores.argsort()[-num_sentences:]
        # Sort indices to maintain original order of sentences in the summary
        top_sentence_indices = sorted(top_sentence_indices)

        summary = ' '.join([sentences[i] for i in top_sentence_indices])
        return summary
    except Exception as e:
        # Fallback to taking the first few sentences on error
        print(f"Warning: TF-IDF summarization failed - {e}. Returning first sentences.")
        try:
             sentences = sent_tokenize(text)
             num_sentences = max(1, int(len(sentences) * ratio))
             return ' '.join(sentences[:num_sentences])
        except:
             # Final fallback: return original text if short, else empty
             return text if len(text.split()) < min_word_count else ""


# Function to get length control token based on target summary word count
def get_length_control_token(target_summary_text, short_threshold=short_threshold, medium_threshold=medium_threshold):
    """Determines length control token based on target summary word count."""
    word_count = len(str(target_summary_text).split()) # Ensure it's a string
    if word_count <= short_threshold: return "<SUM_SHORT>"
    elif word_count <= medium_threshold: return "<SUM_MEDIUM>"
    else: return "<SUM_LONG>"

# Text normalization for evaluation
def normalize_text(text):
    """Normalizes text for consistent ROUGE/BLEU evaluation."""
    if not text: return "" # Return empty string for empty input
    # Convert to lower case and strip whitespace
    text = str(text).lower().strip() # Ensure string type
    # Replace multiple spaces with a single space
    text = ' '.join(text.split())
    return text

In [7]:
# (Cell 5) Data Loading and Preprocessing

print("Loading and preprocessing data...")

try:
    # Load data from the specified path
    data = pd.read_csv(DATA_PATH)
    print(f"Loaded data from {DATA_PATH}")

    # Filter for Radiology reports and create a copy to avoid SettingWithCopyWarning
    df = data[data['CATEGORY'] == 'Radiology'].copy()
    print(f"Filtered {len(df)} Radiology reports.")

    # --- Sampling ---
    if len(df) > SAMPLE_SIZE:
         full_df = df.sample(n=SAMPLE_SIZE, random_state=SEED).reset_index(drop=True)
    else:
         full_df = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    print(f"Selected {len(full_df)} Radiology samples for processing.")

    # --- Text Cleaning and Initial Filtering ---
    print("Cleaning text and applying initial word count filters...")
    full_df['TEXT'] = full_df['TEXT'].fillna("").astype(str) # Ensure string type and handle NaNs
    # Normalize whitespace (replace newlines/returns with spaces, then compress spaces)
    full_df['TEXT'] = full_df['TEXT'].progress_apply(lambda x: ' '.join(x.replace('\\n', ' ').replace('\\r', ' ').strip().split()))

    # Filter texts shorter than min_word_count
    original_len = len(full_df)
    full_df = full_df[full_df['TEXT'].apply(lambda x: len(x.split())) >= min_word_count].copy()
    print(f"Removed {original_len - len(full_df)} samples shorter than {min_word_count} words.")

    # Truncate texts longer than max_word_count
    full_df['TEXT'] = full_df['TEXT'].progress_apply(lambda x: ' '.join(x.split()[:max_word_count]))
    print(f"Truncated texts longer than {max_word_count} words.")


    # --- Generate TF-IDF Target Summary ---
    print(f"Generating TF-IDF target summaries ('target_text') with ratio {tfidf_ratio}...")
    full_df['target_text'] = full_df['TEXT'].progress_apply(lambda x: create_extractive_summary_tfidf(x, ratio=tfidf_ratio))


    # --- Add Section Tokens ---
    print("Adding section tokens to original text ('section_aware_text')...")
    full_df['section_aware_text'] = full_df['TEXT'].progress_apply(add_section_tokens)


    # --- Determine Length Control Token ---
    print(f"Determining length control tokens ('control_token') based on target summary word count (thresholds: short<={short_threshold}, medium<={medium_threshold})...")
    full_df['control_token'] = full_df['target_text'].apply(lambda x: get_length_control_token(x, short_threshold, medium_threshold))


    # --- Create Final Input Text ---
    print("Creating final input text ('input_text') by prepending control token...")
    # The final input text format is <LENGTH_TOKEN> <SECTION_AWARE_TEXT>
    full_df['input_text'] = full_df['control_token'] + " " + full_df['section_aware_text']


    # --- Final Cleanup: Remove samples with empty targets/inputs after processing ---
    # This step is crucial if TF-IDF or section parsing resulted in empty strings for some samples
    original_len = len(full_df)
    full_df = full_df[full_df['target_text'].apply(lambda x: len(str(x).strip()) > 0)].copy() # Ensure string type before strip
    full_df = full_df[full_df['input_text'].apply(lambda x: len(str(x).strip()) > 0)].copy()   # Ensure string type before strip
    print(f"Removed {original_len - len(full_df)} samples with empty targets/inputs after processing.")
    full_df = full_df.reset_index(drop=True) # Reset index after filtering


    # --- Data Splitting ---
    print("Splitting data into Train, Validation, and Test sets...")
    # Keep original TEXT column for clinical evaluation raw input
    train_df, temp_df = train_test_split(full_df[['input_text', 'target_text', 'TEXT']],
                                         test_size=0.3, random_state=SEED)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=SEED)

    # Reset indices for easier access later
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)

    print(f"Train set size: {len(train_df)}")
    print(f"Validation set size: {len(val_df)}")
    print(f"Test set size: {len(test_df)}")

except FileNotFoundError:
    print(f"Error: Data file not found at {DATA_PATH}. Please check the path in the Configuration cell.")
    # Exit or handle gracefully
    # exit()
except KeyError as e:
    print(f"Error: Missing expected column in CSV: {e}. Ensure 'TEXT' and 'CATEGORY' columns exist in your data file.")
    # exit()
except Exception as e:
    print(f"An unexpected error occurred during data loading/preprocessing: {e}")
    # exit()

Loading and preprocessing data...
Loaded data from /content/datasetA/MIMIC -III (10000 patients)/NOTEEVENTS/NOTEEVENTS_sorted.csv
Filtered 104995 Radiology reports.
Selected 104995 Radiology samples for processing.
Cleaning text and applying initial word count filters...


  0%|          | 0/104995 [00:00<?, ?it/s]

Removed 212 samples shorter than 50 words.


  0%|          | 0/104783 [00:00<?, ?it/s]

Truncated texts longer than 1024 words.
Generating TF-IDF target summaries ('target_text') with ratio 0.3...


  0%|          | 0/104783 [00:00<?, ?it/s]

Adding section tokens to original text ('section_aware_text')...


  0%|          | 0/104783 [00:00<?, ?it/s]

Determining length control tokens ('control_token') based on target summary word count (thresholds: short<=50, medium<=100)...
Creating final input text ('input_text') by prepending control token...
Removed 0 samples with empty targets/inputs after processing.
Splitting data into Train, Validation, and Test sets...
Train set size: 73348
Validation set size: 15717
Test set size: 15718


In [8]:
# (Cell 6) Custom Dataset Class

class MIMICDataset(Dataset):
    """Custom Dataset for MIMIC-III Radiology Reports."""

    def __init__(self, dataframe, tokenizer, max_input_length, max_target_length):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
        # Check required columns in the dataframe
        required_cols = ['input_text', 'target_text', 'TEXT'] # TEXT holds original for raw_input
        if not all(col in dataframe.columns for col in required_cols):
             raise ValueError(f"Dataframe must contain columns: {required_cols}")

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.dataframe)

    def __getitem__(self, idx):
        """Retrieves a single sample by index."""
        if idx >= len(self.dataframe): raise IndexError("Index out of bounds")

        # Get the texts for the current sample
        input_text = str(self.dataframe.iloc[idx]['input_text'])   # Final input with control+section tokens
        target_text = str(self.dataframe.iloc[idx]['target_text']) # TF-IDF summary (Target for loss)
        raw_original_text = str(self.dataframe.iloc[idx]['TEXT'])  # Original report text (for clinical metrics)


        # Tokenize the input text
        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt" # Return PyTorch tensors
        )

        # Tokenize the target text. Use as_target_tokenizer() for correct handling of BOS/EOS/padding tokens for the decoder.
        with self.tokenizer.as_target_tokenizer():
            target_encoding = self.tokenizer(
                target_text,
                max_length=self.max_target_length,
                padding='max_length',
                truncation=True,
                return_tensors="pt" # Return PyTorch tensors
            )

        labels = target_encoding['input_ids']
        # Replace padding token id in labels with -100 so it's ignored in the loss calculation
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encoding['input_ids'].flatten(), # Remove batch dimension
            'attention_mask': input_encoding['attention_mask'].flatten(), # Remove batch dimension
            'labels': labels.flatten(), # Remove batch dimension
            'raw_input': raw_original_text, # Include original text for clinical evaluation
            'raw_target': target_text      # Include TF-IDF target for standard evaluation
        }

In [9]:
# (Cell 7) Evaluation Metric Functions

def calculate_metrics(references, hypotheses):

    # Normalize texts before calculating metrics
    references = [normalize_text(ref) for ref in references]
    hypotheses = [normalize_text(hyp) for hyp in hypotheses]

    rouge1_scores, rouge2_scores, rougeL_scores, bleu_scores = [], [], [], []

    # Calculate ROUGE scores for each pair
    for ref, hyp in zip(references, hypotheses):
        # Skip empty references or hypotheses for ROUGE calculation
        if not ref or not hyp:
             # print(f"Warning: Skipping empty reference or hypothesis for ROUGE. Ref: '{ref[:20]}...', Hyp: '{hyp[:20]}...'")
            continue
        try:
            # Calculate ROUGE scores for the current reference-hypothesis pair
            scores = scorer.score(ref, hyp)
            rouge1_scores.append(scores['rouge1'].fmeasure)
            rouge2_scores.append(scores['rouge2'].fmeasure)
            rougeL_scores.append(scores['rougeL'].fmeasure)
        except Exception as e:
            print(f"Error calculating ROUGE for pair: ref='{ref[:50]}...', hyp='{hyp[:50]}...'. Error: {e}")


    # Calculate BLEU scores for each pair
    for ref, hyp in zip(references, hypotheses):
         # Skip empty references or hypotheses for BLEU calculation
         if not ref or not hyp:
              # print(f"Warning: Skipping empty reference or hypothesis for BLEU. Ref: '{ref[:20]}...', Hyp: '{hyp[:20]}...'")
             continue
         try:
            # Tokenize sentences for BLEU calculation
            ref_tokens = nltk.word_tokenize(ref)
            hyp_tokens = nltk.word_tokenize(hyp)
            # Skip if tokenization results in empty lists
            if not ref_tokens or not hyp_tokens:
                 # print(f"Warning: Skipping BLEU due to empty token lists. Ref: '{ref[:20]}...', Hyp: '{hyp[:20]}...'")
                 continue
            # Calculate BLEU score using sentence_bleu with smoothing
            # sentence_bleu expects a list of reference sentences (even if only one)
            bleu_score = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smoother)
            bleu_scores.append(bleu_score)
         except Exception as e:
             print(f"Error calculating BLEU for pair: ref='{ref[:50]}...', hyp='{hyp[:50]}...'. Error: {e}")


    # Calculate average scores, handling cases with no valid scores
    avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0.0
    avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores) if rouge2_scores else 0.0
    avg_rougeL = sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0.0
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0

    return {
        'rouge-1': avg_rouge1,
        'rouge-2': avg_rouge2,
        'rouge-l': avg_rougeL,
        'bleu': avg_bleu
    }


# Clinical Metrics (Entity Overlap F1)
# scispaCy model (nlp_sci) and TARGET_ENTITY_TYPES are initialized in Cell 1

def calculate_clinical_metrics(references_raw, hypotheses, target_entity_types=TARGET_ENTITY_TYPES):

    # Check if scispaCy model is loaded and target entity types are defined
    if not SCISPACY_LOADED or not target_entity_types:
        if not SCISPACY_LOADED:
             print("Warning: scispaCy model not loaded. Skipping clinical metrics calculation.")
        elif not target_entity_types:
             print("Warning: No target entity types defined for clinical metrics. Skipping calculation.")
        return {'entity_recall': 0.0, 'entity_precision': 0.0, 'entity_f1': 0.0}

    all_recalls, all_precisions, all_f1s = [], [], []
    processed_pairs_count = 0 # Counter for pairs successfully processed by scispaCy

    print(f"Calculating clinical metrics for {len(references_raw)} pairs using entity types: {target_entity_types}...")
    # Use tqdm for a progress bar during this potentially slow process
    for ref_text, hyp_text in tqdm(zip(references_raw, hypotheses), total=len(references_raw), desc="Clinical Metrics"):
        # Ensure texts are strings and not empty or just whitespace
        ref_text = str(ref_text).strip()
        hyp_text = str(hyp_text).strip()

        if not ref_text or not hyp_text:
            # print(f"Skipping pair due to empty reference or hypothesis string after stripping.")
            continue

        try:
            # Process the original raw text (reference for entities) with scispaCy
            doc_ref = nlp_sci(ref_text)
            # Extract entities of target types, convert lemma to lower case for comparison
            ref_entities = {ent.lemma_.lower() for ent in doc_ref.ents if ent.label_ in target_entity_types}

            # Process the generated summary (hypothesis) with scispaCy
            doc_hyp = nlp_sci(hyp_text)
            hyp_entities = {ent.lemma_.lower() for ent in doc_hyp.ents if ent.label_ in target_entity_types}

            # Calculate the number of entities common to both original text and generated summary
            common_entities_count = len(ref_entities.intersection(hyp_entities))

            # Calculate Recall, Precision, and F1 score for this pair
            # Recall: Proportion of entities in the reference (original text) that are found in the hypothesis (generated summary)
            recall = common_entities_count / len(ref_entities) if len(ref_entities) > 0 else 0.0
            # Precision: Proportion of entities in the hypothesis (generated summary) that are found in the reference (original text)
            precision = common_entities_count / len(hyp_entities) if len(hyp_entities) > 0 else 0.0
            # F1 Score: Harmonic mean of Precision and Recall
            f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

            # Append scores for this pair to the lists
            all_recalls.append(recall)
            all_precisions.append(precision)
            all_f1s.append(f1)
            processed_pairs_count += 1 # Increment counter for successfully processed pairs

        except Exception as e:
            # Catch any errors during scispaCy processing for a specific pair
            print(f"\nError processing pair with scispaCy (Index {processed_pairs_count}) - ref='{ref_text[:50]}...', hyp='{hyp_text[:50]}...'. Error: {e}\n")
            # Decide how to handle errors: skipping the pair or appending zeros. Skipping is chosen here.
            continue

    print(f"Finished processing {processed_pairs_count} pairs for clinical metrics.")

    # Calculate average scores across all processed pairs, handling case with no processed pairs
    avg_recall = sum(all_recalls) / len(all_recalls) if all_recalls else 0.0
    avg_precision = sum(all_precisions) / len(all_precisions) if all_precisions else 0.0
    avg_f1 = sum(all_f1s) / len(all_f1s) if all_f1s else 0.0

    return {
        'entity_recall': avg_recall,
        'entity_precision': avg_precision,
        'entity_f1': avg_f1
    }

In [10]:
# (Cell 8) Training Strategy Functions

# Curriculum Learning Function
def get_curriculum_dataset(current_epoch, train_df, tokenizer,
                           initial_size=curriculum_learning_schedule["initial_size"],
                           increment=curriculum_learning_schedule["increment"],
                           increment_every=curriculum_learning_schedule["increment_every_epochs"]):

    # Determine the current phase based on the epoch number
    epoch_phase = current_epoch // increment_every

    # Calculate the target sample size for the current phase
    sample_size = initial_size + epoch_phase * increment

    # Ensure the calculated sample size does not exceed the total number of available training samples
    sample_size = min(sample_size, len(train_df))

    print(f"Epoch {current_epoch+1}: Using {sample_size} training samples (Curriculum Learning).")

    # Create a subset of the training dataframe using the calculated sample size
    # .iloc[:sample_size] takes the first 'sample_size' rows
    current_train_df = train_df.iloc[:sample_size]

    # Create and return a MIMICDataset instance for the selected subset
    return MIMICDataset(current_train_df, tokenizer, max_input_length=MAX_INPUT_LENGTH, max_target_length=MAX_TARGET_LENGTH)


# Progressive Unfreezing Function
def unfreeze_layers(model, epoch, optimizer, initial_lr, unfreezing_schedule=progressive_unfreezing_schedule):
    """
    Unfreezes more layers of the model at specified epochs and re-initializes optimizer if needed.

    Args:
        model: The model instance.
        epoch (int): The current epoch number (starting from 0).
        optimizer: The current optimizer instance.
        initial_lr: The initial learning rate.
        unfreezing_schedule (dict): Dictionary mapping epoch numbers (0-indexed) to unfreezing actions.

    Returns:
        The potentially re-initialized optimizer.
    """
    optimizer_reset_needed = False # Flag to indicate if the optimizer needs to be reset

    # Check if the current epoch is in the unfreezing schedule
    if epoch in unfreezing_schedule:
        action = unfreezing_schedule[epoch]
        print(f"\n--- Epoch {epoch+1}: Applying Progressive Unfreezing action: '{action}' ---")

        if action == "unfreeze_all_decoder":
            # Unfreeze all parameters in the decoder
            for name, param in model.model.decoder.named_parameters():
                 if not param.requires_grad: # Only change if not already trainable
                      param.requires_grad = True
                      optimizer_reset_needed = True
            print("All decoder layers are now trainable.")

        elif action == "unfreeze_half_encoder":
            # Unfreeze the top half of the encoder layers
            encoder_layers = model.model.encoder.layers
            num_encoder_layers = len(encoder_layers)
            num_layers_to_unfreeze = num_encoder_layers // 2 # Unfreeze the top half

            print(f"Unfreezing the top {num_layers_to_unfreeze} out of {num_encoder_layers} encoder layers.")
            # Iterate through encoder layers and unfreeze the top ones
            for i, layer in enumerate(encoder_layers):
                 # Unfreeze layers from num_layers_to_unfreeze onwards (0-indexed)
                 if i >= (num_encoder_layers - num_layers_to_unfreeze):
                      # Check if any parameter in the layer is not already trainable
                      if not any(p.requires_grad for p in layer.parameters()):
                           print(f"Unfreezing encoder layer {i} (Index {i}/{num_encoder_layers-1})")
                           for param in layer.parameters():
                                param.requires_grad = True
                           optimizer_reset_needed = True
            print("Specified encoder layers are now trainable.")

        else:
            print(f"Warning: Unknown unfreezing action '{action}' specified in schedule for epoch {epoch+1}.")


    # This ensures the new parameters are added to the optimizer's state.
    if optimizer_reset_needed:
         print("Re-initializing optimizer due to layer unfreezing...")
         # Create a new optimizer instance, including all currently trainable parameters
         trainable_model_params = filter(lambda p: p.requires_grad, model.parameters())
         optimizer = AdamW(trainable_model_params, lr=initial_lr) # Use the initial learning rate
         print("Optimizer re-initialized with current trainable parameters.")

    # Print the ratio of trainable parameters after applying any unfreezing
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Current trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)\\n")

    return optimizer # Return the (potentially new) optimizer instance

In [11]:
# (Cell 9) Checkpointing and Metrics Saving Functions

def save_checkpoint(model, optimizer, epoch, metrics, checkpoint_path):

    print(f"Saving checkpoint for epoch {epoch+1} to {checkpoint_path}...")
    try:
        checkpoint = {
            'epoch': epoch, # Save the 0-indexed epoch number
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics # Save the latest validation metrics dictionary
        }
        # Ensure directory exists before saving
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(checkpoint, checkpoint_path)
        print("Checkpoint saved successfully.")
    except Exception as e:
        print(f"Error saving checkpoint to {checkpoint_path}: {e}")

def load_checkpoint(checkpoint_path, model, optimizer, device):

    # Check if the checkpoint file exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        try:
            # Load the checkpoint dictionary, mapping to the specified device
            checkpoint = torch.load(checkpoint_path, map_location=device)

            # Load the model state dictionary
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Model state loaded.")

            # Load the optimizer state dictionary only if an optimizer is provided
            if optimizer is not None and 'optimizer_state_dict' in checkpoint:
                 try:
                      # Load optimizer state - handles parameters added by unfreezing if they match
                      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                      print("Optimizer state loaded.")
                 except ValueError as e:
                      # This can happen if the set of trainable parameters changed significantly
                      print(f"Warning: Could not load optimizer state, possibly due to parameter changes: {e}")
                      print("Optimizer state will be re-initialized for trainable parameters.")
                      # Re-initialize optimizer with current trainable parameters if loading failed
                      if optimizer is not None: # Double check optimizer is not None
                           trainable_model_params = filter(lambda p: p.requires_grad, model.parameters())
                           # Use the LR from the loaded checkpoint if available, otherwise fallback to initial
                           loaded_lr = next(iter(checkpoint['optimizer_state_dict']['param_groups']))['lr'] if 'optimizer_state_dict' in checkpoint and checkpoint['optimizer_state_dict']['param_groups'] else INITIAL_LR
                           optimizer = AdamW(trainable_model_params, lr=loaded_lr)
                           print("Optimizer re-initialized with current trainable parameters.")


            # Determine the epoch to start from. Add 1 because the saved epoch is the one that just finished.
            # Use .get with a default for robustness against older checkpoint formats
            start_epoch = checkpoint.get('epoch', -1) + 1
            # Load the metrics dictionary saved with this checkpoint
            loaded_metrics = checkpoint.get('metrics', {})

            print(f"Checkpoint loaded successfully. Resuming training from epoch {start_epoch}")
            return model, optimizer, start_epoch, loaded_metrics

        except Exception as e:
            # Handle errors during loading
            print(f"Error loading checkpoint from {checkpoint_path}: {e}")
            print("Starting training from scratch (Epoch 0).")
            # Return initial state if loading fails
            return model, optimizer, 0, {} # Start from epoch 0, return empty metrics

    else:
        # If checkpoint file does not exist
        print(f"Checkpoint file not found at {checkpoint_path}. Starting training from scratch (Epoch 0).")
        return model, optimizer, 0, {} # Start from epoch 0, return empty metrics


def save_metrics(metrics_dict, file_path):

    print(f"Saving metrics history to {file_path}...")
    try:
        # Ensure directory exists
        os.makedirs(os.path.dirname(file_path), exist_ok=True)

        with open(file_path, 'w', encoding='utf-8') as f:
            # Convert any numpy types (like numpy.float64) to native Python types
            # for JSON serialization. This is done for lists of numbers.
            serializable_metrics = {}
            for key, value in metrics_dict.items():
                if isinstance(value, list) and value and isinstance(value[0], (np.generic, float, int)):
                     # Convert each item in the list if it's a numpy generic type
                     serializable_metrics[key] = [item.item() if isinstance(item, np.generic) else item for item in value]
                else:
                     # For non-list values or lists of other types, just use the value directly
                     serializable_metrics[key] = value

            # Dump the serializable dictionary to the JSON file with indentation
            json.dump(serializable_metrics, f, ensure_ascii=False, indent=4)
        print("Metrics history saved successfully.")
    except Exception as e:
        print(f"Error saving metrics history to {file_path}: {e}")

In [12]:
# (Cell 10) Model and Tokenizer Setup

print(f"Loading base model and tokenizer: {model_name}")
try:
    # Load base tokenizer and model from Hugging Face Hub
    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name)

    # Add ALL new special tokens defined in Configuration cell
    print(f"Adding {len(ALL_NEW_SPECIAL_TOKENS)} new special tokens to tokenizer vocabulary: {ALL_NEW_SPECIAL_TOKENS}")
    special_tokens_dict = {'additional_special_tokens': ALL_NEW_SPECIAL_TOKENS}
    num_added = tokenizer.add_special_tokens(special_tokens_dict)
    print(f"Number of tokens added: {num_added}")

    # Resize model embeddings to match the new tokenizer size.
    # This adds new embedding vectors for the newly added tokens, initialized randomly.
    print(f"Resizing model token embeddings from {model.config.vocab_size} to {len(tokenizer)}")
    model.resize_token_embeddings(len(tokenizer))
    print(f"Model vocabulary size after resizing: {model.get_input_embeddings().weight.shape[0]}")

    # Save the modified tokenizer configuration (including new tokens)
    tokenizer.save_pretrained(TOKENIZER_PATH)
    print(f"Tokenizer with special tokens saved to {TOKENIZER_PATH}")

except Exception as e:
    print(f"Error loading model/tokenizer or adding tokens: {e}")
    # Exit gracefully if model/tokenizer setup fails
    # exit()


# --- Initial Layer Freezing Strategy ---
# Freeze most layers initially and only train specific parts.
print("Applying initial layer freezing strategy...")

# Freeze all model parameters by default
for param in model.parameters():
    param.requires_grad = False

# Unfreeze shared embeddings
if model.model.shared is not None:
     for param in model.model.shared.parameters():
          param.requires_grad = True
     print("Unfroze shared embeddings.")

# Unfreeze the language model head
for param in model.lm_head.parameters():
    param.requires_grad = True
print("Unfroze language model head (lm_head).")

# Unfreeze the last N decoder layers (N is num_decoder_layers_to_unfreeze_initial)
decoder_layers = model.model.decoder.layers
num_decoder_layers = len(decoder_layers)
num_layers_to_unfreeze = min(num_decoder_layers_to_unfreeze_initial, num_decoder_layers) # Ensure N is not more than total layers

print(f"Unfreezing the last {num_layers_to_unfreeze} decoder layers.")
for i in range(num_decoder_layers - num_layers_to_unfreeze, num_decoder_layers):
    print(f"Unfreezing decoder layer {i} (Index {i}/{num_decoder_layers-1})")
    for param in decoder_layers[i].parameters():
        param.requires_grad = True

# Print the ratio of trainable parameters after initial freezing
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable parameters after initial freezing: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)\\n")

# Clean up memory
gc.collect()
torch.cuda.empty_cache()

Loading base model and tokenizer: GanjinZero/biobart-v2-base


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/892k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/666M [00:00<?, ?B/s]

Adding 9 new special tokens to tokenizer vocabulary: ['<SUM_SHORT>', '<SUM_LONG>', '[TECHNIQUE_SEP]', '<SUM_MEDIUM>', '[REASON_SEP]', '[FINDINGS_SEP]', '[IMPRESSION_SEP]', '[CONCLUSION_SEP]', '[INDICATION_SEP]']
Number of tokens added: 9
Resizing model token embeddings from 85401 to 85410


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Model vocabulary size after resizing: 85410
Tokenizer with special tokens saved to /content/drive/MyDrive/BioBart_TFIDF_Structured/tokenizer
Applying initial layer freezing strategy...
Unfroze shared embeddings.
Unfroze language model head (lm_head).
Unfreezing the last 4 decoder layers.
Unfreezing decoder layer 2 (Index 2/5)
Unfreezing decoder layer 3 (Index 3/5)
Unfreezing decoder layer 4 (Index 4/5)
Unfreezing decoder layer 5 (Index 5/5)

Trainable parameters after initial freezing: 103,401,984 / 166,411,776 (62.14%)\n


In [13]:
# (Cell 11) Create Datasets and DataLoaders

print("Creating datasets and dataloaders...")

try:
    # Create Dataset instances for validation and test sets
    # Training dataset is created dynamically in the training loop for Curriculum Learning
    val_dataset = MIMICDataset(val_df, tokenizer, max_input_length=MAX_INPUT_LENGTH, max_target_length=MAX_TARGET_LENGTH)
    test_dataset = MIMICDataset(test_df, tokenizer, max_input_length=MAX_INPUT_LENGTH, max_target_length=MAX_TARGET_LENGTH)

    # Create DataLoader instances for validation and test sets
    # num_workers > 0 can speed up data loading, pin_memory=True can speed up CPU-to-GPU transfer
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    print(f"Validation Dataloader created with {len(val_dataloader)} batches.")
    print(f"Test Dataloader created with {len(test_dataloader)} batches.")

except ValueError as e:
     print(f"Error creating dataset: {e}")
     # exit()
except Exception as e:
     print(f"An unexpected error occurred during dataset/dataloader creation: {e}")
     # exit()

# --- Device Setup ---
# Determine the device to use (GPU if available, otherwise CPU)
gc.collect()
torch.cuda.empty_cache() # Clear GPU cache before moving model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move the model to the selected device
try:
    model.to(device)
    print("Model moved to device.")
except RuntimeError as e:
    print(f"Error moving model to {device}: {e}. Trying CPU.")
    device = torch.device('cpu')
    model.to(device)
    print("Model moved to CPU.")


# --- Optimizer and Scaler Setup ---
# Initialize the optimizer with only the trainable parameters
# Using AdamW, a common choice for transformer fine-tuning
# The set of trainable parameters might change during progressive unfreezing
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=INITIAL_LR)
print(f"Optimizer initialized with learning rate: {INITIAL_LR}")

# Initialize GradScaler for Mixed Precision Training (AMP)
# This helps speed up training and reduce memory usage on compatible GPUs
scaler = GradScaler()
print("Initialized GradScaler for Automatic Mixed Precision (AMP).")

Creating datasets and dataloaders...
Validation Dataloader created with 1965 batches.
Test Dataloader created with 1965 batches.
Using device: cuda
Model moved to device.
Optimizer initialized with learning rate: 2e-05
Initialized GradScaler for Automatic Mixed Precision (AMP).


  scaler = GradScaler()


In [14]:
# (Cell 12) Training Loop

print("\nStarting training process...")

# Initialize variables for tracking metrics and best model
best_val_metric = 0.0 # Initialize best validation metric (tracking ROUGE-L)
best_model_epoch = -1 # Track the epoch number where the best model was found

# --- Checkpoint Loading and History Loading ---

# Define the metrics keys that are expected in the history file
# This list is comprehensive and includes all metrics we track
expected_keys = {
    'train_loss': [], 'val_loss': [],
    'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu': [],
    'entity_recall': [], 'entity_precision': [], 'entity_f1': [],
    'sample_sizes': [], # Size of the training dataset subset used in each epoch
    'epoch_runtime_seconds': [] # Duration of each epoch in seconds
}

# Initialize the dictionary to store the training history for all metrics
# This dictionary will be populated from a saved file or start fresh
all_metrics = {key: [] for key in expected_keys}

# Paths for checkpoints
latest_model_path = os.path.join(DRIVE_PATH, "latest_model.pt")
best_model_path = os.path.join(DRIVE_PATH, "best_model.pt")

# Determine the starting epoch and load model/optimizer state from checkpoint
start_epoch = 0 # Default start epoch
loaded_metrics_from_checkpoint = {} # Metrics dictionary saved within the checkpoint

# If the best model checkpoint is not found, try loading the latest model checkpoint
if os.path.exists(latest_model_path):
    print(f"--- Best model not found, loading latest model from {latest_model_path} to resume training ---")
    # load_checkpoint updates model, optimizer, start_epoch, and returns metrics saved with it
    model, optimizer, start_epoch, loaded_metrics_from_checkpoint = load_checkpoint(latest_model_path, model, optimizer, device)
    # Update best_val_metric based on the latest model's metrics (if available)
    best_val_metric = loaded_metrics_from_checkpoint.get('rouge-l', 0.0)
    # Get the epoch the latest model was trained at, default to start_epoch - 1
    best_model_epoch = loaded_metrics_from_checkpoint.get('epoch', start_epoch - 1)
    print(f"Loaded latest model (Epoch {start_epoch})") # Note: start_epoch is the *next* epoch number
    print(f"Resuming training from epoch {start_epoch}")

# Prioritize loading the best model checkpoint if it exists
elif os.path.exists(best_model_path):
    print(f"--- Loading best model from {best_model_path} to resume training ---")
    # load_checkpoint updates model, optimizer, start_epoch, and returns metrics saved with it
    model, optimizer, start_epoch, loaded_metrics_from_checkpoint = load_checkpoint(best_model_path, model, optimizer, device)
    # Update best_val_metric and best_model_epoch based on the loaded best model's metrics
    best_val_metric = loaded_metrics_from_checkpoint.get('rouge-l', 0.0)
    # Get the epoch the best model was trained at, default to start_epoch - 1 if not found
    best_model_epoch = loaded_metrics_from_checkpoint.get('epoch', start_epoch - 1)
    print(f"Loaded best model (Epoch {best_model_epoch + 1}) with ROUGE-L: {best_val_metric:.4f}")
    print(f"Resuming training from epoch {start_epoch}")

# If no checkpoint is found, training starts from scratch (Epoch 0)
else:
    print("--- No checkpoint found, starting training from scratch (Epoch 0) ---")
    start_epoch = 0 # Explicitly set start_epoch to 0


# --- Load Metrics History from File ---
# Load the full training history (metrics for all previous epochs) from the JSON file
# This is done *after* determining start_epoch from checkpoint
METRICS_FILE = os.path.join(DRIVE_PATH, "training_metrics.json") # Ensure path is defined

if os.path.exists(METRICS_FILE):
    print(f"Attempting to load previous metrics history from {METRICS_FILE}")
    try:
        with open(METRICS_FILE, 'r', encoding='utf-8') as f:
            loaded_metrics_hist = json.load(f)

        # Validate the structure and load history up to the start_epoch
        if isinstance(loaded_metrics_hist, dict):
            # Create a temporary dictionary with the expected keys and empty lists
            temp_metrics_for_load = {key: [] for key in expected_keys}
            valid_load_successful = True # Flag to track if loading was entirely successful

            # Determine the length of history available in the file based on the first expected key
            first_key = list(expected_keys.keys())[0]
            if first_key in loaded_metrics_hist and isinstance(loaded_metrics_hist[first_key], list):
                 file_history_len = len(loaded_metrics_hist[first_key])
            else:
                 # If the first key is missing or not a list, the file structure is likely invalid
                 file_history_len = 0
                 valid_load_successful = False
                 print(f"Warning: Metrics file {METRICS_FILE} has unexpected format or is empty.")


            if valid_load_successful:
                # Calculate the actual number of epochs to load history for.
                # This is the minimum of the determined start_epoch and the history length in the file.
                effective_history_len_to_load = min(start_epoch, file_history_len)

                if effective_history_len_to_load > 0:
                    print(f"Loading history for {effective_history_len_to_load} previous epochs.")
                    # Load data for each expected key up to the effective history length
                    for key in expected_keys:
                        if key in loaded_metrics_hist and isinstance(loaded_metrics_hist[key], list) and len(loaded_metrics_hist[key]) >= effective_history_len_to_load:
                             # Slice the list to get only the data up to start_epoch (0-indexed)
                             temp_metrics_for_load[key] = loaded_metrics_hist[key][:effective_history_len_to_load]
                        else:
                             # If a key is missing in the file or its list is shorter than needed,
                             # print a warning and keep the corresponding list empty in temp_metrics_for_load.
                             print(f"Warning: Data for key '{key}' is missing or incomplete in metrics file for loading up to epoch {start_epoch}. History for this key might be reset.")
                             valid_load_successful = False # Mark load as not fully successful for all keys
                             temp_metrics_for_load[key] = [] # Ensure list is empty if load failed for this key

                    if valid_load_successful: # If history for all expected keys was loaded correctly
                         all_metrics = temp_metrics_for_load # Replace the initial empty all_metrics with loaded history
                         print("Metrics history loaded successfully and aligned with checkpoint.")
                    else:
                         print("Issues found during metrics history loading. Starting history fresh to avoid corruption.")
                         # If loading issues occurred, reset all_metrics to empty lists
                         all_metrics = {key: [] for key in expected_keys}

                # Handle cases where the file history length is greater than start_epoch
                # This indicates a potential inconsistency between checkpoint and metrics file
                elif file_history_len > start_epoch:
                     print(f"Warning: Metrics file contains data for {file_history_len} epochs, but resuming from epoch {start_epoch}. This suggests inconsistency. Starting history fresh.")
                     all_metrics = {key: [] for key in expected_keys} # Start fresh due to inconsistency

                else: # file_history_len <= start_epoch and effective_history_len_to_load == 0
                    # This happens if start_epoch is 0 or the file history is very short/empty
                    print("No previous history needs to be loaded (start_epoch is 0 or history file is empty/short).")

            else: # valid_load_successful == False due to initial file format check failure
                 print("Metrics file has an invalid initial structure. Starting history fresh.")
                 all_metrics = {key: [] for key in expected_keys} # Start fresh

    except json.JSONDecodeError:
        print(f"Error decoding JSON from {METRICS_FILE}. File might be corrupt. Starting history fresh.")
        all_metrics = {key: [] for key in expected_keys} # Start fresh
    except Exception as e:
        print(f"An unexpected error occurred loading metrics history: {e}. Starting history fresh.")
        all_metrics = {key: [] for key in expected_keys} # Start fresh

else:
    print("Metrics history file not found. Starting history fresh.")
    # If the file does not exist, all_metrics remains the initial dictionary with empty lists.


# --- Main Training and Validation Loop ---
# Define the interval for performing expensive evaluations (Clinical)
evaluation_interval = 3 # Perform full evaluation every 3 epochs

for epoch in range(start_epoch, num_epochs):
    print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
    epoch_start_time = time.time() # Record the start time of the current epoch

    # --- Progressive Unfreezing ---
    # Apply the unfreezing strategy based on the current epoch number
    # This function returns the (potentially new) optimizer instance
    optimizer = unfreeze_layers(model, epoch, optimizer, INITIAL_LR, progressive_unfreezing_schedule)

    # --- Curriculum Learning ---
    # Get the subset of the training data for the current epoch
    current_train_dataset = get_curriculum_dataset(epoch, train_df, tokenizer,
                                                   initial_size=curriculum_learning_schedule["initial_size"],
                                                   increment=curriculum_learning_schedule["increment"],
                                                   increment_every=curriculum_learning_schedule["increment_every_epochs"])
    # Create the DataLoader for the current training dataset subset
    # Use more workers/pin_memory if resources allow for faster data loading
    train_dataloader = DataLoader(current_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    # Store the size of the training sample used in this epoch
    # This is always recorded
    all_metrics['sample_sizes'].append(len(current_train_dataset))


    # --- Training Phase ---
    model.train() # Set the model to training mode
    total_train_loss = 0 # Variable to accumulate training loss over the epoch
    optimizer.zero_grad() # Zero the gradients at the beginning of the epoch or accumulation step

    # Use tqdm for a progress bar over the training batches
    progress_bar_train = tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training", leave=False)
    for idx, batch in enumerate(progress_bar_train):
        # Move batch data to the appropriate device (GPU/CPU)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # --- Forward Pass with Mixed Precision (AMP) ---
        # autocast enables automatic mixed precision, which speeds up training
        with autocast(): # Note: Update to torch.amp.autocast('cuda', ...) for newer PyTorch
            # Perform the forward pass, calculate loss
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels # Labels are used internally to calculate the loss
            )
            loss = outputs.loss
            # Scale the loss by the gradient accumulation steps
            # This is necessary because gradients are summed over multiple batches
            loss = loss / gradient_accumulation_steps

        # --- Backward Pass with GradScaler ---
        # Scale the loss before backward pass in AMP
        scaler.scale(loss).backward()

        # --- Optimizer Step (with Gradient Accumulation) ---
        # Perform optimizer step only after accumulating gradients over several batches
        # Or perform a step for the last batch even if it's less than accumulation steps
        if (idx + 1) % gradient_accumulation_steps == 0 or (idx + 1) == len(train_dataloader):
            # Optional: Gradient clipping to prevent exploding gradients
            # scaler.unscale_(optimizer) # Unscale gradients before clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Perform optimizer step using the scaled gradients
            scaler.step(optimizer)
            # Update the scale for the next iteration
            scaler.update()
            # Zero the gradients after the optimizer step
            optimizer.zero_grad()

        # Accumulate the loss (undo the scaling by accumulation steps for accurate logging)
        total_train_loss += loss.item() * gradient_accumulation_steps
        # Update the progress bar with the current loss
        progress_bar_train.set_postfix({'loss': loss.item() * gradient_accumulation_steps})

        # Clean up batch variables and clear GPU cache to save memory
        del input_ids, attention_mask, labels, outputs, loss
        if device == torch.device('cuda'):
            torch.cuda.empty_cache()
        gc.collect() # Collect garbage

    # Calculate the average training loss for the epoch
    avg_train_loss = total_train_loss / len(train_dataloader) # Divide by number of batches
    print(f"Epoch {epoch+1} Average Training Loss: {avg_train_loss:.4f}")
    # Store the average training loss in the metrics history
    # This is always recorded
    all_metrics['train_loss'].append(avg_train_loss)


    # --- Validation Phase ---
    model.eval() # Set the model to evaluation mode
    total_val_loss = 0 # Variable to accumulate validation loss
    # Lists to store generated summaries, references, and raw inputs for evaluation
    val_references = [] # TF-IDF targets for ROUGE/BLEU
    val_hypotheses = [] # Model generated summaries
    val_raw_inputs = [] # Original texts for Clinical (Entity) metrics

    print(f"\nRunning Validation for Epoch {epoch+1}...")
    # Use tqdm for a progress bar over the validation batches
    progress_bar_val = tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation", leave=False)

    # Disable gradient calculation during validation to save memory and speed up
    with torch.no_grad():
        for batch in progress_bar_val:
            # Move batch data to the appropriate device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            # Get raw texts from the batch (these are not tensors)
            raw_input_texts = batch['raw_input']
            raw_target_texts = batch['raw_target']

            # --- Calculate Validation Loss ---
            # Use mixed precision for validation loss calculation as well
            with autocast(): # Note: Update to torch.amp.autocast('cuda', ...) for newer PyTorch
                 # Perform forward pass to calculate loss
                 outputs = model(
                     input_ids=input_ids,
                     attention_mask=attention_mask,
                     labels=labels
                 )
                 loss = outputs.loss
            # Accumulate the validation loss
            total_val_loss += loss.item()


            # --- Generate Summaries ---
            # Generate summaries from the input texts using beam search (or other decoding strategies)
            # This is done for all validation batches to calculate standard metrics
            # Use autocast here if generation benefits from mixed precision (check performance)
            # with autocast():
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                # Use generation parameters defined in Configuration cell
                max_length=generation_parameters["max_length"],
                num_beams=generation_parameters["num_beams"],
                length_penalty=generation_parameters["length_penalty"],
                early_stopping=generation_parameters["early_stopping"],
                no_repeat_ngram_size=generation_parameters["no_repeat_ngram_size"]
            )

            # Decode the generated token IDs back into human-readable text
            hypotheses_batch = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # Store the generated summaries, raw targets, and raw inputs for metric calculation later
            val_hypotheses.extend(hypotheses_batch)
            val_references.extend(raw_target_texts) # TF-IDF targets are references for standard metrics
            val_raw_inputs.extend(raw_input_texts)  # Original inputs are references for clinical metrics


            # Clean up batch variables and clear GPU cache
            del input_ids, attention_mask, labels, raw_input_texts, raw_target_texts, outputs, loss, generated_ids
            if device == torch.device('cuda'):
                torch.cuda.empty_cache()
            gc.collect() # Collect garbage


    # Calculate the average validation loss for the epoch
    avg_val_loss = total_val_loss / len(val_dataloader) # Divide by number of batches
    print(f"Epoch {epoch+1} Average Validation Loss: {avg_val_loss:.4f}")
    # Store the average validation loss in the metrics history
    # This is always recorded
    all_metrics['val_loss'].append(avg_val_loss)


    # --- Calculate Standard Validation Metrics (Always) ---
    print("\nCalculating Standard Validation Metrics (ROUGE, BLEU)...")
    # Calculate standard metrics comparing generated summaries to TF-IDF targets
    metrics_std = calculate_metrics(val_references, val_hypotheses)
    print(f"Std Metrics - R1: {metrics_std['rouge-1']:.4f}, R2: {metrics_std['rouge-2']:.4f}, RL: {metrics_std['rouge-l']:.4f}, B: {metrics_std['bleu']:.4f}")

    # Store Standard Metrics (Always)
    all_metrics['rouge-1'].append(metrics_std['rouge-1'])
    all_metrics['rouge-2'].append(metrics_std['rouge-2'])
    all_metrics['rouge-l'].append(metrics_std['rouge-l'])
    all_metrics['bleu'].append(metrics_std['bleu'])


    # --- Conditional Calculation of Clinical Metrics ---
    # Check if the current epoch is an interval where we perform the full evaluation
    if (epoch + 1) % evaluation_interval == 0:
        print(f"\nEpoch {epoch+1}: Performing full evaluation (Clinical)...")

        # Calculate Clinical Metrics (Entity Overlap)
        print("Calculating Clinical Validation Metrics (Entity Overlap)...")
        metrics_clin = calculate_clinical_metrics(val_raw_inputs, val_hypotheses)
        print(f"Clinical Metrics - ER: {metrics_clin['entity_recall']:.4f}, EP: {metrics_clin['entity_precision']:.4f}, EF1: {metrics_clin['entity_f1']:.4f}")

        # Store Clinical Metrics for this epoch
        all_metrics['entity_recall'].append(metrics_clin['entity_recall'])
        all_metrics['entity_precision'].append(metrics_clin['entity_precision'])
        all_metrics['entity_f1'].append(metrics_clin['entity_f1'])

    else:
        # If this is not an evaluation interval epoch, append placeholder values
        print(f"\nEpoch {epoch+1}: Skipping full evaluation (Clinical). Appending placeholder values.")
        all_metrics['entity_recall'].append(0.0) # Or use None, but 0.0 might be simpler for plotting
        all_metrics['entity_precision'].append(0.0)
        all_metrics['entity_f1'].append(0.0)


    # --- Record Epoch Runtime ---
    epoch_end_time = time.time() # Record the end time of the current epoch
    epoch_duration = epoch_end_time - epoch_start_time # Calculate epoch duration
    all_metrics['epoch_runtime_seconds'].append(epoch_duration) # Store duration in history
    print(f"Epoch {epoch+1} took {epoch_duration:.2f} seconds ({epoch_duration/60:.2f} minutes).")


    # --- Checkpoint Saving ---
    # Combine current epoch metrics for saving with the checkpoint (only include computed metrics)
    # If conditional evaluation happened, include all metrics; otherwise, only standard + loss
    if (epoch + 1) % evaluation_interval == 0:
         current_epoch_metrics = {**metrics_std, **metrics_clin,
                                  'train_loss': avg_train_loss, 'val_loss': avg_val_loss, 'epoch': epoch}
    else:
         current_epoch_metrics = {**metrics_std,
                                  'train_loss': avg_train_loss, 'val_loss': avg_val_loss, 'epoch': epoch,
                                  'entity_recall': 0.0, 'entity_precision': 0.0, 'entity_f1': 0.0} # Add clinical placeholders too


    # Save the latest model checkpoint after each epoch
    save_checkpoint(model, optimizer, epoch, current_epoch_metrics, latest_model_path)

    # Determine the current metric value used for tracking the "best" model
    # The code currently uses ROUGE-L from standard metrics (calculated every epoch)
    current_metric_value = metrics_std['rouge-l']

    # Check if the current model is the best seen so far based on best_val_metric
    if current_metric_value > best_val_metric:
        best_val_metric = current_metric_value # Update the best metric value
        best_model_epoch = epoch # Store the epoch number of the new best model
        print(f"\nNew best model found! Metric: {best_val_metric:.4f} at epoch {epoch+1}")
        # Save a separate checkpoint specifically for the best model
        # Save the full set of metrics calculated in this epoch with the best model checkpoint
        save_checkpoint(model, optimizer, epoch, current_epoch_metrics, best_model_path) # Save full metrics with best model
    else:
         # Ensure best_model_epoch is correctly reflected if a checkpoint was loaded initially
         if best_model_epoch == -1: # Case where no best model was found initially
              print(f"Current model metric ({current_metric_value:.4f}). No best model found yet.")
         else:
              print(f"Current model metric ({current_metric_value:.4f}) is not better than best ({best_val_metric:.4f} at epoch {best_model_epoch+1}).")


    # Save a backup checkpoint periodically (e.g., every 5 epochs)
    backup_interval = 5 # Define backup interval
    if (epoch + 1) % backup_interval == 0:
        backup_path = os.path.join(CHECKPOINT_DIR, f"backup_epoch_{epoch+1}.pt")
        save_checkpoint(model, optimizer, epoch, current_epoch_metrics, backup_path)


    # Save the entire metrics history to the JSON file after each epoch
    # This happens every epoch to ensure consistent list lengths
    save_metrics(all_metrics, METRICS_FILE)


    # Clean up memory before the next epoch
    del current_train_dataset # Delete the dataset object to free memory
    gc.collect()
    if device == torch.device('cuda'): torch.cuda.empty_cache()

# Print message after training loop completes
print("\n===== Training Completed =====")

# Calculate and print the total training duration (excluding time spent on crashes/restarts)
# This sum includes the runtime of each successfully completed epoch
total_training_seconds = sum(all_metrics.get('epoch_runtime_seconds', []))
total_training_minutes = total_training_seconds / 60.0
print(f"Total training duration (successful epochs): {total_training_minutes:.2f} minutes.")



Starting training process...
--- Best model not found, loading latest model from /content/drive/MyDrive/BioBart_TFIDF_Structured/latest_model.pt to resume training ---
Loading checkpoint from /content/drive/MyDrive/BioBart_TFIDF_Structured/latest_model.pt...
Model state loaded.
Optimizer state loaded.
Checkpoint loaded successfully. Resuming training from epoch 15
Loaded latest model (Epoch 15)
Resuming training from epoch 15
Attempting to load previous metrics history from /content/drive/MyDrive/BioBart_TFIDF_Structured/training_metrics.json
Loading history for 14 previous epochs.
Metrics history loaded successfully and aligned with checkpoint.

===== Training Completed =====
Total training duration (successful epochs): 2541.46 minutes.


In [None]:
# (Cell 13) Final Test Set Evaluation

print("\n===== Evaluating on Test Set using Best Model =====")
test_eval_start_time = time.time() # Record the start time of test evaluation

# Load the best performing model for the final evaluation
# Prioritize the best model checkpoint
best_model_path = os.path.join(DRIVE_PATH, "best_model.pt")
latest_model_path = os.path.join(DRIVE_PATH, "latest_model.pt")

if os.path.exists(best_model_path):
    print(f"Loading best model from {best_model_path} for final evaluation...")
    # Load the model state dictionary. Pass None for the optimizer as it's not needed for evaluation.
    model, _, _, _ = load_checkpoint(best_model_path, model, None, device)
elif os.path.exists(latest_model_path):
    print("Warning: Best model checkpoint not found. Evaluating with the latest model.")
    # Load the latest model checkpoint if the best is not available
    model, _, _, _ = load_checkpoint(latest_model_path, model, None, device)
else:
    # If neither checkpoint is found, print an error and exit or handle appropriately
    print("Error: No model checkpoint found (best or latest) for final evaluation.")
    # exit() # Exit the script

# Set the model to evaluation mode
model.eval()

# Lists to store generated summaries, references, and raw inputs for test evaluation
test_references = [] # TF-IDF Targets for ROUGE/BLEU
test_hypotheses = [] # Model generated summaries
test_raw_inputs = [] # Original Texts for Clinical (Entity) metrics

print("Running inference on the test set...")
# Use tqdm for a progress bar over the test batches
progress_bar_test = tqdm(test_dataloader, desc="Testing")

# Disable gradient calculation during inference
with torch.no_grad():
    for batch in progress_bar_test:
        # Move batch data to the appropriate device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # Get raw texts from the batch (these are not tensors)
        raw_input_texts = batch['raw_input']
        raw_target_texts = batch['raw_target']

        # --- Generate Summaries for Test Samples ---
        # Use the same generation parameters as used during validation
        generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=generation_parameters["max_length"],
            num_beams=generation_parameters["num_beams"],
            length_penalty=generation_parameters["length_penalty"],
            early_stopping=generation_parameters["early_stopping"],
            no_repeat_ngram_size=generation_parameters["no_repeat_ngram_size"]
        )

        # Decode the generated token IDs back into human-readable text
        hypotheses_batch = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        # Store the generated summaries, raw targets, and raw inputs for metric calculation later
        test_hypotheses.extend(hypotheses_batch)
        test_references.extend(raw_target_texts) # TF-IDF targets are references for standard metrics
        test_raw_inputs.extend(raw_input_texts)  # Original inputs are references for clinical metrics


        # Clean up batch variables and clear GPU cache
        del input_ids, attention_mask, raw_input_texts, raw_target_texts, generated_ids
        if device == torch.device('cuda'):
            torch.cuda.empty_cache()
        gc.collect() # Collect garbage


# --- Record Test Evaluation Runtime ---
test_eval_end_time = time.time() # Record the end time of test evaluation
test_eval_duration_seconds = test_eval_end_time - test_eval_start_time
test_eval_duration_minutes = test_eval_duration_seconds / 60.0
print(f"Test set evaluation duration: {test_eval_duration_minutes:.2f} minutes.")


# --- Calculate Final Test Metrics ---
print("\nCalculating Final Test Metrics...")

# Calculate standard metrics (ROUGE, BLEU) comparing generated summaries to TF-IDF targets
final_metrics_std = calculate_metrics(test_references, test_hypotheses)
print(f"Final Standard Metrics - R1: {final_metrics_std['rouge-1']:.4f}, R2: {final_metrics_std['rouge-2']:.4f}, RL: {final_metrics_std['rouge-l']:.4f}, B: {final_metrics_std['bleu']:.4f}")

# Calculate clinical metrics (Entity Overlap) comparing generated summaries to original raw inputs
final_metrics_clin = calculate_clinical_metrics(test_raw_inputs, test_hypotheses)
print(f"Final Clinical Metrics - ER: {final_metrics_clin['entity_recall']:.4f}, EP: {final_metrics_clin['entity_precision']:.4f}, EF1: {final_metrics_clin['entity_f1']:.4f}")


===== Evaluating on Test Set using Best Model =====
Loading best model from /content/drive/MyDrive/BioBart_TFIDF_Structured/best_model.pt for final evaluation...
Loading checkpoint from /content/drive/MyDrive/BioBart_TFIDF_Structured/best_model.pt...
Model state loaded.
Checkpoint loaded successfully. Resuming training from epoch 15
Running inference on the test set...


Testing:   0%|          | 0/1965 [00:00<?, ?it/s]



In [None]:
# (Cell 14) Save Final Results and Model

# --- Save Final Results Summary ---
# Combine all final test metrics and relevant information into a summary dictionary
final_results_summary = {
    'test_metrics_standard': final_metrics_std,
    'test_metrics_clinical': final_metrics_clin,
    'training_history_file': METRICS_FILE, # Path to the full training history JSON
    'training_config_file': CONFIG_FILE, # Path to the training configuration JSON
    'qualitative_examples_file': QUALITATIVE_EXAMPLES_FILE, # Path to the qualitative examples JSON
    'best_model_epoch': best_model_epoch + 1 if best_model_epoch != -1 else 'N/A', # +1 for human-readable epoch number
    'best_model_validation_metric': best_val_metric, # The best validation metric achieved
    'final_training_samples_used': all_metrics['sample_sizes'][-1] if all_metrics['sample_sizes'] else 'N/A', # Size of training data in the last epoch
    'total_training_duration_minutes': total_training_minutes, # Total time spent in training loop
    'test_evaluation_duration_minutes': test_eval_duration_minutes # Total time spent evaluating on test set
}

print(f"\nSaving final results summary to {FINAL_RESULTS_FILE}...")
try:
    # Save the summary dictionary to a JSON file
    with open(FINAL_RESULTS_FILE, 'w', encoding='utf-8') as f:
        json.dump(final_results_summary, f, ensure_ascii=False, indent=4)
    print("Final results summary saved successfully.")
except Exception as e:
    print(f"Error saving final results summary: {e}")

# --- Save the Final Trained Model ---
# Save the complete fine-tuned model using save_pretrained
# This saves the model's architecture and weights in a format that can be easily reloaded
print(f"\nSaving the complete fine-tuned model (from best epoch) to {FINAL_MODEL_PATH}...")
try:
    # Ensure the directory for the final model exists
    os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
    model.save_pretrained(FINAL_MODEL_PATH)
    # It's good practice to also save the tokenizer with the model,
    # though it was already saved earlier, doing it again here ensures they are together.
    tokenizer.save_pretrained(FINAL_MODEL_PATH)
    print("Final model and tokenizer saved successfully!")
except Exception as e:
    print(f"Error saving final model: {e}")


# --- Save Qualitative Examples ---
# Select a small number of random examples from the test set results
print(f"\nSelecting and saving qualitative examples to {QUALITATIVE_EXAMPLES_FILE}...")
num_examples_to_save = 10 # Define how many examples to save
saved_examples = [] # List to store the selected example dictionaries

# Ensure there are test samples available
if len(test_raw_inputs) > 0:
    # Select random indices from the range of test samples
    # Use min() to handle cases where test set size is less than num_examples_to_save
    selected_indices = random.sample(range(len(test_raw_inputs)), min(num_examples_to_save, len(test_raw_inputs)))

    # Iterate through the selected indices and create a dictionary for each example
    for idx in selected_indices:
        saved_examples.append({
            "original_input": test_raw_inputs[idx], # The original full report text
            "tfidf_target": test_references[idx], # The TF-IDF extractive summary (training target)
            "generated_summary": test_hypotheses[idx] # The summary generated by the model
        })

    try:
        # Save the list of example dictionaries to a JSON file
        with open(QUALITATIVE_EXAMPLES_FILE, 'w', encoding='utf-8') as f:
            json.dump(saved_examples, f, ensure_ascii=False, indent=4)
        print(f"{len(saved_examples)} qualitative examples saved successfully.")
    except Exception as e:
        print(f"Error saving qualitative examples: {e}")
else:
    print("No test samples available to save qualitative examples.")

In [None]:
# (Cell 15) Example Inference

# --- Example Inference Function ---
# This function demonstrates how to use the fine-tuned model to generate a summary for a new raw input text.
# It applies the same preprocessing steps (section tokens, length token) as used during training.

def generate_summary(input_text_raw, model, tokenizer, device, max_gen_length=generation_parameters["max_length"]):

    # Set the model to evaluation mode
    model.eval()

    # --- Apply Preprocessing Steps as done during training (Crucial) ---
    # The inference input MUST be formatted consistently with the training input.

    # 1. Add Section Tokens
    section_aware_text = add_section_tokens(input_text_raw, section_headers_map=SECTION_HEADERS, all_headers_regex=ALL_HEADERS_REGEX)


    # Using Option A for this example: prepend a default length token.
    # Choose one of your defined length control tokens.
    control_token_for_inference = "<SUM_MEDIUM>" # Example default for inference


    # 2. Create the final formatted input text
    # The format must match training: <LENGTH_TOKEN> <SECTION_AWARE_TEXT>
    final_input_text = control_token_for_inference + " " + section_aware_text

    # Add a basic check for empty input text after processing
    if not final_input_text.strip():
         print("Warning: Input text is empty after preprocessing. Cannot generate summary.")
         return "Error: Empty input text after preprocessing."


    # --- Tokenize the preprocessed input text ---
    encoding = tokenizer(
        final_input_text,
        max_length=MAX_INPUT_LENGTH, # Use the same max input length as training
        padding='max_length', # Pad to max_length
        truncation=True, # Truncate if longer than max_length
        return_tensors="pt" # Return PyTorch tensors
    )

    # Move tokenized input to the appropriate device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)


    summary_text = "Error generating summary." # Default error message in case of failure

    # --- Generate Summary using the model ---
    with torch.no_grad(): # Disable gradient calculation
        try:
            # Generate summary using the same parameters as validation/test evaluation
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_gen_length, # Use the specified max generation length
                num_beams=generation_parameters["num_beams"],
                length_penalty=generation_parameters["length_penalty"],
                early_stopping=generation_parameters["early_stopping"],
                no_repeat_ngram_size=generation_parameters["no_repeat_ngram_size"]
            )
            # Decode the generated IDs back to text, skipping special tokens
            summary_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        except Exception as e:
            print(f"Error during summary generation: {e}")


    # Clean up tensors and clear GPU cache
    del input_ids, attention_mask, encoding
    if device == torch.device('cuda'):
        torch.cuda.empty_cache()
    gc.collect() # Collect garbage

    return summary_text


# --- Example Usage of Inference Function ---
print("\n--- Example Inference ---")

# Define a sample raw input text (like a new radiology report)
sample_text_raw = """
INDICATION: Evaluate for pneumonia.
TECHNIQUE: Portable anteroposterior chest X-ray.
FINDINGS: The lungs are clear bilaterally without evidence of consolidation, effusion, or pneumothorax. The heart size is normal. Mediastinal and hilar contours are unremarkable. Visualized osseous structures are intact.
IMPRESSION: No acute cardiopulmonary abnormality.
"""


# Generate a summary for the sample raw text
generated_text = generate_summary(sample_text_raw, model, tokenizer, device)

# Print the original text and the generated summary
print(f"Input Text (Raw):\n{sample_text_raw}\n")
print(f"Generated Summary:\n{generated_text}")

print("\n===== Script Finished =====")