In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import shuffle
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1337)

eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 50 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4

In [2]:
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

torch.manual_seed(1337)

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits = model(X)[0]
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = Y.view(B*T)
            loss = F.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [3]:
class LongTermMemory(nn.Module):
    def __init__(self, memory_size=10, input_size=50, feature_size=65):
        super().__init__()
        self.memory = nn.Parameter(torch.zeros(memory_size, input_size, feature_size))
        self.memory_size = memory_size
        self.feature_size = feature_size
        self.W_query = nn.Linear(feature_size, input_size)
        self.W_key = nn.Linear(feature_size, input_size)
        self.W_value = nn.Linear(feature_size, input_size)

    def forward(self, x):
        query = self.W_query(x)
        keys = self.W_key(self.memory)
        values = self.W_value(self.memory)

        attention_scores = torch.softmax(torch.matmul(query, keys.transpose(-2,-1)), dim=-1)
        retrieved_memory = torch.matmul(attention_scores, values)
        return retrieved_memory, attention_scores
    
    def update_memory(self, x, surprise, decay=0.9):
        self.memory.data *= decay

        with torch.no_grad():
            insertion_index = surprise.argmax().item() % self.memory.size(0)
            self.memory.data[insertion_index] = x


class TitanModel(nn.Module):
    def __init__(self, memory_size=50, input_size=50, feature_size=384):
        super().__init__()
        self.n_vocab = 65
        self.conl = 200
        self.embn = feature_size
        self.word_emb = nn.Embedding(self.n_vocab, self.embn)
        self.pos_enc = nn.Embedding(self.conl, self.embn)
        self.memory_module = LongTermMemory(memory_size,input_size, feature_size)
        self.output_layer = nn.Linear(self.embn, self.n_vocab)

    def forward(self, x, requires_embedding=True):
        if requires_embedding:
            B, T = x.shape

            tok = self.word_emb(x)
            pos = self.pos_enc(torch.arange(T, device=device))
            x = tok + pos
        retrieved_memory, attention_scores = self.memory_module(x)
        output = self.output_layer(retrieved_memory)
        return output, attention_scores
    
    def test_update(self, x, target, loss_fn):
        B, T = x.shape

        tok = self.word_emb(x)
        pos = self.pos_enc(torch.arange(T, device=device))
        x = tok + pos
        x = x.detach().clone().requires_grad_(True)

        output, attn = self(x, False)
        B, T, C = output.shape
        output = output.view(B*T, C)
        target = target.view(B*T)
        loss = loss_fn(output, target)
        loss.backward()
        surprise = torch.abs(x.grad)
        
        for i in range(surprise.shape[0]):
            self.memory_module.update_memory(x.detach()[i,:,:], surprise[i,:,:])

        return loss.item()
    
    def generate(self, x, max_length):
        for _ in range(max_length):
            x_cond = x[:, -self.conl:]
            logits, attn = self(x_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            x_next = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, x_next), dim=1)

        return x
    
model = TitanModel(memory_size=40, feature_size=384).to(device)

In [4]:
print(vocab_size)

65


In [5]:
print(get_batch('train'))

(tensor([[43, 62, 55,  ..., 58, 46, 39],
        [50,  1, 54,  ...,  1, 46, 47],
        [ 1, 49, 43,  ...,  1, 27, 18],
        ...,
        [ 1, 58, 46,  ..., 63,  1, 61],
        [47, 49, 43,  ..., 39, 41, 46],
        [21, 33, 31,  ..., 52, 47, 52]], device='cuda:0'), tensor([[62, 55, 59,  ..., 46, 39, 58],
        [ 1, 54, 59,  ..., 46, 47, 57],
        [49, 43, 43,  ..., 27, 18,  1],
        ...,
        [58, 46, 39,  ...,  1, 61, 47],
        [49, 43,  1,  ..., 41, 46,  1],
        [33, 31, 10,  ..., 47, 52, 45]], device='cuda:0'))


In [13]:
print(model(get_batch('train')[0]))

RuntimeError: The size of tensor a (64) must match the size of tensor b (40) at non-singleton dimension 0

In [8]:

m = model.to(device)
# print the number of parameters in the model
#print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
print(next(m.parameters()).is_cuda)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, attn = model(xb)
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    targets = yb.view(B*T)
    #loss = F.cross_entropy(logits, targets)
    loss = model.test_update(xb,yb,F.cross_entropy)
    #optimizer.zero_grad(set_to_none=True)
    #loss.backward()
    #optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_length=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

True


RuntimeError: The size of tensor a (64) must match the size of tensor b (40) at non-singleton dimension 0

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_length=500)[0].tolist()))


ly nbtho n rhaloi ante fehnersa ymnnhsa hhwAtgadd  dnvesi 
 enslag kypLof eh !rsae
Tthm et,tyot  Eetbo aoeaetalvyl :yteevto  n  wEaltUi, benat a,udio enielwa
EkahehHrndulpHdbnaeyu hlarahphobiIis.hoSye hyfywnys
 HeieH e o:Aehaasseio efuni,fK dla srhtsoe SdOe:e D e  s RreIte beeiAsHuetisbiG
crf;d  oe w
, ,ytqpstti i
sdeaf!ehgunc !
l hs h:ydesdnwsTrtebiorry s
ch
twt

n  dh
o eioUmtamtt euaiosd   hiohede tmEyfceahhi lchs    Ef  y sOtb  o
 iI ie  muOaapciklm  r ,
 oEess,naSonmeoah d
wltn;nsue UNtmy  
