In [None]:

!pip uninstall -y gcsfs fsspec datasets
!pip install fsspec==2024.2.0 gcsfs==2024.2.0 datasets==2.18.0
!pip uninstall -y transformers accelerate peft
!pip install transformers==4.39.3 accelerate==0.27.2 peft==0.10.0

!pip install -q torch sacremoses evaluate seqeval

import os
import torch
import logging
from google.colab import drive
from transformers import AutoModelForMaskedLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import load_dataset
from torch.utils.data import Dataset
from typing import List, Dict, Tuple
import evaluate # For metrics

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

drive.mount('/content/drive')

LLM_FINETUNED_DIR = "/content/drive/MyDrive/finetuned"
os.makedirs(LLM_FINETUNED_DIR, exist_ok=True)
logger.info(f"Created or found existing directory at {LLM_FINETUNED_DIR}")

PUBMEDQA_CACHE_DIR = "/tmp/pubmedqa_data_cache"
os.makedirs(PUBMEDQA_CACHE_DIR, exist_ok=True)
logger.info(f"Using temporary cache directory for datasets: {PUBMEDQA_CACHE_DIR}")


OUTPUT_ADAPTER_PATH = os.path.join(LLM_FINETUNED_DIR, "biobert_finetuned_adapter")
OUTPUT_MERGED_MODEL_PATH = os.path.join(LLM_FINETUNED_DIR, "biobert_finetuned_model_merged")

