In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_cosine_schedule_with_warmup, GPT2ForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import json
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
MODEL_NAME = 'gpt2-medium'

In [4]:
# Set hyperparameters
learning_rate = 1e-5
epochs = 8
dropout = 0.2
batch_size = 1

In [5]:
# Load fine-tuning model
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)

In [6]:
# Define the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=-1, last_epoch=-1)

In [7]:
class ChatDataset(Dataset):
    def __init__(self, filename, augmentation=True):
        with open(filename, 'r') as f:
            raw_data = json.load(f)
            
        if augmentation:
            with open('data/gen_dataset_mhy_math.json', 'r') as f1:
                math_data = json.load(f1)
            with open('data/gen_dataset_mhy_openbookqa.json', 'r') as f2:
                code_data = json.load(f2) 
            raw_data.extend(math_data)
            raw_data.extend(code_data)
            self.data = raw_data
        else:
            self.data = raw_data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chat = self.data[idx]['chat']
        instruction, demonstration = chat.rsplit('Assistant: ', 1)
        instruction = instruction + 'Assistant: '
            
        return instruction, demonstration, chat

In [8]:
# Load the dataset
train_dataset = ChatDataset('data/gen_dataset_mhy_train.json', augmentation=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
eval_dataset = ChatDataset('data/gen_dataset_mhy_val.json', augmentation=False)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [9]:
# model validation
def evaluate(eval_loader, model, device, chat_max_length):
    model.eval()
    eval_loss_sum = 0
    num_eval_batches = 0
    
    with torch.no_grad():
        for step, batch in enumerate(eval_loader):
            
            # Tokenize the chat
            chat = tokenizer(batch[2], return_tensors='pt', max_length=chat_max_length, padding="max_length", truncation=True).input_ids.to(device) 

            # Count the length of instruction and demonstration
            instruction = batch[0]
            demonstration = batch[1]
            instruction_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in instruction] 
            demonstration_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in demonstration]
            assert instruction_tokens_num[0] + demonstration_tokens_num[0] <= chat_max_length, "The length of chat is larger than chat_max_length"

            # Add the eos_token to the end of demonstration if the length of original demonstration is less than 512
            demonstration_tokens_num = [demonstration_tokens_num[0] if instruction_tokens_num[0] + demonstration_tokens_num[0] == chat_max_length else demonstration_tokens_num[0]+1] 

            # Get the input_ids and target_ids
            input_ids = chat
            target_ids = input_ids.clone() 

            # Forward pass through the model
            outputs = model(input_ids, labels=target_ids)
            logits = outputs.logits

            # Only consider the loss for the demonstration part
            logits_demo = logits[:, instruction_tokens_num[0]-1:instruction_tokens_num[0]+demonstration_tokens_num[0]-1, :]         # [batch size, demonstration_max_length, number of classes]
            target_ids_demo = target_ids[:, instruction_tokens_num[0]:instruction_tokens_num[0]+demonstration_tokens_num[0]]         # [batch size, demonstration_max_length]
            batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo) 

            eval_loss_sum += batch_loss.item()
            num_eval_batches = step+1
            
    return eval_loss_sum / num_eval_batches

In [10]:
# model training
best_eval_loss = 100
save_path = "./models/sft_model/sft_model_gpt2_with_original_data"

model.train()
model.zero_grad()


for epoch in range(epochs):
    print(f'epoch {epoch+1}')
    epoch_loss = 0
    num_batches = 0

    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()

        # Tokenize the chat
        chat = tokenizer(batch[2], return_tensors='pt', max_length=1024, padding="max_length", truncation=True).input_ids.to(device) 

        # Count the length of instruction and demonstration
        instruction = batch[0]
        demonstration = batch[1]
        instruction_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in instruction] 
        demonstration_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in demonstration]
        assert instruction_tokens_num[0] + demonstration_tokens_num[0] <= 1024, "The length of chat is larger than chat max length"

        # Add the eos_token to the end of demonstration if the length of original demonstration is less than 512
        demonstration_tokens_num = [demonstration_tokens_num[0] if instruction_tokens_num[0] + demonstration_tokens_num[0] == 1024 else demonstration_tokens_num[0]+1] 

        # Get the input_ids and target_ids
        input_ids = chat
        target_ids = input_ids.clone() 

        # Forward pass through the model
        outputs = model(input_ids, labels=target_ids)
        logits = outputs.logits

        # Only consider the loss for the demonstration part
        logits_demo = logits[:, instruction_tokens_num[0]-1:instruction_tokens_num[0]+demonstration_tokens_num[0]-1, :]         # [batch size, demonstration_max_length, number of classes]
        target_ids_demo = target_ids[:, instruction_tokens_num[0]:instruction_tokens_num[0]+demonstration_tokens_num[0]]         # [batch size, demonstration_max_length]
        batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo)
        
        # Backward and optimize
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        
        if step % 400 == 0:
            print(f'At step {step}, the loss = {batch_loss.item()}') 
            eval_loss = evaluate(eval_loader, model, device, 1024)  
            print(f'Validation Loss: {eval_loss}')

        epoch_loss += batch_loss.item()
        num_batches = step+1
        
    train_loss = epoch_loss / num_batches
    eval_loss = evaluate(eval_loader, model, device, 1024)    
    print(f'Epoch: {epoch} | Training Loss: {train_loss} | Validation Loss: {eval_loss}')
    
    
    if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print("Model Saved!")    

