###   ****Constrained Summarization using BART with Prefix Tuning****

In [2]:
!pip install torchcrf

Collecting torchcrf
  Downloading TorchCRF-1.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.0.0->torchcrf)
  Downloading nvid

In [2]:
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [29]:
# Configuration
MODEL_NAME = "facebook/bart-base"
BATCH_SIZE = 4
MAX_LENGTH = 1024
NUM_EPOCHS = 4
LEARNING_RATE = 2e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
# Perspective configurations
PERSPECTIVE_CONFIG = {
    "INFORMATION": {
        "prefix": "For information purposes",
        "prefix_ids": None,  # Will be initialized later
        "class_weight": 0.4,
        "prefix_weight": 0.3
    },
    "SUGGESTION": {
        "prefix": "It is suggested",
        "prefix_ids": None,
        "class_weight": 0.4,
        "prefix_weight": 0.3
    },
    "EXPERIENCE": {
        "prefix": "In user's experience",
        "prefix_ids": None,
        "class_weight": 0.4,
        "prefix_weight": 0.3
    },
    "CAUSE": {
        "prefix": "Some of the causes",
        "prefix_ids": None,
        "class_weight": 0.4,
        "prefix_weight": 0.3
    },
    "QUESTION": {
        "prefix": "It is inquired",
        "prefix_ids": None,
        "class_weight": 0.4,
        "prefix_weight": 0.3
    }
}

In [31]:
class ConstrainedHealthcareDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        with open(data_path) as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self._init_prefix_ids()
        
        # Filter out items with empty labelled_summaries
        self.data = [item for item in self.data 
                    if item.get("labelled_summaries") and len(item["labelled_summaries"]) > 0]
        
    def _init_prefix_ids(self):
        for p in PERSPECTIVE_CONFIG.values():
            p["prefix_ids"] = self.tokenizer.encode(p["prefix"], add_special_tokens=False)[:4]
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        answers = " ".join([ans.replace("\n", " ") for ans in item["answers"]])
        
        # Safely get summary type and perspective
        summary_types = list(item["labelled_summaries"].keys())
        if not summary_types:
            raise ValueError(f"Item {idx} has no labelled summaries")
            
        summary_type = summary_types[0]
        perspective = summary_type.split("_")[0]
        target = item["labelled_summaries"][summary_type]
        
        inputs = self.tokenizer(
            f"Summarize this {perspective.lower()} content: {answers}",
            max_length=MAX_LENGTH,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        targets = self.tokenizer(
            target,
            max_length=MAX_LENGTH,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": targets["input_ids"].squeeze(),
            "perspective": perspective,
            "answers": answers
        }

In [24]:
class HealthcareSummarizerWithConstraints(torch.nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
        self.tokenizer = tokenizer  # Store tokenizer as an attribute
        self.classifier = torch.nn.Linear(self.model.config.d_model, len(PERSPECTIVE_CONFIG))
        
    def forward(self, input_ids, attention_mask, labels, perspective):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )
        
        # Base loss
        base_loss = outputs.loss
        
        # Get hidden states for classification
        last_hidden = outputs.decoder_hidden_states[-1][:, 0, :]  # First token embedding
        class_logits = self.classifier(last_hidden)
        
        # Perspective classification loss
        perspective_idx = torch.tensor([list(PERSPECTIVE_CONFIG.keys()).index(p) 
                                      for p in perspective]).to(DEVICE)
        class_loss = torch.nn.functional.cross_entropy(class_logits, perspective_idx)
        
        # Prefix constraint loss
        prefix_loss = self._calculate_prefix_loss(outputs.logits, perspective)
        
        # Semantic similarity loss
        semantic_loss = self._calculate_semantic_similarity(
            outputs.logits, 
            input_ids,
            attention_mask
        )
        
        # Combine losses
        total_loss = (
            base_loss +
            PERSPECTIVE_CONFIG[perspective[0]]["class_weight"] * class_loss +
            PERSPECTIVE_CONFIG[perspective[0]]["prefix_weight"] * prefix_loss +
            0.2 * semantic_loss
        )
        
        return total_loss
    
    def _calculate_prefix_loss(self, logits, perspectives):
        batch_size = logits.size(0)
        prefix_loss = 0
        
        for i in range(batch_size):
            perspective = perspectives[i]
            prefix_ids = PERSPECTIVE_CONFIG[perspective]["prefix_ids"]
            prefix_logits = logits[i, :len(prefix_ids)]
            prefix_targets = torch.tensor(prefix_ids).to(DEVICE)
            
            prefix_loss += torch.nn.functional.cross_entropy(
                prefix_logits, 
                prefix_targets,
                ignore_index=self.model.config.pad_token_id
            )
            
        return prefix_loss / batch_size
    
    def _calculate_semantic_similarity(self, logits, input_ids, attention_mask):
        # Generate candidate summaries
        generated_ids = torch.argmax(logits, dim=-1)
        generated_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) 
                         for ids in generated_ids]
        
        # Get embeddings for original content and generated summaries
        with torch.no_grad():
            input_embeddings = self.model.get_input_embeddings()(input_ids)
            input_means = torch.mean(input_embeddings, dim=1)
            
            generated_embeddings = self.model.get_input_embeddings()(generated_ids)
            generated_means = torch.mean(generated_embeddings, dim=1)
            
        # Cosine similarity loss
        similarities = torch.cosine_similarity(input_means, generated_means, dim=-1)
        return 1 - torch.mean(similarities)

