In [None]:
# Install libs
pip install torch transformers datasets accelerate


In [None]:
# imports
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_scheduler
import json
from tqdm import tqdm

# Check if cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Instantiate mini model from pretrain
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tokenizer.pad_token = tokenizer.eos_token 

tiny_model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
).to(device)


In [None]:
from torch.utils.data import Dataset, DataLoader
import json

class PubMedQADataset(Dataset):
    def __init__(self, data_path, tokenizer, mesh_filter_path, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Load mesh filters
        with open(mesh_filter_path, "r") as f:
            self.mesh_filter = set(line.strip().lower() for line in f if line.strip())

        # Load dataset
        with open(data_path, 'r') as f:
            self.training_data = json.load(f)

        self.pmids = []

        for pmid, item in self.training_data.items():
            item_meshes = [m.lower() for m in item.get("MESHES", [])]
            if any(m in self.mesh_filter for m in item_meshes):
                self.pmids.append(pmid)

    def __len__(self):
        return len(self.pmids)

    def __getitem__(self, idx):
        pmid = self.pmids[idx]
        data_item = self.training_data[pmid]

        contexts = data_item["CONTEXTS"]
        if isinstance(contexts, list):
            contexts = " ".join(contexts)

        question = contexts + " " + data_item["QUESTION"]
        answer = data_item["final_decision"] + " " + data_item["LONG_ANSWER"]

        prompt = f"<|user|> {question} <|assistant|> {answer}"

        inputs = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        inputs["labels"] = inputs["input_ids"].clone()
        return {key: val.squeeze(0) for key, val in inputs.items()}


In [None]:
# SFT
optimizer = AdamW(tiny_model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

tiny_model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in tqdm(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = tiny_model(**batch)
        loss = outputs.loss
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

tiny_model.save_pretrained("fine_tuned_tinyllama_pubmedqa")
tokenizer.save_pretrained("fine_tuned_tinyllama_pubmedqa")