epoch 1
At step 0, the loss = 3.2312095165252686
Validation Loss: 2.0414137158774883


KeyboardInterrupt: 

In [11]:
torch.cuda.empty_cache()

In [49]:
MODEL_PATH = "./models/sft_model/sft_model_gpt2_with_original_data"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH).to(device)


In [8]:
# Set the model to evaluation mode
model.eval()

# Define the input question
question = "What tools are typically used to study gene regulatory networks?, choices: [Computational modeling, Genetic knockouts, Patch-clamp, Behavioral studies], please find the suitable choices."
# Tokenize the input question
input_ids = tokenizer.encode(question, return_tensors='pt').to(device)

# Generate the attention mask
attention_mask = torch.ones_like(input_ids).to(device)

# Generate the answer
output = model.generate(input_ids=input_ids, attention_mask = attention_mask, max_length=100, num_beams=5, no_repeat_ngram_size=2, num_return_sequences=5, early_stopping=True)

# Decode and print the response
# response = tokenizer.decode(output[0], skip_special_tokens=True)
# print("Response:", response)

for i, beam in enumerate(output):
    print(f"{i}: {tokenizer.decode(beam, skip_special_tokens=True)}")
    print()

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0: What tools are typically used to study gene regulatory networks?, choices: [Computational modeling, Genetic knockouts, Patch-clamp, Behavioral studies], please find the suitable choices.

1: What tools are typically used to study gene regulatory networks?, choices: [Computational modeling, Genetic knockouts, Patch-clamp, Behavioral studies], please find the suitable choices.

What is the difference between gene regulation and gene expression? [Regulation vs. Expression], choose the appropriate answer.

2: What tools are typically used to study gene regulatory networks?, choices: [Computational modeling, Genetic knockouts, Patch-clamp, Behavioral studies], please find the suitable choices.

What is the difference between gene regulation and gene expression? [Regulation vs. Expression], choose the correct answer.

3: What tools are typically used to study gene regulatory networks?, choices: [Computational modeling, Genetic knockouts, Patch-clamp, Behavioral studies], please find the s

In [23]:
# new model validation
def evaluate(eval_loader, model, device, chat_max_length):
    model.eval()
    eval_loss_sum = 0
    num_eval_batches = 0
    
    with torch.no_grad():
        for step, batch in enumerate(eval_loader):
            
            # Tokenize the chat
            chat = tokenizer(batch[2], return_tensors='pt', max_length=chat_max_length, padding="max_length", truncation=True).input_ids.to(device) 

            # Count the length of instruction and demonstration
            instruction_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in batch[0]][0]
            demonstration_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in batch[1]][0]
            
            # Get the input_ids and target_ids
            input_ids = chat
            target_ids = input_ids.clone() 

            # Forward pass through the model
            outputs = model(input_ids, labels=target_ids)
            logits = outputs.logits

            # Only consider the loss for the demonstration part
            logits_demo = logits[:, instruction_tokens_num-1:-1, :]         # [batch size, demonstration_max_length, number of classes]
            target_ids_demo = target_ids[:, instruction_tokens_num:]        # [batch size, demonstration_max_length]
            batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo) 

            eval_loss_sum += batch_loss.item()
            num_eval_batches = step+1
            
    return eval_loss_sum / num_eval_batches

In [24]:
# new model training
best_eval_loss = 100
save_path = "./models/sft_model/sft_model_gpt2_medium_with_augmentation_data"

model.train()
model.zero_grad()


for epoch in range(epochs):
    print(f'epoch {epoch+1}')
    epoch_loss = 0
    num_batches = 0

    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()

        # Tokenize the chat
        chat = tokenizer(batch[2], return_tensors='pt', max_length=1024, padding="max_length", truncation=True).input_ids.to(device) 

        # Count the length of instruction and demonstration
        instruction_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in batch[0]][0]
        demonstration_tokens_num = [tokenizer(item, return_tensors='pt').input_ids.size(1) for item in batch[1]][0]

        # Get the input_ids and target_ids
        input_ids = chat
        target_ids = input_ids.clone() 

        # Forward pass through the model
        outputs = model(input_ids, labels=target_ids)
        logits = outputs.logits

        # Only consider the loss for the demonstration part
        logits_demo = logits[:, instruction_tokens_num-1:-1, :]         # [batch size, demonstration_max_length, number of classes]
        target_ids_demo = target_ids[:, instruction_tokens_num:]        # [batch size, demonstration_max_length]
        batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo)
        
        # Backward and optimize
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        
        if step % 400 == 0:
            print(f'At step {step}, the loss = {batch_loss.item()}') 
            eval_loss = evaluate(eval_loader, model, device, 1024)  
            print(f'Validation Loss: {eval_loss}')

        epoch_loss += batch_loss.item()
        num_batches = step+1
        
    train_loss = epoch_loss / num_batches
    eval_loss = evaluate(eval_loader, model, device, 1024)    
    print(f'Epoch: {epoch} | Training Loss: {train_loss} | Validation Loss: {eval_loss}')

    if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print("Model Saved!")    

epoch 1
At step 0, the loss = 4.567424774169922


KeyboardInterrupt: 