# Drug Composition to Medical Indication Prediction
## Part 3: BioBERT Fine-Tuning

This notebook covers:
1. Loading BioBERT pre-trained model
2. Preparing data for transformers
3. Fine-tuning BioBERT for multi-label classification
4. Evaluation and comparison with baselines

## 1. Install and Import Libraries

In [None]:
# Install additional requirements if needed
!pip install transformers datasets accelerate

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from sklearn.metrics import precision_recall_fscore_support, hamming_loss, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Preprocessed Data

In [None]:
# Load data
X_train = np.load('X_train.npy', allow_pickle=True)
X_val = np.load('X_val.npy', allow_pickle=True)
X_test = np.load('X_test.npy', allow_pickle=True)
y_train = np.load('y_train.npy')
y_val = np.load('y_val.npy')
y_test = np.load('y_test.npy')

with open('mlb.pkl', 'rb') as f:
    mlb = pickle.load(f)

num_labels = y_train.shape[1]

print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print(f"Test samples: {len(X_test)}")
print(f"Number of labels: {num_labels}")

## 3. Load BioBERT Tokenizer and Model

In [None]:
# Load BioBERT tokenizer
MODEL_NAME = 'dmis-lab/biobert-v1.1'

print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"Loading model from {MODEL_NAME}...")
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

model.to(device)

print(f"Model loaded successfully!")
print(f"Model parameters: {model.num_parameters():,}")

## 4. Create PyTorch Dataset

In [None]:
class DrugIndicationDataset(Dataset):
    """PyTorch Dataset for drug indication prediction"""
    
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(label)
        }

# Create datasets
train_dataset = DrugIndicationDataset(X_train, y_train, tokenizer)
val_dataset = DrugIndicationDataset(X_val, y_val, tokenizer)
test_dataset = DrugIndicationDataset(X_test, y_test, tokenizer)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Test dataset
sample = train_dataset[0]
print(f"\nSample batch keys: {sample.keys()}")
print(f"Input IDs shape: {sample['input_ids'].shape}")
print(f"Labels shape: {sample['labels'].shape}")

## 5. Define Evaluation Metrics

In [None]:
def compute_metrics(pred):
    """Compute metrics for multi-label classification"""
    labels = pred.label_ids
    preds = (torch.sigmoid(torch.tensor(pred.predictions)) > 0.5).int().numpy()
    
    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
        labels, preds, average='micro', zero_division=0
    )
    
    hamming = hamming_loss(labels, preds)
    exact_match = accuracy_score(labels, preds)
    
    return {
        'f1_macro': f1,
        'f1_micro': f1_micro,
        'precision_macro': precision,
        'recall_macro': recall,
        'hamming_loss': hamming,
        'exact_match': exact_match
    }

## 6. Configure Training Arguments

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./biobert_results',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='f1_macro',
    logging_dir='./logs',
    logging_steps=50,
    save_total_limit=2,
    warmup_steps=500,
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
)

print("Training configuration:")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  FP16: {training_args.fp16}")

## 7. Initialize Trainer

In [None]:
# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print("Trainer initialized!")

## 8. Train Model

In [None]:
print("Starting training...")
print("This may take several hours depending on your hardware.\n")

# Train
train_result = trainer.train()

