In [9]:
import sys
import os 
home_dir = os.path.expanduser("~")
work_dir = os.path.join(home_dir, "synth", "highlighting")
import logging
logging.getLogger().setLevel(logging.INFO)
import warnings
warnings.filterwarnings("ignore")

import time
import pickle
import json
from tqdm import tqdm

# ML AND SCI LIBRARIES
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import CrossEntropyLoss

# XENT code
from xentlang import X
from utils import Tee

device = torch.device("cuda:3")
models_path = os.path.join(work_dir, "models")
data_path = os.path.join(work_dir, "data")

# utility parameters
cut_dataset = None

# Hyperparameters
LEARNING_RATE = 6e-4 # take it from Karpathy nano-GPT 
EPOCHS = 15
# TODO add all the available hyperparameters
data_split = 0.6 # train/test ratio

beta1 = 0.1
beta2 = 0.95
grad_clip = 1.0


def load_model_and_tokenizer(path: str):
    model = AutoModelForCausalLM.from_pretrained(path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(path, clean_up_tokenization_spaces=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

def load_model(path: str):
    model = AutoModelForCausalLM.from_pretrained(path).to(device)
    return model

def load_torch_model(path: str):
    model = torch.load(path, weights_only=False)
    return model 

def load_tokenizer(path: str):
    tokenizer = AutoTokenizer.from_pretrained(path, clean_up_tokenization_spaces=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

# DATA LOADING METHOD
def load_dataset(name: str):
    with open(os.path.join(data_path, f"{name}.pkl"), "rb") as data:
        return pickle.load(data)
    
class TextDataset(Dataset):
    def __init__(self, dataset: list[str], tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.dataset = [self.tokenize(text) for text in tqdm(dataset)]

    def tokenize(self, text): 
        return self.tokenizer(
            text, 
            return_tensors="pt", 
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        ).to(device)
    
    def tokenize_single(self, text):
        return self.tokenizer(
            text, 
            return_tensors="pt",
            padding=True
        ).to(device)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index) -> str:
        return self.dataset[index]

def find_xent_def(tokens):
    """ Returns the index at which the xent function starts, needed for starting the loss computation """
    xdefseq = tokenizer.encode(X.xdef, return_tensors="pt").to(device)
    seq_len = xdefseq.shape[1]
    windows = tokens.input_ids.unfold(dimension=2, size=seq_len, step=1)
    matches = (windows==xdefseq).all(dim=3)
    indices = matches.nonzero().squeeze(0)
    return indices


# load the model
path = os.path.join(models_path, "gpt2-xl-M0")
M0, tokenizer = load_model_and_tokenizer(path)
def tokenize(text): return tokenizer(text, return_tensors="pt", padding=True).to(device)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
from highlight import WikiArticle

In [11]:
a = WikiArticle(tokenize)

In [15]:
tokenize(a.article["text"])

Token indices sequence length is longer than the specified maximum sequence length for this model (2341 > 1024). Running this sequence through the model will result in indexing errors


{'input_ids': tensor([[21926, 26345, 40926,  ...,   262, 20572, 25219]], device='cuda:3'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:3')}