In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_cosine_schedule_with_warmup, GPT2ForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import json
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
MODEL_NAME = 'gpt2'

In [4]:
# Set hyperparameters
learning_rate = 1e-5
epochs = 8
dropout = 0.2
batch_size = 2

In [5]:
# Load fine-tuning model
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)

In [6]:
# Define the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=-1, last_epoch=-1)

In [11]:
# data preprocessing (reduce 38 pieces of data in training set)
path_original_train = 'data/gen_dataset_mhy_train.json'
path_filter_original_train = 'data/filter_gen_dataset_mhy_train.json' 

with open(path_original_train, 'r') as f:
    raw_data = json.load(f)

filter_data = []
for item in raw_data:
    if len(item['chat'].rsplit('Assistant: ', 1)) == 2:
        filter_data.append(item)

with open(path_filter_original_train, 'w') as f:
    json.dump(filter_data, f)

In [13]:
# data preprocessing (reduce 2 pieces of data in validation set)
path_original_val = 'data/gen_dataset_mhy_val.json'
path_filter_original_val = 'data/filter_gen_dataset_mhy_val.json' 

with open(path_original_val, 'r') as f:
    raw_data = json.load(f)

filter_data = []
for item in raw_data:
    if len(item['chat'].rsplit('Assistant: ', 1)) == 2:
        filter_data.append(item)

with open(path_filter_original_val, 'w') as f:
    json.dump(filter_data, f)

In [8]:
class ChatDataset(Dataset):
    def __init__(self, filename, confidence=True):
        with open(filename, 'r') as f:
            raw_data = json.load(f)
            
        if confidence:
            self.data = [item for item in raw_data if item['confidence'] == 6]
        else:
            self.data = raw_data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chat = self.data[idx]['chat']
        instruction, demonstration = chat.rsplit('Assistant: ', 1)
            
        return instruction, demonstration

In [9]:
# Load the dataset
train_dataset = ChatDataset('data/filter_gen_dataset_mhy_train.json')
train_loader = DataLoader(train_dataset, batch_size=batch_size)
eval_dataset = ChatDataset('data/filter_gen_dataset_mhy_val.json', confidence=False)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [10]:
# model validation
def evaluate(eval_loader, model, device, batch_size, instruction_max_length, demonstration_max_length):
    model.eval()
    eval_loss_sum = 0
    num_eval_batches = 0
    
    with torch.no_grad():
        for step, batch in enumerate(eval_loader):
            
            # Tokenize the instruction and demonstration
            instruction = tokenizer(batch[0], return_tensors='pt', max_length=instruction_max_length, padding="max_length", truncation=True).input_ids.to(device)
            demonstration = tokenizer(batch[1], return_tensors='pt', max_length=demonstration_max_length, padding="max_length", truncation=True).input_ids.to(device)
            
            # Concatenate the instruction and demonstration
            input_ids = torch.cat([instruction, demonstration], dim=-1)
            target_ids = input_ids.clone()
            
            # Forward pass through the model
            outputs = model(input_ids, labels=target_ids)
            logits = outputs.logits
            
            # Only consider the loss for the demonstration part
            logits_demo = logits[:, -demonstration_max_length-1:-1, :]         # [batch size, demonstration_max_length, number of classes]
            target_ids_demo = target_ids[:, -demonstration_max_length:]        # [batch size, demonstration_max_length]
            batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo)

            eval_loss_sum += batch_loss.item()
            num_eval_batches = step+1
            
    return eval_loss_sum / num_eval_batches

In [11]:
# model training
instruction_max_length = 512
demonstration_max_length = 512
best_eval_loss = 100
save_path = "./models/sft_model/sft_model_gpt2_with_original_data"

model.train()
model.zero_grad()


for epoch in range(epochs):
    print(f'epoch {epoch+1}')
    epoch_loss = 0
    num_batches = 0

    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()

        # Tokenize the instruction and demonstration
        instruction = tokenizer(batch[0], return_tensors='pt', max_length=instruction_max_length, padding="max_length", truncation=True).input_ids.to(device)
        demonstration = tokenizer(batch[1], return_tensors='pt', max_length=demonstration_max_length, padding="max_length", truncation=True).input_ids.to(device)
        
        # Concatenate the instruction and demonstration
        input_ids = torch.cat([instruction, demonstration], dim=-1)
        target_ids = input_ids.clone()

        # Forward pass through the model
        outputs = model(input_ids, labels=target_ids)
        logits = outputs.logits

        # Only consider the loss for the demonstration part
        logits_demo = logits[:, -demonstration_max_length-1:-1, :]         # [batch size, demonstration_max_length, number of classes]
        target_ids_demo = target_ids[:, -demonstration_max_length:]        # [batch size, demonstration_max_length]
        batch_loss = torch.nn.functional.cross_entropy(logits_demo.permute(0, 2, 1), target_ids_demo)
        
        # Backward and optimize
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        
        if step % 200 == 0:
            print(f'At step {step}, the loss = {batch_loss.item()}') 
            eval_loss = evaluate(eval_loader, model, device, 4, 512, 512)  
            print(f'Validation Loss: {eval_loss}')

        epoch_loss += batch_loss.item()
        num_batches = step+1
        
    train_loss = epoch_loss / num_batches
    eval_loss = evaluate(eval_loader, model, device, 4, 512, 512)    
    print(f'Epoch: {epoch} | Training Loss: {train_loss} | Validation Loss: {eval_loss}')
    
    
    if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print("Model Saved!")    

epoch 1
At step 0, the loss = 7.067453384399414
Validation Loss: 5.305339722959404
At step 200, the loss = 0.5005541443824768
Validation Loss: 0.9278261139606818
At step 400, the loss = 0.36695098876953125
Validation Loss: 0.7949621954152727
At step 600, the loss = 0.8579390048980713
Validation Loss: 0.7757964397597517
Epoch: 0 | Training Loss: 0.7477622972522184 | Validation Loss: 0.7637681983474992
Model Saved!
epoch 2
At step 0, the loss = 0.6939657926559448
Validation Loss: 0.763514370095526
At step 200, the loss = 0.4450378715991974
Validation Loss: 0.7518466388058459
At step 400, the loss = 0.31622013449668884
Validation Loss: 0.7465434743935226
At step 600, the loss = 0.7996694445610046
Validation Loss: 0.7399270613351439
Epoch: 1 | Training Loss: 0.600207140055112 | Validation Loss: 0.7411697934962745
Model Saved!
epoch 3
At step 0, the loss = 0.6504468321800232
Validation Loss: 0.7413206140837099
At step 200, the loss = 0.42666545510292053
Validation Loss: 0.7382949777775341
A

In [15]:
torch.cuda.empty_cache()