In [2]:
import os 
work_dir = os.path.join(os.getcwd())
import logging
logging.getLogger().setLevel(logging.INFO)
import warnings
warnings.filterwarnings("ignore")

import time
import pickle
from tqdm import tqdm

# ML AND SCI LIBRARIES
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer

# XENT code
from xentlang import X

device = torch.device("cuda:1")
models_path = os.path.join(work_dir, "models")
data_path = os.path.join(work_dir, "data")


# MODEL LOADING METHODS
def load_model_and_tokenizer(path: str):
    model = AutoModelForCausalLM.from_pretrained(path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(path, clean_up_tokenization_spaces=True)
    return model, tokenizer

def load_model(path: str):
    model = AutoModelForCausalLM.from_pretrained(path).to(device)
    return model

# DATA LOADING METHOD
def load_dataset(name: str):
    with open(os.path.join(data_path, f"{name}.pkl"), "rb") as data:
        return pickle.load(data)
    
class TextDataset(Dataset):
    def __init__(self, dataset: list[str], tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.dataset = [self.tokenize(text) for text in tqdm(dataset)]

    def tokenize(self, text): 
        return self.tokenizer(
            text, 
            return_tensors="pt", 
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        ).to(device)
    
    def tokenize_single(self, text):
        return self.tokenizer(
            text, 
            return_tensors="pt",
            padding=True
        ).to(device)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index) -> str:
        return self.dataset[index]


In [3]:
torch.cuda.empty_cache()

In [5]:
# load the model
path = "models/gpt2/M0"
M0, tokenizer = load_model_and_tokenizer(path)


In [11]:
input = tokenizer.encode("Hey wassup bro", return_tensors="pt").to(device)
M0.generate(input, max_new_tokens=90)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[10814,   373, 37330,  1379,    11,   314,  1101,   407,  8066,  6486,
            11,   314,  1101,   407,  8066,  6486,    11,   314,  1101,   407,
          8066,  6486,    11,   314,  1101,   407,  8066,  6486,    11,   314,
          1101,   407,  8066,  6486,    11,   314,  1101,   407,  8066,  6486,
            11,   314,  1101,   407,  8066,  6486,    11,   314,  1101,   407,
          8066,  6486,    11,   314,  1101,   407,  8066,  6486,    11,   314,
          1101,   407,  8066,  6486,    11,   314,  1101,   407,  8066,  6486,
            11,   314,  1101,   407,  8066,  6486,    11,   314,  1101,   407,
          8066,  6486,    11,   314,  1101,   407,  8066,  6486,    11,   314,
          1101,   407,  8066,  6486]], device='cuda:1')

In [4]:
def find_xent_def(tokens):
    """ Returns the index at which the xent function starts, needed for starting the loss computation """
    xdefseq = tokenizer.encode(X.xdef, return_tensors="pt").to(device)
    seq_len = xdefseq.shape[1]
    windows = tokens.input_ids.unfold(dimension=2, size=seq_len, step=1)
    matches = (windows==xdefseq).all(dim=3)
    indices = matches.nonzero().squeeze(0)
    return indices
    

In [5]:
for batch, data in enumerate(train_loader):
    indices = find_xent_def(data)
    data = data.input_ids.squeeze(0).squeeze(0)
    string = tokenizer.decode(data[indices[2]:], skip_special_tokens=True)
    if batch >= 4: break

In [6]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

from torch.nn import CrossEntropyLoss

In [7]:
LEARNING_RATE = 1e-5

crossentropy = CrossEntropyLoss()
optimizer = AdamW(M0.parameters(), lr = LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.012)
lr_lambda = lambda epoch: 0.965 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda, verbose=True)

In [8]:
# TODO: make it such that loss is computed only after the xent function is called. 

def train(model):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch, tokens in enumerate(train_loader):
        optimizer.zero_grad()
        xidx = find_xent_def(tokens)[2]
        tokens = tokens.input_ids.view(1, -1) # [1, T]
        logits = model(tokens).logits  # [1, T, L]
        loss = crossentropy(logits[0, xidx:][:-1], tokens[0, xidx:][1:])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item() 
        log_interval = 10 
        if batch % log_interval == 0 and batch > 0: 
            cur_loss = total_loss / log_interval 
            elapsed = time.time() - start_time
            print(f"| batch: {batch} | loss: {cur_loss:.5f} | has taken: {elapsed:.2f} seconds")
            total_loss = 0
            start_time = time.time()


