# Fine tuning FLAN-T5 on Drug Discovery

Author: [Khairi Abidi](https://github.com/abidikhairi)

This notebook demonstrates finetuning of FLAN-T5 for drug discovery instruction.

Key Features:

- Memory Efficient: LoRA for consumer GPUs

The model learns to recommend adequate drugs based on gene context.

In [None]:
%env CUDA_VISIBLE_DEVICES=0,1

In [None]:
%env WANDB_PROJECT=Drug-FLAN

## Connect to 3rd party services

- **WandB**: for experiment tracking.
- **HuggingFace Hub**: for model checkpoints uploading.

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGING_FACE_TOKEN")
wandb_token = user_secrets.get_secret("WANDB_API_KEY")

In [None]:
!wandb login {wandb_token}

In [None]:
!huggingface-cli login --token {hf_token}

## GPU Environment Detection
Verify GPU availability and display hardware specifications for optimal training configuration.

In [None]:
import torch

# Verify CUDA availability and display GPU specifications
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    # Display current GPU details for training optimization
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    # Provide guidance for enabling GPU in Colab
    print("⚠️  No GPU available. This notebook requires a GPU for efficient training.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

## Core Library Imports
Import essential libraries for finetuning, model configuration, and experiment tracking.

In [None]:
# Model and tokenization
from transformers import (
    T5ForConditionalGeneration,    # Seq2Seq language model loading
    T5Tokenizer,                   # Text tokenization
    DataCollatorForSeq2Seq,        # Batch inputs handling
)

# Model optimization
from torch.optim import AdamW
from transformers import get_scheduler

# Training and Setup
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

# Dataset handling
from datasets import load_dataset

# Logging configuration
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
model_name = "google/flan-t5-base"
max_seq_length = 512

print(f"Loading model: {model_name}")
print(f"Max sequence length: {max_seq_length}")

In [None]:
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    trust_remote_code=True,               # Allow custom model code execution
    # dtype=torch.float16,                  # Use FP16 for non-quantized operations
)

In [None]:
tokenizer = T5Tokenizer.from_pretrained(model_name, trust_remote_code=True)

# Ensure tokenizer has proper padding token for batch processing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
print(f"✅ Model loaded successfully!")
print(f"📊 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"🧮 Quantized parameters: ~{sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.1f}M")

In [None]:
def compute_model_size(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.nelement() * p.element_size()
    for p in model.buffers():
        n_params += p.nelement() * p.element_size()

    return n_params / (1024 ** 3)

print(f"📊 Model size : {compute_model_size(model):.2f} GB")

## Dataset Setup
Configure the Drug discovery instruction dataset.

In [None]:
def tokenize_dataset_example(examples):
    model_inputs = tokenizer(examples['input'], return_tensors='pt', padding=True)
    labels = tokenizer(examples['target'], return_tensors='pt', padding=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

print("✅ Dataset tokenization function defined (batch mode)")

In [None]:
def filter_dataset_example(example):
    return len(example['input_ids']) < 512
    
print("✅ Dataset filter functions defined")

In [None]:
# Load and preprocess DrugInstruct training dataset
print("🔄 Loading DrugInstruct dataset...")
dataset = load_dataset("khairi/drug-discovery-hetionet")

# Apply conversation formatting to all examples
dataset = dataset.map(tokenize_dataset_example, batched=True, batch_size=4)
    # .filter(filter_dataset_example)

train_data = dataset['train']
valid_data = dataset['validation']

print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(train_data):,}")
print(f"📊 Validation examples: {len(valid_data):,}")
print(f"🎯 Sample protein: {tokenizer.decode(train_data[0]['input_ids'])}")

## Training Setup
Configure training parameters optimized for finetuning with memory constraints.

In [None]:
train_data

In [None]:
valid_data

In [None]:
# Prepare data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
# Configure Finetuning training parameters for sequence modeling
training_args = Seq2SeqTrainingArguments(
    # Memory-efficient batch configuration
    per_device_train_batch_size=8,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=4,   # Effective batch size = 8 * 4 = 16
       
    # Training duration and monitoring
    # max_steps=50,                     # Short demo run (increase to 500+ for production)
    logging_steps=100,                  # Log metrics every step for close monitoring
    save_steps=100,
    eval_steps=100,
    eval_strategy='steps',
    num_train_epochs=5,            # Uncomment for production
    
    save_total_limit = 3,
    
    # Stability and output configuration
    output_dir="./finetuning_outputs",
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    report_to="none",                # use wandb for experiment tracking
    run_name='flan-base-hetionet-instruct',

    # Push to Hub, uncomment in production
    push_to_hub=True,
    hub_model_id='khairi/Shizuku-0.5B'
    
)

### Optimizer and Scheduler Setup
Configure **AdamW** optimizer with for model parameters.

Then sets up a cosine learning rate scheduler with a short warmup (fast warmup/slow cooldown) and total training steps.

In [None]:
print("🚀 Initializing optimizer...")
print("🌙 Setting up cosine LR scheduler...")

optimizer = AdamW([
        {'params': model.parameters(), 'lr': 1e-4},
    ],
    betas=(0.99, 0.98),
    weight_decay=0.01
)

lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=10,
    num_training_steps=50
)

print("✅ Optimizer ready!")
print("✨ LR scheduler ready!")

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    optimizers=(optimizer, lr_scheduler)
)

In [None]:
# Execute FineTuning
print("🚀 Starting Finetuning...")

# Run the training process
trainer.train()

print("✅ Training completed successfully!")
print(f"💾 Model saved to: {training_args.output_dir}")

In [None]:
row = valid_data[0]

In [None]:
row.keys()

In [None]:
inputs = tokenizer(row['input'], return_tensors='pt')

inputs = {k: v.to(model.device) for k, v in inputs.items()}

In [None]:
output = model.generate(**inputs)

In [None]:
print(tokenizer.decode(output[0]))

In [None]:
print(row['target'])