In [1]:
import torch
import numpy as np


from callformer.decoding import DecodingOptions
from callformer.transformer import ModelDimensions, CallFormer
from callformer.tokenizer import Tokenizer

In [2]:
import pickle
from copy import copy
from datetime import date

with open("full_samples.pkl", "rb") as f:
    full_samples = pickle.load(f)

tokenizer = Tokenizer()

token_samples = []

tokens = [{"call": "<|searchnotes|>",
           "args": []},
           {"call": "<|summarize|>"}]

for sample in full_samples:
    search_start_date = ""
    if sample[2][0] != -1:
        search_start_date = f'"{date(year=sample[2][0], month=sample[2][1], day=sample[2][2]).strftime("%Y-%m-%d")}"'
    call_string = (
        f'{tokens[0]["call"]}'
        f'({search_start_date})'
        f'{tokens[1]["call"]}'
        )
    toks = tokenizer.encode(call_string)
    token_samples.append((
                         sample[0], 
                         sample[1],
                         sample[2], 
                         torch.from_numpy(np.array(sample[3])).unsqueeze(0).float(),
                         toks))

[(sample[0], sample[1], tokenizer.decode(sample[-1])) for sample in token_samples[:2]]

[('2021-01-03',
  'Today is Sunday, January 03, 2021. Give me a summary of my notes from the past two days. Focus on the ones that are related to quantum computing.',
  ['<|searchnotes|>("2021-01-01")<|summarize|>']),
 ('2033-03-27',
  'Today is Sunday, March 27, 2033. Summarize my thoughts on AI safety from the past three days. Organize the summary as a timeline.',
  ['<|searchnotes|>("2033-03-24")<|summarize|>'])]

In [5]:
STATE_SIZE = token_samples[0][-2].shape[-1]

model_dims = ModelDimensions(
                n_vocab=tokenizer.vocab_size,
                n_ctx=10,
                n_state=STATE_SIZE,
                n_head=8,
                n_layer=2)

model = CallFormer(model_dims)

In [6]:
model.decode(token_samples[0][-2], DecodingOptions())

(1) _main_loop() torch.Size([1, 1]) torch.Size([1, 1, 1536])
(2) _main_loop() torch.Size([1, 1]) torch.Size([1, 1, 1536])
PyTorchInference.logits(): torch.Size([1, 1]) torch.Size([1, 1, 1536])
(1) Decoder.forward(): torch.Size([1, 1, 1536]), torch.Size([1, 1])
(2) Decoder.forward(): self.token_embedding(x).shape=torch.Size([1, 1, 1536]), x.shape=torch.Size([1, 1])
(3) Decoder.forward(): torch.Size([1, 1, 1536]), torch.Size([1, 1, 1536])
MHA.forward(): torch.Size([1, 1, 1536]) None
MHA.forward(): torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536])
(1) MHA.qkv_attention(): torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536])
MHA.forward(): torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536])
MHA.forward(): torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536])
(1) MHA.qkv_attention(): torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536]) torch.Size([1, 1, 1536])
MHA.forward(): torch.Size([1, 1, 1536]) None
MHA.forward(): torch.Siz

TypeError: 'NoneType' object is not subscriptable

In [None]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 288.416MB


In [None]:
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


class Test(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("positional_embedding", sinusoids(1, 150))

    def forward(self, x):
        return x + self.positional_embedding
    
test = Test()
test(torch.randn(1, 1, 150)).shape

torch.Size([1, 1, 150])