MODEL_NAME = "dmis-lab/biobert-v1.1"
MAX_LENGTH = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info(f"Initial device check: {DEVICE}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"Memory Allocated: {torch.cuda.memory_allocated() / 1e6:.2f} MB")
    logger.info(f"Memory Reserved: {torch.cuda.memory_reserved() / 1e6:.2f} MB")
else: # If no CUDA GPU, check if XLA/TPU is available
    try:
        DEVICE = torch.device("cpu")
    except ImportError:
        logger.info(f"torch_xla not installed or available. Falling back to CPU.")
        DEVICE = torch.device("cpu")

Found existing installation: gcsfs 2025.3.2
Uninstalling gcsfs-2025.3.2:
  Successfully uninstalled gcsfs-2025.3.2
Found existing installation: fsspec 2025.3.2
Uninstalling fsspec-2025.3.2:
  Successfully uninstalled fsspec-2025.3.2
Found existing installation: datasets 2.14.4
Uninstalling datasets-2.14.4:
  Successfully uninstalled datasets-2.14.4
Collecting fsspec==2024.2.0
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Collecting gcsfs==2024.2.0
  Downloading gcsfs-2024.2.0-py2.py3-none-any.whl.metadata (1.6 kB)
Collecting datasets==2.18.0
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.18.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gcsfs-2024.2.0-py2.py3-none-any.whl (33 kB)
Downloading dat

In [None]:
# Data Preparation & Custom Dataset Classes

class LLMFineTuningDataset(Dataset):
    """Generic dataset for LLM fine-tuning, handles tokenization."""
    def __init__(self, encodings, answers=None):
        self.encodings = encodings
        self.answers = answers
    def __len__(self):
        return len(self.encodings.input_ids)

    def __getitem__(self, idx):
        item = {
            key: val[idx].clone().detach() if isinstance(val, torch.Tensor) else val[idx]
            for key, val in self.encodings.items()
        }
        return item

def load_and_prepare_data(stage: str, tokenizer) -> Dataset:
    """Loads and prepares data for a specific fine-tuning stage using qiaojin/PubMedQA."""
    logger.info(f"Loading data for stage: {stage}")

    raw_dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial", cache_dir=PUBMEDQA_CACHE_DIR)
    train_split = raw_dataset["train"].select(range(200))

    if stage == "domain_adaptation":
        logger.info("Preparing data for Domain Adaptation (Masked Language Modeling).")
        texts = [" ".join(item['context']) + " " + item['long_answer'] for item in train_split if item['context'] and item['long_answer']]
        encodings = tokenizer(
            texts,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return LLMFineTuningDataset(encodings)

    elif stage == "few_shot":
        logger.info("Preparing data for Few-shot Fine-tuning (Extractive Question Answering).")
        qa_data = []
        for i, item in enumerate(train_split):
            context_str = " ".join(item['context'])
            question = item['question']
            long_answer = item['long_answer']
            if context_str and question and long_answer:
                qa_data.append({"id": str(i), "question": question, "context": context_str, "answers": {"text": [long_answer], "answer_start": [context_str.find(long_answer)]}})

        texts = [f"[CLS] {d['question']} [SEP] {d['context']} [SEP]" for d in qa_data]
        encodings = tokenizer(
            texts,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Store answers separately, not inside encodings
        answers = [item['answers']['text'][0] for item in qa_data]

        return LLMFineTuningDataset(encodings, answers=answers)

    elif stage == "contrastive":
        logger.info("Skipping Contrastive Fine-tuning as qiaojin/PubMedQA does not directly support 'gene-trait similarity/dissimilarity pairs' without complex custom logic.")
        logger.warning("For Stage 3 as defined in Module 1, a dataset with explicit positive and negative gene-trait examples (like OMIM data with specific labeling) would be necessary.")
        return None

    else:
        raise ValueError(f"Unknown stage: {stage}")


In [None]:
#  Fine-Tuning Functions

def fine_tune_stage(model, tokenizer, dataset: Dataset, stage: str, epochs: int = 3):
    """Fine-tune the model for a specific stage."""
    if dataset is None: # Handle skipped contrastive stage
        logger.info(f"Skipping {stage} fine-tuning due to unsuitable dataset.")
        return model

    logger.info(f"Starting {stage} fine-tuning...\nNote: Using bf16 for mixed precision, compatible with TPUs and newer GPUs.")

    # Configure LoRA
    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["query", "key", "value"],
        lora_dropout=0.1,
        bias="none",
        task_type="FEATURE_EXTRACTION"
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) if stage == "domain_adaptation" else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=os.path.join(LLM_FINETUNED_DIR, f"{stage}_checkpoints"),
        num_train_epochs=epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        save_strategy="epoch",
        logging_steps=10,
        learning_rate=2e-5,
        bf16=True,
        remove_unused_columns=False,
        dataloader_num_workers=os.cpu_count() if os.cpu_count() else 0,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator
    )

    trainer.train()
    logger.info(f"{stage} fine-tuning completed.")
    trainer.save_model(os.path.join(LLM_FINETUNED_DIR, f"{stage}_adapter"))
    return model

def save_final_model(model, tokenizer):
    """
    Saves the fine-tuned model and tokenizer to Google Drive.
    """
    logger.info(f"Saving final fine-tuned model to Google Drive directory: {LLM_FINETUNED_DIR}")
    os.makedirs(LLM_FINETUNED_DIR, exist_ok=True)

    #  Save using save_pretrained
    model.save_pretrained(LLM_FINETUNED_DIR)
    tokenizer.save_pretrained(LLM_FINETUNED_DIR)
    logger.info("Model and tokenizer saved using save_pretrained (Hugging Face format).")

    # Save the model's state_dict as a .pth file as well
    pth_save_path = os.path.join(LLM_FINETUNED_DIR, "model_state_dict.pth")
    torch.save(model.state_dict(), pth_save_path)
    logger.info(f"Model state_dict saved to {pth_save_path} (.pth format).")

In [None]:
# Training Loop

logger.info(f"Loading base model {MODEL_NAME}...")
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


if os.path.exists(OUTPUT_MERGED_MODEL_PATH) and os.path.exists(os.path.join(OUTPUT_MERGED_MODEL_PATH, "tokenizer.json")):
    logger.info(f"Final merged model already exists at {OUTPUT_MERGED_MODEL_PATH}. Skipping training.")
else:
    #  Domain adaptation (MLM)
    data_domain_adaptation = load_and_prepare_data("domain_adaptation", tokenizer)
    if data_domain_adaptation:
        model = fine_tune_stage(model, tokenizer, data_domain_adaptation, "domain_adaptation", epochs=1)

    # Few-shot fine-tuning (Feature Extraction for QA-like understanding)
    data_few_shot = load_and_prepare_data("few_shot", tokenizer)
    if data_few_shot:
        model = fine_tune_stage(model, tokenizer, data_few_shot, "few_shot", epochs=1)

    # Contrastive fine-tuning
    data_contrastive = load_and_prepare_data("contrastive", tokenizer)
    if data_contrastive:
        model = fine_tune_stage(model, tokenizer, data_contrastive, "contrastive", epochs=1)

    # Save final fine-tuned model
    save_final_model(model, tokenizer)

logger.info("LLM Fine-tuning process complete.")


Some weights of BertForMaskedLM were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading readme:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 233M/233M [00:00<00:00, 237MB/s]


Generating train split:   0%|          | 0/211269 [00:00<?, ? examples/s]

trainable params: 442,368 || all params: 108,783,172 || trainable%: 0.4066511316658426


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdagikas22[0m ([33mdagikas22-jsad[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,11.1838




trainable params: 442,368 || all params: 108,783,172 || trainable%: 0.4066511316658426


Step,Training Loss
10,11.0773


