In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np

from tqdm import tqdm

from callformer.transformer import ModelDimensions, CallFormer
from callformer.tokenizer import Tokenizer

import pickle
from copy import copy
from datetime import date

with open("full_samples.pkl", "rb") as f:
    full_samples = pickle.load(f)

tokenizer = Tokenizer()

token_samples = []

tokens = [{"call": "<|searchnotes|>",
           "args": []},
           {"call": "<|summarize|>"}]

for sample in full_samples:
    search_start_date = ""
    if sample[2][0] != -1:
        search_start_date = f'"{date(year=sample[2][0], month=sample[2][1], day=sample[2][2]).strftime("%Y-%m-%d")}"'
    call_string = (
        f'{tokens[0]["call"]}'
        f'({search_start_date})'
        f'{tokens[1]["call"]}'
        )
    toks = tokenizer.encode(call_string)
    token_samples.append((
                         sample[0], 
                         sample[1],
                         sample[2], 
                         torch.from_numpy(np.array(sample[3])).unsqueeze(0).float(),
                         toks))
    


class CallFormerDataset(Dataset):
    def __init__(self, samples, model_dims: ModelDimensions):
        self.samples = samples
        self.n_ctx = model_dims.n_ctx
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input = torch.empty(self.n_ctx)
        target = torch.empty(self.n_ctx)
        seq_len = self.samples[idx][-1].shape[0]
        assert seq_len <= self.n_ctx
        
        input[:seq_len-1] = self.samples[idx][-1][:-1]
        #input[seq_len-1] = tokenizer.eot
        input[seq_len:] = tokenizer.pad
        input = input.to(torch.long)
        target[:seq_len-1] = self.samples[idx][-1][1:]
        #target[seq_len-1] = tokenizer.eot
        target[seq_len:] = tokenizer.pad
        target = target.to(torch.long)
        embedding = self.samples[idx][-2]
        return embedding, input, target


def prepare_training_inputs(
                            tokens: torch.Tensor, 
                            embedding: torch.Tensor, 
                            tokenizer: Tokenizer,
                            ):
    assert embedding.ndim in (2, 3)
    
    if embedding.ndim == 2:
        embedding = embedding.unsqueeze(0)
    n_batches: int = tokens.shape[0]
    start_token = tokenizer.vocab_lookup["<|start|>"]
    start_tokens = torch.tensor([start_token] * n_batches, device=embedding.device).unsqueeze(-1)
    tokens = torch.hstack((start_tokens, tokens[...,:-start_tokens.shape[-1]]))
    return tokens, embedding

In [2]:
STATE_SIZE = token_samples[0][-2].shape[-1]

model_dims = ModelDimensions(
                n_vocab=tokenizer.vocab_size,
                n_ctx=100,
                n_state=STATE_SIZE,
                n_head=8,
                n_layer=2)

model = CallFormer(model_dims)

In [5]:
ds = CallFormerDataset(token_samples, model_dims)
dloader = DataLoader(ds, batch_size=2, shuffle=True)

loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for embedding, input, target in tqdm(dloader):
    input, embedding = prepare_training_inputs(input, embedding, tokenizer)
    logits = model.decoder(input, embedding)
    result = logits[:, 1:, :]
    
    optimizer.zero_grad()
    loss = loss_fn(result.mT, target)
    print (loss.item())
    loss.backward()
    optimizer.step()

  0%|          | 0/3250 [00:00<?, ?it/s]

torch.Size([2, 100]) torch.Size([2, 1])





AssertionError: 