# Model and tokenizer

In [None]:
from transformers import BioGptForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset
import torch

model_name = "microsoft/biogpt"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BioGptForCausalLM.from_pretrained(model_name)

# 1. Import Dataset

In [None]:
# Load PubMedQA dataset
dataset = load_dataset("pubmed_qa", "pqa_labeled")

# Show dataset structure
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 1000
    })
})


# 2. Preprocessing

## 2.1.Formatting dataset

In [None]:
train_dataset = dataset["train"]
train_dataset

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})

In [None]:
train_dataset["context"][0]

{'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
  'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoT

In [None]:
print(train_dataset["context"][0]["contexts"][0])
print(train_dataset["context"][0]["contexts"][1])

Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.
The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoTracker Red CMXRos an

In [None]:
def format_dataset(examples):
    inputs = [f"Question: {q} " for q in examples["question"]]

    format_context = []
    for context in train_dataset["context"]:
        single_context = []
        for texts in context["contexts"]:
           single_context.append(texts)
        single_context = " ".join(single_context)
        format_context.append(single_context)

    outputs = [
        f"Context: {c} Answer: {a} Decision: {d}" for c,a,d in zip(format_context, examples["long_answer"], examples["final_decision"])
    ]
    texts = [inp+out for inp, out in zip(inputs, outputs)]

    return texts

formatted_dataset = format_dataset(train_dataset)
print(len(formatted_dataset))

1000


In [None]:
for text in formatted_dataset[:3]:
    print(text)

Question: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death? Context: Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in la

In [None]:
## 2.2. Tokenization and encoding

In [None]:
import tqdm.auto as tqdm

def tokenize_text(text_list, tokenizer):
    tokenized_texts = []
    for text in tqdm.tqdm(text_list):
        text = tokenizer.encode(text)
        tokenized_texts.append(text)
    return tokenized_texts

tokenized_texts = tokenize_text(text_list=formatted_dataset, tokenizer=tokenizer)


  0%|          | 0/1000 [00:00<?, ?it/s]

## 2.3. Save the tokenized data

In [None]:
import os
import pickle
def save(data, dir="./data"):
    print(f"Saving data to {dir}")
    if not os.path.isdir(dir):
        os.makedirs(dir)

    data_path = f"{dir}/tokenized_text.pickle"

    if os.path.exists(data_path):
        print("Dataset exists, type R to replace or anything to skip")
        user_input = input("you:")
        if user_input == "R":
            with open(data_path, "wb") as f:
                pickle.dump(data, f)
            print("Data successfully replaced")
        else:
            print("Skipping...")

    else:
        with open(data_path, "wb") as f:
            pickle.dump(data, f)
        print("Data saved successfully")

save(tokenized_texts)


Saving data to ./data
Dataset exists, type R to replace or anything to skip
Data successfully replaced and saved


# Checkpoint...
# 2.4. Load the saved tokenized dataset

In [1]:
!pip install sacremoses
import sacremoses
from transformers import BioGptForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import torch
import tqdm.auto as tqdm

model_name = "microsoft/biogpt"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BioGptForCausalLM.from_pretrained(model_name)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/595 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/927k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/696k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.56G [00:00<?, ?B/s]

In [2]:
### Freeze all the layers and unfreeze the last 4 layers on the model

for param in model.parameters():
  param.requires_grad= False

# Unfreeze only the last 4 layers
for i in range(4):  # Adjust this number to unfreeze more/fewer layers
    for param in model.biogpt.layers[-(i+1)].parameters():
        param.requires_grad = True


In [3]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params}")

Total trainable parameters: 50384896


In [4]:
import pickle
def load(dir="tokenized_text.pickle"):
    print(f"loading dataset from {dir}")
    with open(dir, "rb") as f:
        data = pickle.load(f)
    print("data loaded successfully")
    return data
loaded_data = load()

loading dataset from tokenized_text.pickle
data loaded successfully


# 3. Training

In [5]:
test_frac=0.8
training_num = int(len(loaded_data)*test_frac)
training_data = loaded_data[:training_num]
testing_data = loaded_data[training_num:]

