# Fine tuning FLAN-T5 on Drug Discovery

- Author: [Khairi Abidi](https://github.com/abidikhairi)
- Author: [Ranim Mechergui](https://github.com/ranim.mechergui)

This notebook demonstrates finetuning of FLAN-T5 for drug discovery instruction.

Key Features:

- Memory Efficient: LoRA for consumer GPUs

The model learns to recommend adequate drugs based on gene context.

In [1]:
%env CUDA_VISIBLE_DEVICES=0,1

env: CUDA_VISIBLE_DEVICES=0,1


In [2]:
%env WANDB_PROJECT=Drug-FLAN

env: WANDB_PROJECT=Drug-FLAN


## Connect to 3rd party services

- **WandB**: for experiment tracking.
- **HuggingFace Hub**: for model checkpoints uploading.

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGING_FACE_TOKEN")
wandb_token = user_secrets.get_secret("WANDB_API_KEY")

In [4]:
!wandb login {wandb_token}

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
The token `KAGGLE_TOKEN` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `KAGGLE_TOKEN`


## GPU Environment Detection
Verify GPU availability and display hardware specifications for optimal training configuration.

In [6]:
import torch

# Verify CUDA availability and display GPU specifications
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    # Display current GPU details for training optimization
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    # Provide guidance for enabling GPU in Colab
    print("⚠️  No GPU available. This notebook requires a GPU for efficient training.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

CUDA available: True
Number of GPUs: 2
Current GPU: 0
GPU name: Tesla T4
GPU memory: 15.8 GB


## Core Library Imports
Import essential libraries for finetuning, model configuration, and experiment tracking.

In [7]:
# Model and tokenization
from transformers import (
    T5ForConditionalGeneration,    # Seq2Seq language model loading
    T5Tokenizer,                   # Text tokenization
    DataCollatorForSeq2Seq,        # Batch inputs handling
)

# Model optimization
from torch.optim import AdamW
from transformers import get_scheduler

# Training and Setup
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

# Dataset handling
from datasets import load_dataset

# Logging configuration
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

2025-10-08 14:23:04.016917: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759933384.395999      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759933384.501090      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [8]:
model_name = "google/flan-t5-base"
max_seq_length = 512

print(f"Loading model: {model_name}")
print(f"Max sequence length: {max_seq_length}")

Loading model: google/flan-t5-base
Max sequence length: 512


In [9]:
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    trust_remote_code=True,               # Allow custom model code execution
    # dtype=torch.float16,                  # Use FP16 for non-quantized operations
)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [10]:
tokenizer = T5Tokenizer.from_pretrained(model_name, trust_remote_code=True)

# Ensure tokenizer has proper padding token for batch processing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [11]:
print(f"✅ Model loaded successfully!")
print(f"📊 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"🧮 Quantized parameters: ~{sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.1f}M")

✅ Model loaded successfully!
📊 Model parameters: ~247.6M
🧮 Quantized parameters: ~0.0M


In [12]:
def compute_model_size(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.nelement() * p.element_size()
    for p in model.buffers():
        n_params += p.nelement() * p.element_size()

    return n_params / (1024 ** 3)

print(f"📊 Model size : {compute_model_size(model):.2f} GB")

📊 Model size : 0.92 GB


## Dataset Setup
Configure the Drug discovery instruction dataset.

In [13]:
def tokenize_dataset_example(examples):
    model_inputs = tokenizer(examples['input'], return_tensors='pt', padding=True)
    labels = tokenizer(examples['target'], return_tensors='pt', padding=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

print("✅ Dataset tokenization function defined (batch mode)")

✅ Dataset tokenization function defined (batch mode)


In [14]:
def filter_dataset_example(example):
    return len(example['input_ids']) < 512
    
print("✅ Dataset filter functions defined")

✅ Dataset filter functions defined


In [15]:
# Load and preprocess DrugInstruct training dataset
print("🔄 Loading DrugInstruct dataset...")
dataset = load_dataset("khairi/drug-discovery-hetionet")

# Apply conversation formatting to all examples
dataset = dataset.map(tokenize_dataset_example, batched=True, batch_size=4)
    # .filter(filter_dataset_example)

train_data = dataset['train']
valid_data = dataset['validation']

print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(train_data):,}")
print(f"📊 Validation examples: {len(valid_data):,}")
print(f"🎯 Sample input: {tokenizer.decode(train_data[0]['input_ids'])}")
print(f"🎯 Sample output: {tokenizer.decode(train_data[0]['labels'])}")

🔄 Loading DrugInstruct dataset...


train.parquet:   0%|          | 0.00/164k [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/15.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1416 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/84 [00:00<?, ? examples/s]

Map:   0%|          | 0/1416 [00:00<?, ? examples/s]

Map:   0%|          | 0/84 [00:00<?, ? examples/s]

✅ Dataset loaded and processed!
📊 Training examples: 1,416
📊 Validation examples: 84
🎯 Sample input: Predict drugs for the central gene based on its neighborhood: Central node: NAT2 One-hop neighbors: [process, Biological Process::GO:0071466], [process, Biological Process::GO:0009410], [process, Biological Process::GO:0006805] [molecular function, Molecular Function::GO:0004060], [molecular function, Molecular Function::GO:0008080], [molecular function, Molecular Function::GO:0016746], [molecular function, Molecular Function::GO:0016407], [molecular function, Molecular Function::GO:0016747] Answer:</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

## Training Setup
Configure training parameters optimized for finetuning with memory constraints.

In [16]:
train_data

Dataset({
    features: ['input', 'target', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1416
})

In [17]:
valid_data

Dataset({
    features: ['input', 'target', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 84
})

In [18]:
# Prepare data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [19]:
# Configure Finetuning training parameters for sequence modeling
training_args = Seq2SeqTrainingArguments(
    # Memory-efficient batch configuration
    per_device_train_batch_size=4,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=4,   # Effective batch size = 4 * 4 = 16
       
    # Training duration and monitoring
    # max_steps=1000,                     # Short demo run (increase to 500+ for production)
    logging_steps=50,                  # Log metrics every step for close monitoring
    save_steps=50,
    eval_steps=50,
    eval_strategy='steps',
    num_train_epochs=10,            # Uncomment for production

    weight_decay=0.01,
    optim = 'adamw_torch',
    lr_scheduler_type='cosine',
    warmup_ratio=0.2,
    learning_rate=1e-5,
    save_total_limit = 3,
    
    # Stability and output configuration
    output_dir="./finetuning_outputs",
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    report_to="wandb",                # use wandb for experiment tracking
    run_name='flan-base-hetionet-instruct',

    # Push to Hub, uncomment in production
    push_to_hub=True,
    hub_model_id='khairi/Shizuku-247M'
    
)

In [20]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    tokenizer=tokenizer,
    data_collator=data_collator
)

  trainer = Seq2SeqTrainer(


In [21]:
# Execute FineTuning
print("🚀 Starting Finetuning...")

# Run the training process
trainer.train()

print("✅ Training completed successfully!")
print(f"💾 Model saved to: {training_args.output_dir}")

[34m[1mwandb[0m: Currently logged in as: [33mflursky[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


🚀 Starting Finetuning...


[34m[1mwandb[0m: Tracking run with wandb version 0.20.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251008_142339-vs0t86kt[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mflan-base-hetionet-instruct[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/flursky/Drug-FLAN[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/flursky/Drug-FLAN/runs/vs0t86kt[0m
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
50,22.7392,20.139389
100,15.0464,9.549397
150,6.2251,3.816029
200,3.606,2.459501
250,2.8342,1.678527
300,2.4208,1.406151
350,2.1766,1.28788
400,2.087,1.254877
450,2.0227,1.250299




✅ Training completed successfully!
💾 Model saved to: ./finetuning_outputs


In [22]:
row = valid_data[0]

In [23]:
row.keys()

dict_keys(['input', 'target', 'input_ids', 'attention_mask', 'labels'])

In [24]:
inputs = tokenizer(row['input'], return_tensors='pt')

inputs = {k: v.to(model.device) for k, v in inputs.items()}

In [25]:
output = model.generate(**inputs)

In [26]:
print(tokenizer.decode(output[0]))

<pad>- - - - - - - - - -


In [27]:
print(row['target'])

- Adenosine monophosphate 
- Adenosine triphosphate
