# Fine tuning FLAN-T5 on Drug Discovery

Author: [Khairi Abidi](https://github.com/abidikhairi)

This notebook demonstrates finetuning of FLAN-T5 for drug discovery instruction.

Key Features:

- Memory Efficient: LoRA for consumer GPUs

The model learns to recommend adequate drugs based on gene context.

## Installation and Setup

Install the required packages for finetuning with memory-efficient techniques.

In [1]:
%%capture
!pip install --quiet transformers datasets trl bitsandbytes peft trackio

In [2]:
%env CUDA_VISIBLE_DEVICES=0,1

env: CUDA_VISIBLE_DEVICES=0,1


In [4]:
%env WANDB_PROJECT=Drug-FLAN

env: WANDB_PROJECT=Drug-FLAN


## Connect to 3rd party services

- **WandB**: for experiment tracking.
- **HuggingFace Hub**: for model checkpoints uploading.

In [6]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGING_FACE_TOKEN")
wandb_token = user_secrets.get_secret("WANDB_API_KEY")

In [7]:
!wandb login {wandb_token}

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `hf`CLI if you want to set the git credential as well.
Token is valid (permission: write).
The token `KAGGLE_TOKEN` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `KAGGLE_TOKEN`


## GPU Environment Detection
Verify GPU availability and display hardware specifications for optimal training configuration.

In [9]:
import torch

# Verify CUDA availability and display GPU specifications
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    # Display current GPU details for training optimization
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    # Provide guidance for enabling GPU in Colab
    print("⚠️  No GPU available. This notebook requires a GPU for efficient training.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

CUDA available: False
Number of GPUs: 0
⚠️  No GPU available. This notebook requires a GPU for efficient training.
In Colab: Runtime → Change runtime type → Hardware accelerator → GPU


## Core Library Imports
Import essential libraries for finetuning, model configuration, and experiment tracking.

In [10]:
# Model and tokenization
from transformers import (
    T5ForConditionalGeneration,    # Seq2Seq language model loading
    T5Tokenizer,                   # Text tokenization
    DataCollatorForSeq2Seq,        # Batch inputs handling
)

# Model optimization
from torch.optim import AdamW
from transformers import get_scheduler

# Training and Setup
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

# Dataset handling
from datasets import load_dataset

# Logging configuration
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

2025-09-25 08:22:38.349056: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758788558.605203      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758788558.690544      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [13]:
model_name = "google/flan-t5-base"
max_seq_length = 1024

print(f"Loading model: {model_name}")
print(f"Max sequence length: {max_seq_length}")

Loading model: google/flan-t5-base
Max sequence length: 1024


In [15]:
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    trust_remote_code=True,               # Allow custom model code execution
    dtype=torch.float16,                 # Use FP16 for non-quantized operations
)

In [17]:
tokenizer = T5Tokenizer.from_pretrained(model_name, trust_remote_code=True)

# Ensure tokenizer has proper padding token for batch processing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [20]:
print(f"✅ Model loaded successfully!")
print(f"📊 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"🧮 Quantized parameters: ~{sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.1f}M")

✅ Model loaded successfully!
📊 Model parameters: ~247.6M
🧮 Quantized parameters: ~0.0M


In [21]:
def compute_model_size(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.nelement() * p.element_size()
    for p in model.buffers():
        n_params += p.nelement() * p.element_size()

    return n_params / (1024 ** 3)

print(f"📊 Model size : {compute_model_size(model):.2f} GB")

📊 Model size : 0.53 GB


## Dataset Setup
Configure the Drug discovery instruction dataset.

In [32]:
def tokenize_dataset_example(examples):
    model_inputs = tokenizer(examples['input'], return_tensors='pt', padding=True)
    labels = tokenizer(examples['target'], return_tensors='pt', padding=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

print("✅ Dataset tokenization function defined (batch mode)")

✅ Dataset tokenization function defined (batch mode)


In [33]:
def filter_dataset_example(example):
    return example['input_ids'].shape[1] <= 512
    
print("✅ Dataset filter functions defined")

✅ Dataset filter functions defined


In [35]:
# Load and preprocess DrugInstruct training dataset
print("🔄 Loading DrugInstruct dataset...")
dataset = load_dataset("khairi/drug-discovery-hetionet")

# Apply conversation formatting to all examples
dataset = dataset.map(tokenize_dataset_example, batched=True, batch_size=32)

train_data = dataset['train']
valid_data = dataset['validation']

print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(train_data):,}")
print(f"📊 Validation examples: {len(valid_data):,}")
print(f"🎯 Sample protein: {tokenizer.decode(train_data[0]['input_ids'])}")

🔄 Loading DrugInstruct dataset...
✅ Dataset loaded and processed!
📊 Training examples: 1,589
📊 Validation examples: 100
🎯 Sample protein: Predict drugs for the central gene based on its neighborhood: Central node: NAT2 One-hop neighbors: <unk> process, cellular response to xenobiotic stimulus>, <unk> process, xenobiotic metabolic process>, <unk> process, response to xenobiotic stimulus> <unk> molecular function, acetyltransferase activity>, <unk> molecular function, N-acyltransferase activity>, <unk> molecular function, transferase activity, transferring acyl groups>, <unk> molecular function, N-acetyltransferase activity>, <unk> molecular function, arylamine N-acetyltransferase activity> Answer:</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

## Training Setup
Configure training parameters optimized for finetuning with memory constraints.

In [37]:
# Prepare data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [43]:
# Configure Finetuning training parameters for sequence modeling
training_args = Seq2SeqTrainingArguments(
    # Memory-efficient batch configuration
    per_device_train_batch_size=8,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=8,   # Effective batch size = 8 * 4 = 32
       
    # Training duration and monitoring
    max_steps=50,                     # Short demo run (increase to 500+ for production)
    logging_steps=5,                  # Log metrics every step for close monitoring
    save_steps=5,
    eval_steps=5,
    eval_strategy='steps',
    # num_train_epochs=1,            # Uncomment for production
    
    # Stability and output configuration
    output_dir="./finetuning_outputs",
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    report_to="none",                # use wandb for experiment tracking
    run_name='flan-base-hetionet-instruct',

    # Push to Hub, uncomment in production
    #push_to_hub=True,
    #hub_model_id='khairi/Shizuku-0.5B'
    
)

### Optimizer and Scheduler Setup
Configure **AdamW** optimizer with for model parameters.

Then sets up a cosine learning rate scheduler with a short warmup (fast warmup/slow cooldown) and total training steps.

In [46]:
print("🚀 Initializing optimizer...")
print("🌙 Setting up cosine LR scheduler...")

optimizer = AdamW([
        {'params': model.parameters(), 'lr': 3e-4},
    ],
    betas=(0.99, 0.98),
    weight_decay=0.01
)

lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=50,
    num_training_steps=1000
)

print("✅ Optimizer ready!")
print("✨ LR scheduler ready!")

🚀 Initializing optimizer...
🌙 Setting up cosine LR scheduler...
✅ Optimizer ready!
✨ LR scheduler ready!


In [47]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    optimizers=(optimizer, lr_scheduler)
)

  trainer = Seq2SeqTrainer(


In [None]:
# Execute FineTuning
print("🚀 Starting Finetuning...")

# Run the training process
trainer.train()

print("✅ Training completed successfully!")
print(f"💾 Model saved to: {training_args.output_dir}")

🚀 Starting Finetuning...