In [25]:
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = HealthcareSummarizerWithConstraints(tokenizer).to(DEVICE)

In [32]:
# Create datasets and dataloaders
train_dataset = ConstrainedHealthcareDataset("/kaggle/input/nlp-healthcare/train.json", tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [27]:
# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [33]:
# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs = batch["input_ids"].to(DEVICE)
        masks = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        perspectives = batch["perspective"]
        
        optimizer.zero_grad()
        loss = model(inputs, masks, labels, perspectives)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    print(f"Average Loss: {total_loss/len(train_loader):.4f}")

Epoch 1: 100%|██████████| 559/559 [15:40<00:00,  1.68s/it]


Average Loss: 0.2737


Epoch 2: 100%|██████████| 559/559 [15:39<00:00,  1.68s/it]


Average Loss: 0.2376


Epoch 3: 100%|██████████| 559/559 [15:40<00:00,  1.68s/it]


Average Loss: 0.2291


Epoch 4: 100%|██████████| 559/559 [15:39<00:00,  1.68s/it]

Average Loss: 0.2079





In [34]:
# Generation with constraints
def constrained_generate(text, perspective, max_length=150):
    model.eval()
    prefix = PERSPECTIVE_CONFIG[perspective]["prefix"]
    
    inputs = tokenizer(
        f"Summarize this {perspective.lower()} content: {text}",
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors="pt"
    ).to(DEVICE)
    
    # Force prefix in generation
    prefix_ids = tokenizer.encode(prefix, add_special_tokens=False)[:4]
    forced_ids = torch.tensor([prefix_ids], device=DEVICE)
    
    outputs = model.model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=max_length,
        num_beams=4,
        early_stopping=True,
        forced_bos_token_id=forced_ids
    )
    
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Post-hoc validation
    if not summary.startswith(prefix):
        summary = f"{prefix} {summary.split(' ', 1)[-1]}"
        
    return summary

In [35]:
# Example usage
sample = train_dataset[0]
generated = constrained_generate(
    text=sample["answers"],
    perspective="INFORMATION"
)
print("Generated Summary:")
print(generated)