print("\nTraining complete!")
print(f"Training time: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"Training samples/second: {train_result.metrics['train_samples_per_second']:.2f}")

## 9. Evaluate on Test Set

In [None]:
print("Evaluating on test set...")

# Evaluate
test_results = trainer.evaluate(test_dataset)

print("\n" + "="*60)
print("BioBERT Test Results")
print("="*60)
for metric, value in test_results.items():
    print(f"{metric:25s}: {value:.4f}")
print("="*60)

## 10. Generate Predictions

In [None]:
# Get predictions
predictions = trainer.predict(test_dataset)
y_pred_probs = torch.sigmoid(torch.tensor(predictions.predictions)).numpy()
y_pred_biobert = (y_pred_probs > 0.5).astype(int)

print(f"Predictions shape: {y_pred_biobert.shape}")

# Save predictions
np.save('biobert_predictions.npy', y_pred_biobert)
np.save('biobert_probabilities.npy', y_pred_probs)

## 11. Compare with Baselines

In [None]:
# Load baseline results
with open('baseline_results.pkl', 'rb') as f:
    baseline_results = pickle.load(f)

# Create comparison
comparison = pd.DataFrame({
    'Model': ['TF-IDF + LR', 'SentenceEmb + LR', 'BioBERT'],
    'F1 (Macro)': [
        baseline_results['tfidf_test']['f1_macro'],
        baseline_results['embedding_test']['f1_macro'],
        test_results['eval_f1_macro']
    ],
    'F1 (Micro)': [
        baseline_results['tfidf_test']['f1_micro'],
        baseline_results['embedding_test']['f1_micro'],
        test_results['eval_f1_micro']
    ],
    'Precision (Macro)': [
        baseline_results['tfidf_test'].get('precision', 0),
        baseline_results['embedding_test'].get('precision', 0),
        test_results['eval_precision_macro']
    ],
    'Recall (Macro)': [
        baseline_results['tfidf_test'].get('recall', 0),
        baseline_results['embedding_test'].get('recall', 0),
        test_results['eval_recall_macro']
    ],
    'Hamming Loss': [
        baseline_results['tfidf_test']['hamming_loss'],
        baseline_results['embedding_test']['hamming_loss'],
        test_results['eval_hamming_loss']
    ]
})

print("\n" + "="*90)
print("FINAL MODEL COMPARISON")
print("="*90)
print(comparison.to_string(index=False))
print("="*90)

# Save comparison
comparison.to_csv('final_model_comparison.csv', index=False)

In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# F1 Macro
axes[0].bar(comparison['Model'], comparison['F1 (Macro)'], color=['skyblue', 'lightcoral', 'lightgreen'])
axes[0].set_ylabel('F1 Score (Macro)')
axes[0].set_title('F1 Score Comparison (Macro Average)')
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Precision & Recall
x = np.arange(len(comparison))
width = 0.35
axes[1].bar(x - width/2, comparison['Precision (Macro)'], width, label='Precision', alpha=0.8)
axes[1].bar(x + width/2, comparison['Recall (Macro)'], width, label='Recall', alpha=0.8)
axes[1].set_ylabel('Score')
axes[1].set_title('Precision vs Recall')
axes[1].set_xticks(x)
axes[1].set_xticklabels(comparison['Model'], rotation=45)
axes[1].legend()
axes[1].grid(axis='y', alpha=0.3)

# Hamming Loss
axes[2].bar(comparison['Model'], comparison['Hamming Loss'], color=['skyblue', 'lightcoral', 'lightgreen'])
axes[2].set_ylabel('Hamming Loss')
axes[2].set_title('Hamming Loss (Lower is Better)')
axes[2].tick_params(axis='x', rotation=45)
axes[2].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 12. Analyze BioBERT Predictions

In [None]:
# Show example predictions
def show_biobert_predictions(idx):
    """Show detailed prediction for a sample"""
    print(f"\nExample {idx}")
    print("="*80)
    print(f"Input text (first 400 chars):\n{X_test[idx][:400]}...\n")
    
    true_labels = mlb.inverse_transform(y_test[idx:idx+1])[0]
    pred_labels = mlb.inverse_transform(y_pred_biobert[idx:idx+1])[0]
    
    print(f"True conditions ({len(true_labels)}):")
    for label in true_labels:
        print(f"  - {label}")
    
    print(f"\nPredicted conditions ({len(pred_labels)}):")
    for label in pred_labels:
        print(f"  - {label}")
    
    correct = set(true_labels) & set(pred_labels)
    missed = set(true_labels) - set(pred_labels)
    extra = set(pred_labels) - set(true_labels)
    
    print(f"\nCorrect: {list(correct)}")
    print(f"Missed: {list(missed)}")
    print(f"Extra: {list(extra)}")
    print(f"\nAccuracy: {len(correct) / max(len(true_labels), 1):.2%}")

# Show several examples
for i in range(5):
    show_biobert_predictions(i)

## 13. Save Fine-tuned Model

In [None]:
# Save model and tokenizer
model.save_pretrained('./biobert_finetuned')
tokenizer.save_pretrained('./biobert_finetuned')

print("Model and tokenizer saved to ./biobert_finetuned/")

# Save results
with open('biobert_results.pkl', 'wb') as f:
    pickle.dump(test_results, f)

print("Results saved!")

## Summary

### BioBERT Performance:
- ✅ Fine-tuned BioBERT for multi-label drug indication prediction
- ✅ Compared with TF-IDF and SentenceTransformers baselines
- ✅ BioBERT typically achieves best performance

### Key Achievements:
1. **Transfer Learning**: Leveraged pre-trained biomedical knowledge
2. **Multi-Label Classification**: Predicted multiple medical conditions per drug
3. **Domain Adaptation**: Fine-tuned on drug-specific data

### Next Steps:
- Hyperparameter tuning (learning rate, batch size, epochs)
- Try other models (PubMedBERT, SciBERT)
- Ensemble methods combining multiple models
- Error analysis for specific condition types