In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np

from tqdm import tqdm

from callformer.transformer import ModelDimensions, CallFormer
from callformer.tokenizer import Tokenizer

import pickle
from copy import copy
from datetime import date

DATA_PATH = "drive/MyDrive/search_ai/full_samples.pkl"
MODEL_PATH = "drive/MyDrive/search_ai/callformer_model.chkpt"


with open(DATA_PATH, "rb") as f:
    full_samples = pickle.load(f)

tokenizer = Tokenizer()

token_samples = []

tokens = [{"call": "<|searchnotes|>",
           "args": []},
           {"call": "<|summarize|>"}]

for sample in full_samples:
    search_start_date = ""
    if sample[2][0] != -1:
        search_start_date = f'"{date(year=sample[2][0], month=sample[2][1], day=sample[2][2]).strftime("%Y-%m-%d")}"'
    call_string = (
        f'{tokens[0]["call"]}'
        f'({search_start_date})'
        f'{tokens[1]["call"]}'
        )
    toks = tokenizer.encode(call_string)
    token_samples.append((
                         sample[0], 
                         sample[1],
                         sample[2], 
                         torch.from_numpy(np.array(sample[3])).unsqueeze(0).float(),
                         toks))
    


class CallFormerDataset(Dataset):
    def __init__(self, samples, model_dims: ModelDimensions):
        self.samples = samples
        self.n_ctx = model_dims.n_ctx
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        seq = self.samples[idx][-1]
        full_seq = torch.hstack( (
                        torch.Tensor([tokenizer.sot]), 
                        seq,
                        torch.Tensor([tokenizer.eot]),))
        input = F.pad(full_seq[:-1], (0, self.n_ctx - full_seq.shape[-1]+1), value=tokenizer.pad)
        target = F.pad(full_seq[1:], (0, self.n_ctx - full_seq.shape[-1]+1), value=tokenizer.pad)
        
        embedding = self.samples[idx][-2]

        assert embedding.ndim in (2, 3)
        if embedding.ndim == 2:
            embedding = embedding.unsqueeze(0)

        return embedding, input.to(torch.long), target.to(torch.long)



In [2]:
STATE_SIZE = token_samples[0][-2].shape[-1]

model_dims = ModelDimensions(
                n_vocab=tokenizer.vocab_size,
                n_ctx=100,
                n_state=STATE_SIZE,
                n_head=8,
                n_layer=2)

model = CallFormer(model_dims)

In [6]:
ds = CallFormerDataset(token_samples, model_dims)
dloader = DataLoader(ds, batch_size=2, shuffle=True)

loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 10

min_loss = np.inf

for n in range(epochs):
    for embedding, input, target in tqdm(dloader):
        logits = model.decoder(input, embedding)
        result = logits[:, 1:, :]
        
        optimizer.zero_grad()
        loss = loss_fn(result.mT, target[...,:-1])
        loss.backward()
        optimizer.step()

        if loss.item() < min_loss:
            min_loss = loss.item()
            save_model(model, optimizer, "callformer.pth")

  0%|          | 0/3250 [00:00<?, ?it/s]

245.65151977539062


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|          | 1/3250 [00:03<3:29:34,  3.87s/it]

27.367963790893555


  0%|          | 2/3250 [00:10<4:55:04,  5.45s/it]

202.55030822753906


  0%|          | 3/3250 [00:12<3:37:41,  4.02s/it]

40.33922576904297


  0%|          | 4/3250 [00:15<3:01:34,  3.36s/it]

43.945831298828125


  0%|          | 5/3250 [00:17<2:42:01,  3.00s/it]

35.72275161743164


  0%|          | 6/3250 [00:19<2:26:52,  2.72s/it]

28.86953353881836


  0%|          | 7/3250 [00:22<2:21:40,  2.62s/it]

20.169784545898438


  0%|          | 8/3250 [00:24<2:16:27,  2.53s/it]

17.445350646972656


  0%|          | 9/3250 [00:26<2:17:54,  2.55s/it]

17.238555908203125


  0%|          | 10/3250 [00:29<2:13:32,  2.47s/it]

9.66306209564209


  0%|          | 11/3250 [00:31<2:16:43,  2.53s/it]

9.126280784606934


  0%|          | 12/3250 [00:34<2:23:55,  2.67s/it]

7.996767044067383


  0%|          | 12/3250 [00:38<2:52:42,  3.20s/it]


KeyboardInterrupt: 