In [None]:
import gc
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import Dataset
from sklearn.model_selection import train_test_split
import pandas as pd

# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Use a pipeline as a high-level helper


pipe = pipeline("text-generation", model="distilbert/distilgpt2")

In [None]:

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load model directly
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
model.to(device)
model.gradient_checkpointing_enable()

In [None]:
def tokenize_function(examples):
    """Tokenize texts for language modeling"""
    # Tokenize
    tokenized = tokenizer(
        examples['medical_abstract'],
        truncation=True,
        max_length=512,  # context length
        padding='max_length',
        return_tensors=None
    )

    # For language modeling, labels are the same as input_ids
    tokenized['labels'] = tokenized['input_ids'].copy()

    return tokenized

# Load the full training data
splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df_train = pd.read_parquet("hf://datasets/TimSchopf/medical_abstracts/" + splits["train"])
df_test = pd.read_parquet("hf://datasets/TimSchopf/medical_abstracts/" + splits["test"])

# Split training data into train and validation
train_df, val_df = train_test_split(df_train, test_size=0.1, random_state=42)

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(df_test)

# Set pad token for tokenizer (DistilGPT-2 tokenizer doesn't have one by default)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Tokenize
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=['condition_label', 'medical_abstract'])
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=['condition_label', 'medical_abstract'])
test_dataset = test_dataset.map(tokenize_function, batched=True, remove_columns=['condition_label', 'medical_abstract'])

print("Tokenization complete!")

In [None]:
def next_token_accuracy(model, tokenizer, test_dataset, n_samples=100):
    """
    How often is the correct next token in top-k predictions?
    """
    model.eval()

    top1_correct = 0
    top5_correct = 0
    top10_correct = 0
    total = 0

    with torch.no_grad():
        for i in range(min(n_samples, len(test_dataset))):
            example = test_dataset[i]
            input_ids = torch.tensor([example['input_ids']]).to("cuda")

            # Get predictions for each position
            outputs = model(input_ids)
            logits = outputs.logits

            # Check each token (except first and last)
            for pos in range(1, len(input_ids[0]) - 1):
                if input_ids[0, pos] == tokenizer.pad_token_id:
                    continue

                true_token = input_ids[0, pos + 1].item()
                predicted_logits = logits[0, pos, :]

                # Top-k predictions
                top_k = torch.topk(predicted_logits, k=10)
                top_tokens = top_k.indices.tolist()

                if true_token == top_tokens[0]:
                    top1_correct += 1
                if true_token in top_tokens[:5]:
                    top5_correct += 1
                if true_token in top_tokens[:10]:
                    top10_correct += 1

                total += 1

    return {
        'top1_accuracy': top1_correct / total,
        'top5_accuracy': top5_correct / total,
        'top10_accuracy': top10_correct / total
    }

In [None]:
import math

def calculate_perplexity(model, dataset, tokenizer):
    """Calculate perplexity on a dataset"""
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for i in range(min(100, len(dataset))):  # Sample 100 examples
            example = dataset[i]
            inputs = {
                'input_ids': torch.tensor([example['input_ids']]).to(device),
                'attention_mask': torch.tensor([example['attention_mask']]).to(device),
                'labels': torch.tensor([example['labels']]).to(device)
            }

            outputs = model(**inputs)
            loss = outputs.loss

            # Count actual tokens (not padding)
            n_tokens = inputs['attention_mask'].sum().item()

            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)

    return perplexity

In [None]:


print("Evaluating baseline (pretrained) model...")
baseline_perplexity = calculate_perplexity(model, test_dataset, tokenizer)
print(f"Baseline perplexity: {baseline_perplexity:.2f}")

print("\n" + "="*60)
print("BASELINE GENERATION (Before Fine-tuning)")
print("="*60)

prompt = "The patient presented with"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
    **inputs,
    max_length=50,
    num_return_sequences=3,
    temperature=0.9,
    do_sample=True,
    top_p=0.95
)

for i, output in enumerate(outputs):
    text = tokenizer.decode(output, skip_special_tokens=True)
    print(f"\nSample {i+1}: {text}")

baseline_acc = next_token_accuracy(model, tokenizer, test_dataset)
print(f"{'Baseline':<15} {baseline_acc['top1_accuracy']:<10.1%} {baseline_acc['top5_accuracy']:<10.1%} {baseline_acc['top10_accuracy']:.1%}")



In [None]:
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # GPT-2 uses causal LM, not masked LM
)

In [None]:
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,                        # LoRA rank
    lora_alpha=32,              # LoRA scaling
    lora_dropout=0.1,
    target_modules=["c_attn", "c_proj"]  # Apply to attention layers
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Should show: trainable params: ~300K / 124M (0.24%)

# Move to GPU
model = model.to("cuda")