def evaluate(test_model, test_loader):
    test_model.eval()
    total_loss = 0
    nbatches = min(50, test_size)
    with torch.no_grad():
        for batch, tokens in enumerate(test_loader):
            tokens = tokens.input_ids.view(1, -1)
            logits = test_model(tokens).logits
            loss = crossentropy(logits[0, :][:-1], tokens[0, :][1:])
            total_loss += loss 
            if batch > nbatches: 
                break
    return total_loss / nbatches


In [9]:
EPOCHS = 20
best_loss = float("inf")
best_model = None

for epoch in tqdm(range(EPOCHS), desc="epoch"):
    train(M0)
    print("Evaluating...", end=" ")
    val_loss = evaluate(M0, test_loader=test_loader)
    print(f"Validation loss: {val_loss:.5f}")

    if val_loss < best_loss:
        best_loss = val_loss
        best_model = M0
    
    scheduler.step()

best_model.save_pretrained("models/gpt2-xl-M1")

epoch:   0%|          | 0/20 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


| batch: 10 | loss: 4.89050 | has taken: 15.96 seconds
| batch: 20 | loss: 1.55253 | has taken: 11.55 seconds
Evaluating... 

epoch:   5%|▌         | 1/20 [00:44<14:03, 44.37s/it]

Validation loss: 2.12859
| batch: 10 | loss: 1.25339 | has taken: 12.84 seconds
| batch: 20 | loss: 0.99296 | has taken: 11.65 seconds
Evaluating... 

epoch:  10%|█         | 2/20 [01:25<12:47, 42.66s/it]

Validation loss: 2.09857
| batch: 10 | loss: 1.00045 | has taken: 12.83 seconds
| batch: 20 | loss: 0.86755 | has taken: 11.82 seconds
Evaluating... 

epoch:  15%|█▌        | 3/20 [02:07<11:58, 42.25s/it]

Validation loss: 2.10775
| batch: 10 | loss: 0.81689 | has taken: 12.82 seconds
| batch: 20 | loss: 0.70978 | has taken: 11.67 seconds


epoch:  15%|█▌        | 3/20 [02:39<15:06, 53.30s/it]


KeyboardInterrupt: 

In [13]:
import torch
device = torch.device("cuda:1")
M1 = torch.load("models/gpt2-xl-M1", weights_only=False).to(device)
M1.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)

In [14]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# MODEL LOADING METHODS
def load_model_and_tokenizer(path: str):
    model = AutoModelForCausalLM.from_pretrained(path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(path, clean_up_tokenization_spaces=True)
    return model, tokenizer
Mbase, tokenizer = load_model_and_tokenizer("models/gpt2-xl-M0")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
M1(tokenizer.encode("Ciao", return_tensors="pt").to(device))

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ 1.9008, -1.6847, -2.2090,  ..., -5.7860, -4.8315, 15.4105],
         [ 7.9671,  5.9680,  0.0433,  ..., -6.3949, -3.3520,  5.1771]]],
       device='cuda:1', grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[ 0.6723,  0.4973, -0.9757,  ...,  0.1950,  0.3992,  0.2074],
          [-0.4869, -0.0345,  0.3615,  ..., -0.3100, -0.1830,  0.8324]],

         [[ 0.7735, -0.1537, -0.4134,  ..., -0.9463, -0.9032,  0.2914],
          [-1.3950,  0.3026,  0.0725,  ...,  0.5198, -0.0067, -0.6163]],

         [[ 0.6085,  1.0068,  0.1315,  ...,  0.7609,  0.1015,  0.0327],
          [-0.2179,  0.0455,  0.0267,  ..., -0.0882, -0.3272,  0.1117]],

         ...,

         [[-0.1425,  0.7093, -0.3140,  ..., -0.2944, -0.3393,  0.0643],
          [ 0.3780, -0.2175,  0.0946,  ..., -0.5287,  0.0228, -0.4280]],

         [[-0.5845,  0.1274, -0.0693,  ...,  1.2399, -0.0476,  0.4525],
          [-0.3209, -0.3499,  0.6144,  ...,  0.0550, -0.137

: 