Generated Summary:
For information purposes, Parkinson's disease is one of the most common neurologic disorders of the elderly. The term "parkinsonism" refers to any condition that causes any combination of the types of movement abnormalities seen in Parkinson's condition by damaging or destroying dopamine neurons in a certain area of the brain. Parkinsonism describes the common symptoms of Parkinson's, such as tremor, rigidity, akinesia or bradykinesia and postural instability. Those patients who respond to drug treatment for Parkinson's are diagnosed with it, and those who do not have parkinsonism are treated.


In [3]:
pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=a0ef5c85e2b6c923232e2b4db0a5ca1df899c6d95c28514118c44953ca682b09
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2
Note: you may need to restart the kernel to use updated packages.


In [5]:
!pip install bert_score

Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bert_score
Successfully installed bert_score-0.3.13


In [38]:
from rouge_score import rouge_scorer
import json
from tqdm import tqdm

def evaluate_model(model, tokenizer, data_path, max_length=150):
    # Load evaluation data
    with open(data_path) as f:
        eval_data = json.load(f)
    
    # Filter out items with empty labelled_summaries
    eval_data = [item for item in eval_data 
                if item.get("labelled_summaries") and len(item["labelled_summaries"]) > 0]
    
    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # Store results
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []
    
    model.eval()
    
    for item in tqdm(eval_data, desc="Evaluating"):
        # Get the first summary type and perspective
        summary_type = list(item["labelled_summaries"].keys())[0]
        perspective = summary_type.split("_")[0]
        reference = item["labelled_summaries"][summary_type]
        answers = " ".join([ans.replace("\n", " ") for ans in item["answers"]])
        
        # Generate summary
        generated_summary = constrained_generate(
            text=answers,
            perspective=perspective,
            max_length=max_length
        )
        
        # Calculate ROUGE scores
        scores = scorer.score(reference, generated_summary)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    
    # Calculate average scores
    avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
    avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores)
    avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)
    
    return {
        'rouge1': avg_rouge1,
        'rouge2': avg_rouge2,
        'rougeL': avg_rougeL,
        'num_samples': len(eval_data)
    }

# Evaluation function for both validation and test sets
def run_evaluation(model, tokenizer):
    print("\nEvaluating on validation set...")
    valid_results = evaluate_model(model, tokenizer, "/kaggle/input/nlp-healthcare/valid.json")
    print(f"Validation ROUGE-1: {valid_results['rouge1']:.4f}")
    print(f"Validation ROUGE-2: {valid_results['rouge2']:.4f}")
    print(f"Validation ROUGE-L: {valid_results['rougeL']:.4f}")
    print(f"Validation samples: {valid_results['num_samples']}")
    
    print("\nEvaluating on test set...")
    test_results = evaluate_model(model, tokenizer, "/kaggle/input/nlp-healthcare/test.json")
    print(f"Test ROUGE-1: {test_results['rouge1']:.4f}")
    print(f"Test ROUGE-2: {test_results['rouge2']:.4f}")
    print(f"Test ROUGE-L: {test_results['rougeL']:.4f}")
    print(f"Test samples: {test_results['num_samples']}")
    
    return {
        'validation': valid_results,
        'test': test_results
    }

# Run evaluation after training
final_results = run_evaluation(model, tokenizer)

# You can also save the results to a file
with open("evaluation_results.json", "w") as f:
    json.dump(final_results, f, indent=2)


Evaluating on validation set...


Evaluating: 100%|██████████| 959/959 [07:49<00:00,  2.04it/s]


Validation ROUGE-1: 0.3651
Validation ROUGE-2: 0.1826
Validation ROUGE-L: 0.2914
Validation samples: 959

Evaluating on test set...


Evaluating: 100%|██████████| 640/640 [05:13<00:00,  2.04it/s]

Test ROUGE-1: 0.3689
Test ROUGE-2: 0.1870
Test ROUGE-L: 0.2949
Test samples: 640





In [3]:
with open("/kaggle/input/nlp-healthcare/train.json") as f:
        train = json.load(f)

In [5]:
len(train)

2236

In [6]:
train[0]