In [6]:
import os
output_dir = "./model"
if os.path.exists(output_dir):
    pass
else:
    os.mkdir(output_dir)
logging_dir = "./logs"
if os.path.exists:
    pass
else:
    os.mkdir(logging_dir)

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=10,
    eval_strategy="steps",
    save_total_limit=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    warmup_steps=5,
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    seed=42
)

In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_data,
    eval_dataset=testing_data,
    processing_class=tokenizer,
    data_collator=data_collator,
)

In [8]:
trainer.train()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msaurabhbhattarai1999[0m ([33msaurabhbhattarai1999-university-of-roehampton-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss
10,2.5714,2.246562
20,2.2617,2.113255
30,2.1771,2.086507
40,2.2275,2.066851
50,2.2572,2.058095
60,2.2256,2.052861
70,2.1989,2.04935
80,2.1818,2.045508
90,2.203,2.043167
100,2.2012,2.041673


There were missing keys in the checkpoint model loaded: ['output_projection.weight'].


TrainOutput(global_step=400, training_loss=2.053743071556091, metrics={'train_runtime': 1892.3943, 'train_samples_per_second': 1.691, 'train_steps_per_second': 0.211, 'total_flos': 2639309330055168.0, 'train_loss': 2.053743071556091, 'epoch': 4.0})

# Text generation

In [9]:
import torch

def generate_text(prompt, model, tokenizer, max_length=512, num_return_sequences=1, temperature=1.0, top_p=0.9):
    # Check for GPU availability and set the device accordingly
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move the model to the appropriate device
    model.to(device)

    # Encode the prompt text
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Generate text from the model
    generated_ids = model.generate(
        input_ids,
        max_length=max_length,
        num_return_sequences=num_return_sequences,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,  # Set to True for sampling (more randomness), False for greedy decoding
        no_repeat_ngram_size=2  # Helps avoid repetition in the generated text
    )

    # Decode the generated text
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return generated_text

# Example usage:
prompt = "Can I die of heart attack?"
generated_response = generate_text(prompt, model, tokenizer)

generated_response


'Can I die of heart attack? Context: A community-based survey in the city of São Paulo, Brazil. For the first time in Latin America, we aimed to assess the prevalence of cardiovascular diseases and risk factors among the adult population of a city in São Parda, São Polínetica. The population under study was formed by all adults who had at least one of the following risk conditions: smoking, diabetes mellitus, hypertension, or dyslipidemia. Of 988 randomly selected adults, 589 (61%) completed the survey. Mean age was 54.7 years (SD 16.7), and 516 were women (73% of those who responded). We observed a high prevalence rate for hypertension (55%, 95% CI 55-58%; women 65% and men 49%: p = 0.0001) and a low prevalence for smoking (11% in women, 15% men: n.s.) among subjects who reported being overweight or obese (body mass index (BMI) > or = 30 kg / m2) (women 25% women; men 15; p < 0.0001 and p for the interaction between gender and BMI = 0.001) or having diabetes (10% obese women and 5% no

In [10]:
generated_response

'Can I die of heart attack? Context: A community-based survey in the city of São Paulo, Brazil. For the first time in Latin America, we aimed to assess the prevalence of cardiovascular diseases and risk factors among the adult population of a city in São Parda, São Polínetica. The population under study was formed by all adults who had at least one of the following risk conditions: smoking, diabetes mellitus, hypertension, or dyslipidemia. Of 988 randomly selected adults, 589 (61%) completed the survey. Mean age was 54.7 years (SD 16.7), and 516 were women (73% of those who responded). We observed a high prevalence rate for hypertension (55%, 95% CI 55-58%; women 65% and men 49%: p = 0.0001) and a low prevalence for smoking (11% in women, 15% men: n.s.) among subjects who reported being overweight or obese (body mass index (BMI) > or = 30 kg / m2) (women 25% women; men 15; p < 0.0001 and p for the interaction between gender and BMI = 0.001) or having diabetes (10% obese women and 5% no