In [9]:
# OS AND IO IMPORTS
import os
work_dir = os.getcwd()
import logging
logging.getLogger().setLevel(logging.CRITICAL)
import warnings
warnings.filterwarnings("ignore")
import pickle
import csv 
from tqdm import tqdm 

# ML AND SCI LIBRARIES
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

In [10]:
# Load the M_0 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [11]:
class PizzaSynthDataset(Dataset):
    
    def __init__(self, dataset, data_path = os.path.join(work_dir, "sdata")):
        super().__init__()

        self.dataset_path = os.path.join(data_path, dataset)

        self.pizza_list = []
        self.end_of_text_token = "<|endoftext|>"

        with open(self.dataset_path, "rb") as pkl:
            self.pizza_list = [list(item)[0] for item in pickle.load(pkl)]
            
    def __len__(self):
        return len(self.pizza_list)
    
    def __getitem__(self, idx):
        return self.pizza_list[idx]
                

In [12]:
dataset = PizzaSynthDataset("pizza_selfgpt2_synthdataset.pkl")
pizza_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [13]:
BATCH_SIZE = 8
EPOCHS = 2
LEARNING_RATE = 3e-5
WARMUP_STEPS = 300

In [14]:
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=-1)
models_folder = "trained_models"

In [15]:
import plotly.graph_objects as go

fig = go.FigureWidget(data = [])
fig.add_scatter()
fig.update_xaxes(title_text='Steps')
fig.update_yaxes(title_text='Loss')

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '371ff6bb-5849-49d8-83dc-f93c502529a4'}],
    'layout': {'template': '...', 'xaxis': {'title': {'text': 'Steps'}}, 'yaxis': {'title': {'text': 'Loss'}}}
})

In [16]:
loss_series = []
for epoch in tqdm(range(EPOCHS), desc="Epoch"):
    for i, pizza in enumerate(pizza_loader):
        pizza_tensor = torch.tensor(tokenizer.encode(pizza[0])).unsqueeze(0).to(device)
        outputs = model(pizza_tensor, labels=pizza_tensor)
        loss, logits = outputs.loss, outputs.logits
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        model.zero_grad()
        loss_series.append(loss.to("cpu").detach().numpy())
        # Update the plotly figure with the current loss
        fig.data[0].y = loss_series

    
    torch.save(model.state_dict(), os.path.join(models_folder, f"M_1_{epoch}.pt"))  

Epoch:   0%|          | 0/2 [03:01<?, ?it/s]


KeyboardInterrupt: 

In [None]:
for epoch in tqdm(range(EPOCHS), desc="Epoch"):

    for i, pizza in enumerate(pizza_loader):
        
        pizza_tensor = torch.tensor(tokenizer.encode(pizza[0])).unsqueeze(0).to(device)

        if pizza_tensor.size()[1] > MAX_SEQ_LEN:
            continue

        if not torch.is_tensor(tmp_pizza_tensor):
            tmp_pizza_tensor = pizza_tensor
            continue
        else: 
            if tmp_pizza_tensor.size()[1] + pizza_tensor.size()[1] > MAX_SEQ_LEN:
                work_pizza_tensor = tmp_pizza_tensor
                tmp_pizza_tensor = pizza_tensor
            else: 
                tmp_pizza_tensor = torch.cat([tmp_pizza_tensor, pizza_tensor[:, 1:]], dim=1)
                continue

        outputs = model(work_pizza_tensor, labels=work_pizza_tensor)
        loss, logits = outputs[:2]
        loss.backward()
        sum_loss = sum_loss + loss.detach().data

        proc_seq_count += proc_seq_count + 1 
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0
            batch_count += 1
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()
        
        if batch_count == 1:
            print(f"sum loss: {sum_loss}")
            batch_count = 0
            sum_loss = 0

    

Epoch:   0%|          | 0/3 [03:12<?, ?it/s]


KeyboardInterrupt: 