# Training arguments (can use higher batch size now!)
training_args = TrainingArguments(
    output_dir="./gpt2-lora-finetuned",
    num_train_epochs=5,
    per_device_train_batch_size=4,      # Can use 4 instead of 1!
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,                 # LoRA can use higher LR
    fp16=True,
    logging_steps=50,
    eval_steps=500,
    save_steps=500,
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    report_to="none",
)

# Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
)

print("Training with LoRA...")
trainer.train()

# Save LoRA adapters (only ~2MB instead of 500MB!)
model.save_pretrained("./gpt2-lora-final")

In [None]:
import matplotlib.pyplot as plt
import math
# Extract training history
log_history = trainer.state.log_history

# Separate train and eval logs
train_logs = [x for x in log_history if 'loss' in x and 'eval_loss' not in x]
eval_logs = [x for x in log_history if 'eval_loss' in x]

train_steps = [x['step'] for x in train_logs]
train_losses = [x['loss'] for x in train_logs]

eval_steps = [x['step'] for x in eval_logs]
eval_losses = [x['eval_loss'] for x in eval_logs]

# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1.plot(train_steps, train_losses, label='Train Loss', alpha=0.7)
ax1.plot(eval_steps, eval_losses, label='Validation Loss', marker='o', linewidth=2)
ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Loss')
ax1.set_title('Training Progress: Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Perplexity (exp(loss))
train_perplexities = [math.exp(loss) for loss in train_losses]
eval_perplexities = [math.exp(loss) for loss in eval_losses]

ax2.plot(train_steps, train_perplexities, label='Train Perplexity', alpha=0.7)
ax2.plot(eval_steps, eval_perplexities, label='Validation Perplexity', marker='o', linewidth=2)
ax2.set_xlabel('Training Steps')
ax2.set_ylabel('Perplexity')
ax2.set_title('Training Progress: Perplexity')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("LEARNING EVIDENCE")
print("="*60)
print(f"Initial train loss: {train_losses[0]:.4f}")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Improvement: {train_losses[0] - train_losses[-1]:.4f}")
print()
print(f"Initial val perplexity: {eval_perplexities[0]:.2f}")
print(f"Final val perplexity: {eval_perplexities[-1]:.2f}")
print(f"Improvement: {eval_perplexities[0] - eval_perplexities[-1]:.2f}")

In [None]:
print("="*60)
print("FINAL EVALUATION")
print("="*60)

# Calculate perplexity on test set
finetuned_perplexity = calculate_perplexity(model, test_dataset, tokenizer)

print(f"\nPerplexity Comparison:")
print(f"  Baseline (pretrained):  {baseline_perplexity:.2f}")
print(f"  Fine-tuned:             {finetuned_perplexity:.2f}")
print(f"  Improvement:            {baseline_perplexity - finetuned_perplexity:.2f}")
print(f"  Relative improvement:   {(1 - finetuned_perplexity/baseline_perplexity)*100:.1f}%")

In [None]:
def generate_comparison(prompt, model, tokenizer, num_samples=3):
    """Generate text with the model"""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = model.generate(
        **inputs,
        max_length=100,
        num_return_sequences=num_samples,
        temperature=0.9,
        do_sample=True,
        top_p=0.95,
        repetition_penalty=1.2,
        no_repeat_ngram_size=2
    )

    results = []
    for output in outputs:
        text = tokenizer.decode(output, skip_special_tokens=True)
        results.append(text)

    return results

# Test prompts (adjust for your domain)
test_prompts = [
    "The patient presented with",
    "Treatment options include",
    "The study found that",
]

print("\n" + "="*60)
print("GENERATION COMPARISON")
print("="*60)

for prompt in test_prompts:
    print(f"\n{'='*60}")
    print(f"PROMPT: '{prompt}'")
    print(f"{'='*60}")

    # Fine-tuned generations
    print("\nFINE-TUNED MODEL:")
    print("-" * 60)
    finetuned_gens = generate_comparison(prompt, model, tokenizer, num_samples=3)
    for i, text in enumerate(finetuned_gens, 1):
        print(f"\n{i}. {text}")



In [None]:
print("\n" + "="*60)
print("NEXT-TOKEN PREDICTION ACCURACY")
print("="*60)
print(f"{'Model':<15} {'Top-1':<10} {'Top-5':<10} {'Top-10'}")
print("-"*60)
print(f"{'Baseline':<15} {baseline_acc['top1_accuracy']:<10.1%} {baseline_acc['top5_accuracy']:<10.1%} {baseline_acc['top10_accuracy']:.1%}")

# Calculate accuracy for the fine-tuned model
finetuned_acc = next_token_accuracy(model, tokenizer, test_dataset)
print(f"{'Fine-tuned':<15} {finetuned_acc['top1_accuracy']:<10.1%} {finetuned_acc['top5_accuracy']:<10.1%} {finetuned_acc['top10_accuracy']:.1%}")