In [2]:
import os, sys

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path) 
from data.create_qa_dataloaders import create_qa_dataloaders
from utils import set_seed
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch


[2023-07-17 13:57:57,570] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [12]:
# Log in to your W&B account on console before running this
# https://docs.wandb.ai/quickstart#2-log-in-to-wb

import wandb
wandb.login()



True

## Setting up models & tokeniser, loading data

In [2]:
set_seed(62)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
judge = GPT2LMHeadModel.from_pretrained('gpt2').to(device)

NameError: name 'set_seed' is not defined

In [1]:
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
judge.resize_token_embeddings(len(tokenizer))

train_prop = 0.8
batch_size = 16
shuffle = True

train_loader, test_loader = create_qa_dataloaders('data/processed/TruthfulQA_labeled.csv', tokenizer, train_prop, batch_size, shuffle)



NameError: name 'tokenizer' is not defined

## Training model

In [7]:
from torch.optim import AdamW as Harriet


lr = 5e-5
optimizer = Harriet(judge.parameters(), lr=lr)
pad_idx = len(tokenizer) - 1  # Since padding token is last token
ce_loss = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
num_epochs = 20

log_accuracy_every_batch = 50  # How many steps we compute the accuracy over
save_every_epoch = 5

In [8]:
import time


config = config={
        "lr": lr,
        "batch size": batch_size,
        "epochs": num_epochs,
}

wandb.init(
    project="Finetuning-TruthfulQA",
    name=None,
    config=config
)

In [14]:
import torch.nn as nn
from tqdm import tqdm


global_step = 0
acc = []
yes_idx = tokenizer("Yes").input_ids[0]
no_idx = tokenizer("No").input_ids[0]

judge.train()
for epoch in range(num_epochs):
    for train_idx, batch in tqdm(enumerate(train_loader)):
        input_ids, attention_mask, label = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        label = label.to(device)

        output = judge(input_ids=input_ids, attention_mask=attention_mask)
        last_logits = output.logits[:, -1:, :]
        last_logits = last_logits.squeeze(1)

        loss = ce_loss(last_logits, label)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # Metrics
        metrics = {"train/loss": loss}
        wandb.log(metrics)

        token_probs = torch.softmax(last_logits, -1)
        top_token = torch.argmax(token_probs, dim=1)

        batch_acc_yes = torch.sum((label==top_token) * (label==yes_idx))
        batch_acc_no = torch.sum((label==top_token) * (label==no_idx))
        batch_acc = (batch_acc_yes + batch_acc_no) / batch_size
        acc.append(batch_acc)

        global_step += 1

        if global_step % log_accuracy_every_batch == 0:
            this_acc = sum(acc) / len(acc)
            wandb.log({"train/acc": this_acc})
            acc = []

        # Test loop
        if global_step % 50 == 0:
            judge.eval()
            total_test_loss = 0
            test_acc = []
            with torch.no_grad():
                for batch in test_loader:
                    input_ids, attention_mask, label = batch
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    label = label.to(device)

                    output = judge(input_ids=input_ids, attention_mask=attention_mask)
                    last_logits = output.logits[:, -1:, :]  # Double check
                    last_logits = last_logits.squeeze(1)

                    loss = ce_loss(last_logits, label)
                    total_test_loss += loss.item()

                    # Accuracy
                    token_probs = torch.softmax(last_logits, -1)
                    top_token = torch.argmax(token_probs, dim=1)

                    batch_acc_yes = torch.sum((label==top_token) * (label==yes_idx))
                    batch_acc_no = torch.sum((label==top_token) * (label==no_idx))
                    batch_acc = (batch_acc_yes + batch_acc_no) / batch_size
                    test_acc.append(batch_acc)

            avg_loss = total_test_loss / len(test_loader)
            avg_acc = sum(test_acc) / len(test_acc)
            metrics = {"test/loss": avg_loss, "test/acc": avg_acc}
            wandb.log(metrics)

            judge.train()

    if epoch % save_every_epoch == 0:
        judge_save_path = "gpt2-judge-finetuned-epoch{:}.pt".format(epoch)
        torch.save(judge.state_dict(), judge_save_path)
        wandb.save(judge_save_path)

259it [39:02,  9.04s/it]


KeyboardInterrupt: 