In [1]:
import os

DATA_DIR = "data" # This may need to be changed on different machines

# Make sure we're in the correct directory and make sure the data directory exists
if not os.path.exists(DATA_DIR):
    os.chdir("../..") # Move up two directories because we're in src/nb and the data directory/path should be in/start at the root directory 
    assert os.path.exists(DATA_DIR), f"ERROR: DATA_DIR={DATA_DIR} not found"  # If we still can't see the data directory something is wrong

from tqdm.notebook import tqdm

import torch
# get Dataset class
from torch.utils.data import DataLoader
from torch import nn
import pandas as pd
import json
import numpy as np

from transformers import GPT2LMHeadModel, AdamW, GPT2Tokenizer

from src.lib.decoder_dataset import DecoderDataset
from src.lib.decoder import Decoder
from src.lib.util import to_device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datasets = {}
for dataset_name in ["dev", "test", "train"]:
    save_path = os.path.join(DATA_DIR, "decoded_cds", "balanced", f"{dataset_name}_dataset.pth")    

    if not os.path.exists(save_path):
        df = pd.read_csv(os.path.join(DATA_DIR, "decoded_cds", "balanced", f"{dataset_name}.csv"), index_col=0)
        dataset = DecoderDataset(df)
        dataset.save_state_dict(save_path)
    else:
        dataset = DecoderDataset.from_state_dict(save_path)
    
    datasets[dataset_name] = dataset

In [3]:
batch_size = 20
learning_rate = 5e-5
optimizer_name = "AdamW"



data_loaders = {}
for dataset_name in datasets:
    data_loaders[dataset_name] = DataLoader(datasets[dataset_name], batch_size=batch_size, num_workers=10)

decoder = Decoder().to(device)

optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

Some weights of the model checkpoint at models/gpt2_large were not used when initializing GPT2LMHeadModel: ['transformer.extra_embedding_project.bias', 'transformer.extra_embedding_project.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
epochs = 10

loss_history = []
val_loss_history = []

for epoch in range(epochs):
    decoder.train()
    pbar = tqdm(data_loaders["train"])
    for batch in pbar:
        batch = to_device(batch, device)

        x, y = batch

        label, label_idx = y

        logits = decoder(x).logits[:, label_idx].diagonal().t()

        # calculate loss and backprop
        loss = loss_fn(logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_value = loss.item()
        loss_history.append(loss_value)

        pbar.set_description(f"Epoch {epoch} Loss: {np.mean(loss_history[-30:]):.4f}")
        
    
    with torch.no_grad():
        decoder.eval()
        total_val_loss = 0
        val_loss_count = 0
        for batch in data_loaders["dev"]:
            batch = to_device(batch, device)

            x, y = batch

            label, label_idx = y

            logits = decoder(x).logits[:, label_idx].diagonal().t()

            loss = loss_fn(logits, label)

            loss_value = loss.item()
            total_val_loss += loss_value
            val_loss_count += 1
        
        val_loss_history.append(total_val_loss / val_loss_count)
    
    # save model checkpoint
    avg_loss = np.mean(loss_history[-len(data_loaders["train"]):])
    model_name = f"decoder_{epoch}_{avg_loss:.4f}"
    save_path = os.path.join("decoder_checkpoints", "balanced", model_name)
    if not os.path.exists(save_path):
        os.mkdir(save_path) 
    torch.save(decoder.state_dict(), os.path.join(save_path, "model.pth"))
    # save loss histories
    histories = {
        "loss_history": loss_history,
        "val_loss_history": val_loss_history
    }
    with open(os.path.join(save_path, "history.json"), "w") as f:
        json.dump(histories, f)
    # save config
    config = {
        "epochs": epochs,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "optimizer": optimizer_name,
    }
    
    with open(os.path.join(save_path, "config.json"), "w") as f:
        json.dump(config, f)

  0%|          | 0/13669 [00:00<?, ?it/s]