In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
import wandb
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from datasets import load_dataset

import pandas as pd
import ast
from bs4 import BeautifulSoup
import unicodedata

In [None]:
dataset = pd.read_csv('datasets/staging_test_set.csv')[['PMID', 'Abstract', 'EDAM Topics']]

dataset['EDAM Topics'] = dataset['EDAM Topics'].apply(ast.literal_eval)
dataset['Abstract'] = dataset['Abstract'].apply(lambda x: BeautifulSoup(x, 'html.parser').get_text(strip=True))

In [None]:
test_data = pd.read_csv('outputs/raw_model_outputs.csv').iloc[:25][['Abstract']]
for i in range(len(test_data)):
    test_data.loc[i, 'Abstract'] = unicodedata.normalize("NFKD", test_data.loc[i, 'Abstract'])

In [None]:
dataset[dataset['Abstract'].isin(test_data['Abstract'])]

In [None]:
dataset = dataset[~dataset['Abstract'].isin(test_data['Abstract'])]

In [None]:
training_set = dataset.sample(n=1000, random_state=42)

In [None]:
## Prepare dataset for finetuning

with open('templates/prompt_template.txt', 'r') as template_file:
    template = template_file.read()
with open('EDAM/edam_topics.txt', 'r') as edam_file:
    full_edam_topics = edam_file.readlines()

full_edam_topics = [topic.strip() for topic in full_edam_topics]
# Add EDAM topics to prompt template

formatted_topics = "\n".join(full_edam_topics)
template = template.replace("<topics>", formatted_topics)

In [None]:
inputs, outputs = [], []
for index, row in training_set.iterrows():
    abstract = row['Abstract']
    edam_topics = row['EDAM Topics']
    
    prompt = template.replace('<abstract>', abstract)
    prompt = prompt.replace('<num_terms>', str(len(edam_topics)))

    inputs.append(prompt)
    outputs.append(', '.join(edam_topics))


In [None]:
training_set['Instruction'] = inputs

training_set['Output'] = outputs
training_set.dropna(inplace=True)

In [None]:
training_set.to_csv('datasets/llm-finetune-data-1000.csv', index=False)

## Finetuning model

In [None]:
training_set = load_dataset('csv', data_files='datasets/llm-finetune-data-1000.csv', split='train')

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
    "/nvme/models/mixtral-8x7b-instruct-model/",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    device_map="auto"
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("/nvme/models/mixtral-8x7b-instruct-tokenizer/")

In [None]:
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token

In [None]:
run = wandb.init(project='Fine Tune Mixtral 8x7B', job_type="training", anonymous="allow")

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
    )
model = get_peft_model(model, peft_config)

In [None]:
training_arguments = TrainingArguments(
    output_dir= "./results",
    num_train_epochs= 1,
    per_device_train_batch_size= 4,
    gradient_accumulation_steps= 2,
    optim = "paged_adamw_8bit",
    save_steps= 5000,
    logging_steps= 30,
    learning_rate= 2e-4,
    weight_decay= 0.001,
    fp16= False,
    bf16= False,
    max_grad_norm= 0.3,
    max_steps= -1,
    warmup_ratio= 0.3,
    group_by_length= True,
    lr_scheduler_type= "constant",
    report_to="wandb"
)

In [None]:
def generate_prompt(user_query,  sep="\n\n### "): 
    sys_msg= "The conversation between a Human and a helpful AI BT."
    try:
        p =  "<s> [INST]" + sys_msg +"\n"+ user_query["Instruction"] + "[/INST]" +  user_query["Output"] + "</s>"
    except:
        p = "<s> [INST]" + sys_msg +"\n"+ user_query["Instruction"] + "[/INST]" +  "" + "</s>"
    return p

In [None]:
def tokenize(prompt):
    return tokenizer(
        prompt + tokenizer.eos_token,
        truncation=True,
        max_length=2048,
        padding="max_length"
    )

In [None]:
train_data = training_set.map(lambda x: tokenize(generate_prompt(x)), remove_columns=["Instruction" , "Output"])

trainer = Trainer(
    model=model,
    train_dataset=train_data,
    args=training_arguments,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

# silence the warnings. Please re-enable for inference!
model.config.use_cache = False 

In [None]:
# Train model
trainer.train()
# Save trained model
trainer.model.save_pretrained("/nvme/models/mixtral-8x7b-instruct-finetuned-model/")
tokenizer.save_pretrained('/nvme/models/mixtral-8x7b-instruct-finetuned-tokenizer/')