{'uri': '4367393',
 'question': 'what is parkinesonism?',
 'context': '',
 'answers': ['u spelt it wrong !!\nParkinson\'s disease is one of the most common neurologic disorders of the elderly. The term "parkinsonism" refers to any condition that causes any combination of the types of movement abnormalities seen in Parkinson\'s disease by damaging or destroying dopamine neurons in a certain area of the brain.',
  "Parkinsonism describes the common symptoms of Parkinson's disease - tremor, rigidity, akinesia or bradykinesia and postural instability. Those patients who respond to drug treatment for Parkinson's disease are diagnosed with it, and those who do not have parkinsonism."],
 'labelled_answer_spans': {'INFORMATION': [{'txt': 'Parkinson\'s disease is one of the most common neurologic disorders of the elderly. The term "parkinsonism" refers to any condition that causes any combination of the types of movement abnormalities seen in Parkinson\'s disease by damaging or destroying dopam

In [7]:
with open("/kaggle/input/nlp-healthcare/test.json") as f:
        test = json.load(f)

In [8]:
len(test)

640

In [9]:
test[0]

{'uri': '1309809',
 'question': 'what is orgasm?',
 'context': '',
 'answers': ['An orgasm, also known as a sexual climax, is a pleasurable physical, psychological or emotional response to prolonged sexual stimulation. It is often accompanied by a notable physiological reaction, such as ejaculation, blushing or spasm and may be followed by aftershocks.\n\nDictionaries still give the subsidiary meaning, "a similar point of intensity of emotional excitement," but as of 2005 this usage has become obscure. It can be startling to modern readers when encountered in older literature.\n\n\nGeneral\nBoth males and females can experience orgasm, but the exact response varies across gender. Generally speaking, orgasm is the third stage of four in the human sexual response cycle, which is the currently accepted model of the physiological process of sexual stimulation',
  'You asked two questions that you might think are the same; but they are not. Orgasm is an experience resulting from the combine

In [11]:
with open("/kaggle/input/nlp-healthcare/valid.json") as f:
        valid = json.load(f)

In [12]:
len(valid)

959

In [13]:
valid[0]

{'uri': '3392171',
 'question': 'do braces hurt????',
 'context': 'pain?\nhard to talk?',
 'answers': ["yes yes yes. But not horribly painful. you get used to them and they become easier to talk with. the pain is only when they get tightened and until you get used to them. They give you wax to put around the spots that hurt you when your tongue rubs against the brackets and you won't feel the pain.",
  "They hurt for a bit when you first get them.They feel tight.Then, they settle down. However, they hurt each time you get them adjusted for a few days afterward. But,dont worry,you'll get used to the pain.",
  'yes yes',
  'they hurt when u first get them but then u get used 2 them really  easily. sometimes u like drool or spit or something but unless u have a pallete expander u can talk normal. if u have 1 ur "K\'s" get all screwed up.',
  'Yes THEy HuRt WhEn YoU fIrSt GeT them but just for a few days...If you eat something hard and it gets stuck in them it mite hurt...and also after yo

###  ****Perspective-Controlled Healthcare Text Summarization using Multi-Task DeBERTa-BART Framework****

In [12]:
!pip install transformers datasets rouge-score bert-score nltk

Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading fsspec-2024.12.0-py3-none-any.whl (183 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: fsspec
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.12.0 which is incompatible.
bigframes 1.36.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.[0m[31m
[0mSuccessfully installed fsspec-2024.12.0


In [19]:
import json
import torch
from transformers import (
    DebertaV2Tokenizer, 
    DebertaV2ForSequenceClassification,
    BartForConditionalGeneration,
    AutoTokenizer
)
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import numpy as np
from rouge_score import rouge_scorer
import bert_score
from nltk.translate.meteor_score import meteor_score
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm

In [20]:
# Configuration
PERSPECTIVES = ["INFORMATION", "QUESTION", "EXPERIENCE", "SUGGESTION", "CAUSE"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4
CLASSIFIER_EPOCHS = 3
GENERATOR_EPOCHS = 5
MAX_LENGTH = 1024

In [21]:
# 1. Dataset Class
class HealthcareDataset(Dataset):
    def __init__(self, filepath):
        with open(filepath) as f:
            self.data = json.load(f)
        
        # Filter items with labelled summaries
        self.data = [item for item in self.data 
                    if item.get("labelled_summaries") and len(item["labelled_summaries"]) > 0]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        answers = " ".join([ans.replace("\n", " ") for ans in item["answers"]])
        perspective = list(item["labelled_summaries"].keys())[0].split("_")[0]
        summary = list(item["labelled_summaries"].values())[0]
        
        return {
            "text": answers,
            "perspective": PERSPECTIVES.index(perspective),
            "summary": summary,
            "perspective_name": perspective
        }

In [22]:
# 2. Model Initialization
classifier_tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-base")
classifier = DebertaV2ForSequenceClassification.from_pretrained(
    "microsoft/deberta-v3-base",
    num_labels=len(PERSPECTIVES)
).to(DEVICE)

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(DEVICE)

In [28]:
import os
from datetime import datetime

# Create output directories
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs(f"models/{timestamp}/classifier", exist_ok=True)
os.makedirs(f"models/{timestamp}/generator", exist_ok=True)

In [34]:
def train_classifier():
    train_data = HealthcareDataset("/kaggle/input/nlp-healthcare/train.json")
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    
    optimizer = AdamW(classifier.parameters(), lr=2e-5)
    classifier.train()
    
    best_loss = float('inf')
    
    for epoch in range(CLASSIFIER_EPOCHS):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in tqdm(train_loader, desc=f"Classifier Epoch {epoch+1}"):
            inputs = classifier_tokenizer(
                batch["text"],
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            ).to(DEVICE)
            
            labels = torch.tensor(batch["perspective"]).to(DEVICE)
            
            optimizer.zero_grad()
            outputs = classifier(**inputs, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            correct += (torch.argmax(outputs.logits, dim=1) == labels).sum().item()
            total += labels.size(0)
        
        avg_loss = total_loss / len(train_loader)
        accuracy = correct / total
        
        print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Save model after each epoch
        epoch_dir = f"models/{timestamp}/classifier/epoch_{epoch+1}"
        os.makedirs(epoch_dir, exist_ok=True)
        classifier.save_pretrained(epoch_dir)
        classifier_tokenizer.save_pretrained(epoch_dir)
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            classifier.save_pretrained(f"models/{timestamp}/classifier/best")
            classifier_tokenizer.save_pretrained(f"models/{timestamp}/classifier/best")

def train_generator():
    train_data = HealthcareDataset("/kaggle/input/nlp-healthcare/train.json")
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    
    optimizer = AdamW(generator.parameters(), lr=3e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5)
    generator.train()
    
    best_loss = float('inf')
    
    for epoch in range(GENERATOR_EPOCHS):
        total_loss = 0
        total_bleu = 0
        
        for batch in tqdm(train_loader, desc=f"Generator Epoch {epoch+1}"):
            input_texts = [
                f"Summarize this {PERSPECTIVES[perspective].lower()} content: {text}"
                for text, perspective in zip(batch["text"], batch["perspective"])
            ]
            
            inputs = generator_tokenizer(
                input_texts,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            ).to(DEVICE)
            
            targets = generator_tokenizer(
                batch["summary"],
                padding=True,
                truncation=True,
                max_length=150,
                return_tensors="pt"
            ).to(DEVICE)
            
            optimizer.zero_grad()
            outputs = generator(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                labels=targets["input_ids"]
            )
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
            # Calculate batch BLEU for monitoring
            with torch.no_grad():
                generated_ids = generator.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=150
                )
                generated_summaries = generator_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                bleu = corpus_bleu([[ref.split()] for ref in batch["summary"]], 
                                  [gen.split() for gen in generated_summaries])
                total_bleu += bleu
        
        avg_loss = total_loss / len(train_loader)
        avg_bleu = total_bleu / len(train_loader)
        scheduler.step(avg_loss)
        
        print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, BLEU: {avg_bleu:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Save model after each epoch
        epoch_dir = f"models/{timestamp}/generator/epoch_{epoch+1}"
        os.makedirs(epoch_dir, exist_ok=True)
        generator.save_pretrained(epoch_dir)
        generator_tokenizer.save_pretrained(epoch_dir)
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            generator.save_pretrained(f"models/{timestamp}/generator/best")
            generator_tokenizer.save_pretrained(f"models/{timestamp}/generator/best")

In [35]:
# 4. Evaluation Function
def evaluate(filepath):
    dataset = HealthcareDataset(filepath)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE)
    
    results = {p: {"R1": [], "R2": [], "RL": [], "BERTScore": [], "METEOR": [], "BLEU": []}
               for p in PERSPECTIVES}
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    classifier.eval()
    generator.eval()
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            # Classify perspectives
            cls_inputs = classifier_tokenizer(
                batch["text"],
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            ).to(DEVICE)
            
            cls_outputs = classifier(**cls_inputs)
            pred_perspectives = torch.argmax(cls_outputs.logits, dim=1).cpu().numpy()
            
            # Generate summaries
            input_texts = [
                f"Summarize this {PERSPECTIVES[p].lower()} content: {text}"
                for text, p in zip(batch["text"], pred_perspectives)
            ]
            
            gen_inputs = generator_tokenizer(
                input_texts,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            ).to(DEVICE)
            
            generated_ids = generator.generate(
                input_ids=gen_inputs["input_ids"],
                attention_mask=gen_inputs["attention_mask"],
                max_length=150
            )
            
            generated_summaries = generator_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Calculate metrics for each item
            for i, (ref, gen, true_perspective) in enumerate(zip(
                batch["summary"], generated_summaries, batch["perspective_name"])):
                
                rouge = scorer.score(ref, gen)
                bert_score_val = bert_score.score([gen], [ref], lang="en")[2].mean()
                meteor = meteor_score([ref.split()], gen.split())
                bleu = corpus_bleu([[ref.split()]], [gen.split()])
                
                results[true_perspective]["R1"].append(rouge["rouge1"].fmeasure)
                results[true_perspective]["R2"].append(rouge["rouge2"].fmeasure)
                results[true_perspective]["RL"].append(rouge["rougeL"].fmeasure)
                results[true_perspective]["BERTScore"].append(bert_score_val.item())
                results[true_perspective]["METEOR"].append(meteor)
                results[true_perspective]["BLEU"].append(bleu)
    
    # Aggregate results
    final_results = {}
    for p in PERSPECTIVES:
        if results[p]["R1"]:
            final_results[p] = {
                "R1": np.mean(results[p]["R1"]),
                "R2": np.mean(results[p]["R2"]),
                "RL": np.mean(results[p]["RL"]),
                "BERTScore": np.mean(results[p]["BERTScore"]),
                "METEOR": np.mean(results[p]["METEOR"]),
                "BLEU": np.mean(results[p]["BLEU"])
            }
    
    return final_results

In [36]:
# 5. Main Execution
if __name__ == "__main__":
    # Train models
    print("Training classifier...")
    train_classifier()
    
    print("\nTraining generator...")
    train_generator()
    
    # Evaluate on all splits
    print("\nEvaluating on validation set...")
    valid_results = evaluate("/kaggle/input/nlp-healthcare/valid.json")
    
    print("\nEvaluating on test set...")
    test_results = evaluate("/kaggle/input/nlp-healthcare/test.json")
    
    # Print results
    def print_results(name, results):
        print(f"\n{name} Results:")
        print("| Perspective | R1 | R2 | RL | BERTScore | METEOR | BLEU |")
        print("|---|---|---|---|---|---|---|")
        for p in PERSPECTIVES:
            if p in results:
                print(f"| {p} | {results[p]['R1']:.2f} | {results[p]['R2']:.2f} | {results[p]['RL']:.2f} | "
                      f"{results[p]['BERTScore']:.3f} | {results[p]['METEOR']:.3f} | {results[p]['BLEU']:.3f} |")
    
    print_results("Validation", valid_results)
    print_results("Test", test_results)
    
    # Save results
    with open("results.json", "w") as f:
        json.dump({
            "validation": valid_results,
            "test": test_results
        }, f, indent=2)

Training classifier...


  labels = torch.tensor(batch["perspective"]).to(DEVICE)
Classifier Epoch 1: 100%|██████████| 559/559 [04:47<00:00,  1.95it/s]


Epoch 1 - Loss: 1.1323, Accuracy: 0.5933


Classifier Epoch 2: 100%|██████████| 559/559 [04:44<00:00,  1.97it/s]


Epoch 2 - Loss: 1.0279, Accuracy: 0.6228


Classifier Epoch 3: 100%|██████████| 559/559 [04:46<00:00,  1.95it/s]


Epoch 3 - Loss: 0.9128, Accuracy: 0.6591

Training generator...


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Generator Epoch 1: 100%|██████████| 559/559 [14:33<00:00,  1.56s/it]


Epoch 1 - Loss: 1.8698, BLEU: 0.0475, LR: 3.00e-05


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Generator Epoch 2: 100%|██████████| 559/559 [15:56<00:00,  1.71s/it]


Epoch 2 - Loss: 1.2092, BLEU: 0.0773, LR: 3.00e-05


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Generator Epoch 3: 100%|██████████| 559/559 [16:30<00:00,  1.77s/it]


Epoch 3 - Loss: 1.0309, BLEU: 0.0883, LR: 3.00e-05


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Generator Epoch 4: 100%|██████████| 559/559 [16:27<00:00,  1.77s/it]


Epoch 4 - Loss: 0.9185, BLEU: 0.1038, LR: 3.00e-05


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Generator Epoch 5: 100%|██████████| 559/559 [16:44<00:00,  1.80s/it]


Epoch 5 - Loss: 0.8236, BLEU: 0.1165, LR: 3.00e-05

Evaluating on validation set...


Evaluating:   0%|          | 0/240 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction(


Evaluating on test set...


Evaluating:   0%|          | 0/160 [00:00<?, ?it/s]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: [


Validation Results:
| Perspective | R1 | R2 | RL | BERTScore | METEOR | BLEU |
|---|---|---|---|---|---|---|
| INFORMATION | 0.44 | 0.22 | 0.33 | 0.894 | 0.298 | 0.087 |
| QUESTION | 0.18 | 0.04 | 0.14 | 0.855 | 0.147 | 0.007 |
| EXPERIENCE | 0.20 | 0.05 | 0.14 | 0.849 | 0.129 | 0.009 |
| SUGGESTION | 0.30 | 0.13 | 0.23 | 0.875 | 0.192 | 0.048 |
| CAUSE | 0.29 | 0.13 | 0.24 | 0.876 | 0.251 | 0.063 |

Test Results:
| Perspective | R1 | R2 | RL | BERTScore | METEOR | BLEU |
|---|---|---|---|---|---|---|
| INFORMATION | 0.44 | 0.21 | 0.33 | 0.893 | 0.297 | 0.084 |
| QUESTION | 0.16 | 0.03 | 0.12 | 0.849 | 0.146 | 0.007 |
| EXPERIENCE | 0.20 | 0.05 | 0.14 | 0.848 | 0.131 | 0.007 |
| SUGGESTION | 0.31 | 0.13 | 0.24 | 0.877 | 0.191 | 0.042 |
| CAUSE | 0.31 | 0.15 | 0.25 | 0.875 | 0.264 | 0.070 |



