In [34]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1337)

eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 50 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4

win_size = 25

In [35]:
def _chunk(x, w):
    '''convert x into overlapping chunks. Chunk size = 2w, overlap size = w'''
    x = x.view(x.size(0), x.size(1)// (w * 2), w * 2, x.size(2))
    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1
    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2

    return x.as_strided(size=chunk_size, stride=chunk_stride)

def _skew(x, dir, pad):
    '''Convert diagonals into columns'''
    x_pad = F.pad(x, dir, value=pad)
    x_pad = x_pad.view(*x_pad.size()[:-2], x_pad.size(-1), x_pad.size(-2))
    return x_pad

def _skewv(x, pad):
    B, C, M, L = x.size()
    x = F.pad(x, (0, M+1), value=pad)
    x = x.view(B, C, -1)
    x = x[:, :, :-M]
    x = x.view(B, C, M, M+L)
    x= x[:,:,:,:-1]
    return x


def sliding_chunk_matmul(q, k, w, pad):
    B,T,C = q.shape

    assert T % (w * 2) == 0
    assert q.shape == k.shape

    chunk_count = T // w- 1

    #q = q.transpose(1,2)
    #k = k.transpose(1,2)

    qchunk = _chunk(q, w)
    kchunk = _chunk(k, w)

    chunked_attn = torch.einsum('bcxd, bcyd->bcxy', (qchunk, kchunk))

    diag_chunk_attn = _skew(chunked_attn, dir=(0,0,0,1), pad=pad)
    diag_attn = torch.zeros((B,chunk_count + 1, w, w * 2 + 1), device=q.device)
    #diag_attn = diag_chunk_attn.new_empty((B,chunk_count + 1, w, w * 2 + 1))
    
    diag_attn[:, :-1, :, w:] = diag_chunk_attn[:, :, :w, :w+1]
    diag_attn[:, -1, :, w:] = diag_chunk_attn[:, -1, w:, :w+1]

    diag_attn[:, 1:, :, :w] = diag_chunk_attn[:, :, -(w+1):-1, w+1:]
    diag_attn[:, 0, 1:w, 1:w] = diag_chunk_attn[:, 0, :w-1, 1-w:]

    diag_attn = diag_attn.view(B, T, 2 * w +1)#.transpose(2, 1)

    return diag_attn

def sliding_chunk_matmul_v(attn, v, w):
    B,T,C = v.shape

    assert T % (w * 2) == 0
    assert attn.size()[:2] == v.size()[:2]
    assert attn.size(2) == 2 * w + 1

    chunk_count = T // w- 1

    #q = q.transpose(1,2)
    #k = k.transpose(1,2)

    chunk_prob = attn.reshape(B, T//w, w, 2*w+1)

    pad_v = F.pad(v, (0,0,w,w), value=-1)

    chunk_v_size = (B, chunk_count+1, 3*w, C)
    chunk_v_stride = pad_v.stride()
    chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
    chunk_v = pad_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)

    skew_prob = _skewv(chunk_prob, pad=0)
    context = torch.einsum('bcwd,bcdh->bcwh', (skew_prob, chunk_v))

    return context.view(B, T, C)

In [36]:
v = torch.rand(1,8,1,384)
k = torch.rand(1,8,1,384)
q = torch.rand(1,8,1,384)

v2 = torch.rand(1,50,10)
k2 = torch.rand(1,50,10)
q2 = torch.rand(1,50,10)

#sliding_chunks_matmul_qk(q,v, 2, 0)
attn = sliding_chunk_matmul(q2, k2, win_size, 0)
print(attn)
sliding_chunk_matmul_v(attn, v2,win_size)

tensor([[[0.0000, 0.0000, 0.0000,  ..., 1.6600, 1.8496, 1.8551],
         [0.0000, 2.2924, 1.3036,  ..., 2.7999, 2.0772, 2.4124],
         [0.0000, 1.2098, 2.4661,  ..., 2.8661, 2.9248, 2.9016],
         ...,
         [2.2204, 2.3682, 2.8499,  ..., 1.5737, 3.2960, 1.9400],
         [2.4293, 2.8384, 2.3147,  ..., 3.3519, 2.2143, 2.5128],
         [3.3137, 2.3607, 2.9765,  ..., 0.0000, 0.0000, 0.0000]]])


tensor([[[ 17.6377,  20.0018,  18.1015,  25.6739,  18.4093,  20.7069,  22.0540,
           19.2566,  20.7492,  23.0306],
         [-16.5276, -14.6500, -16.9186,  -6.3146, -14.4112, -12.2407,  -9.9455,
          -12.7977, -11.7430,  -9.5899],
         [-21.1664, -20.7011, -22.0192,  -9.5137, -18.2107, -15.8548, -13.7926,
          -17.1841, -14.9103, -12.9988],
         [-37.9667, -36.4896, -39.0957, -29.2821, -35.3898, -35.1796, -33.2945,
          -35.4808, -33.7225, -32.2936],
         [-10.8993,  -8.9786, -12.6159,   1.5115,  -9.0796,  -6.0268,  -5.3072,
           -7.8472,  -3.8100,  -2.9759],
         [-15.6266, -15.9906, -17.2088,  -3.8856, -14.3705, -10.5381,  -8.7546,
          -11.7928,  -8.9814,  -7.2145],
         [-26.3584, -26.0478, -28.4386, -16.8168, -23.6308, -21.9483, -20.9290,
          -21.5302, -20.1243, -20.6471],
         [  3.0895,   4.5688,   2.5393,  18.2942,   7.9842,  10.7343,  12.5366,
           11.6372,  12.9217,  13.3890],
         [-15.1271, -15.3044, -1

In [37]:
class attnSlidingWindow(nn.Module):
    '''
        Attention(Q, K, V ) = softmax( QK^T/√d_k)V 
    
    '''
    #Takes number of embedded, head_size, context length
    def __init__(self, embn, hdim, con_l, drop=0.0):

        super(attnSlidingWindow, self).__init__()
        #dim is (d_k) when sqrt'd it is meant to counter small gradients in large sets of queries and keys
        self.k = nn.Linear(embn, hdim, bias=False)
        self.q = nn.Linear(embn, hdim, bias=False)
        self.v = nn.Linear(embn, hdim, bias=False)
        self.d_k = np.sqrt(hdim)

        self.register_buffer('mask', torch.tril(torch.ones(con_l,con_l)))
        #Simple drop out 
        self.drop = nn.Dropout(drop)

    def forward(self, x, ret_att=False):
        #batch X length X dim
        B,T,C = x.shape
        k = self.k(x)
        q = self.q(x)

        n = sliding_chunk_matmul(q, k, win_size, 0) * k.shape[-1]**-0.5
        n = n.masked_fill(self.mask[:T,:n.shape[-1]]==0, float('-inf'))
        #Drop out referenced later in paper but not in original diagram
        att = self.drop(F.softmax(n, dim=-1))

        v = self.v(x)

        out = sliding_chunk_matmul_v(att, v, win_size)
        if ret_att:
            return out, att 
        return out

In [38]:
#Scaled dot product attention testing
#dim should be size of q and k


windowattn = attnSlidingWindow(384, 10, 100, drop=0.2)

v = torch.rand(1,50,384)



print(windowattn(v))

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [-1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00,
          -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00],
         [-1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00,
          -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00],
         [-1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00,
          -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00, -1.2500e+00],
         [-9.9393e-01, -9.9393e-01, -9.9393e-01, -9.9393e-01, -9.9393e-01,
          -9.9393e-01, -9.9393e-01, -9.9393e-01, -9.9393e-01, -9.9393e-01],
         [-8.3276e-01, -8.3276e-01, -8.3276e-01, -8.3276e-01, -8.3276e-01,
          -8.3276e-01, -8.3276e-01, -8.3276e-01, -8.3276e-01, -8.3276e-01],
         [-7.2674e-01, -7.2674e-01, -7.2674e-01, -7.2674e-01, -7.2674e-01,
          -7.2674e-

In [39]:
class multiHeadedAttention(nn.Module):
    def __init__(self, n_heads, dims, embn, con_l, dropout=0.0):
        super(multiHeadedAttention, self).__init__()
        #d_k=d_v = dims/h

        self.n_heads = n_heads

        self.attn = nn.ModuleList([attnSlidingWindow(embn, dims, con_l) for _ in range(n_heads)])
        #Final linear layer after concat and attention
        self.fc = nn.Linear(n_heads*dims, embn)

        self.drop = nn.Dropout(dropout)
        

    def forward(self, x):
        out = torch.cat([h(x) for h in self.attn], dim=-1)
        out = self.drop(self.fc(out))
        return out

In [40]:
#heads, d_model, d_km d_v as per the paper
torch.manual_seed(1337)
multiHead = multiHeadedAttention(6, 50, 384, 512, dropout=0.2)

#batches, dims, dimensionalityxn_heads

v = torch.rand(1,50,384)


print(multiHead(v))

tensor([[[-2.1719, -0.8455, -0.2509,  ...,  0.0000, -0.0000,  0.4248],
         [-2.1719, -0.8455, -0.2509,  ...,  1.2850, -0.4226,  0.4248],
         [-2.1719, -0.0000, -0.2509,  ...,  1.2850, -0.4226,  0.0000],
         ...,
         [-0.0000, -0.2245, -0.0643,  ...,  0.4177, -0.2174,  0.2085],
         [-0.0000, -0.2608, -0.0000,  ...,  0.4212, -0.2153,  0.2435],
         [-0.0000, -0.0000, -0.0838,  ...,  0.5351, -0.2666,  0.2365]]],
       grad_fn=<MulBackward0>)


In [41]:
class positionFeedFoward(nn.Module):
    def __init__(self, inp, hid, drop=0.0):
        super(positionFeedFoward, self).__init__()
        self.w1 = nn.Linear(inp,4*hid)
        self.w2 = nn.Linear(4*hid,inp)
        self.drop = nn.Dropout(drop)

    def forward(self, x):

        x = self.w2(F.relu(self.w1(x)))
        x = self.drop(x)

        return x

In [42]:
class Decoder(nn.Module):
    '''Combinds MultiHeadedAttention and FeeForward, three layers'''
    def __init__(self, nheads, embn, con_l, drop=0.0):
        super(Decoder, self).__init__()
        head_size = embn // nheads
        self.slf_attn = multiHeadedAttention(nheads, head_size,embn, con_l, dropout=drop)
        
        self.ffn = positionFeedFoward(embn, embn, drop=drop)

        self.norm1 = nn.LayerNorm(embn)
        self.norm2 = nn.LayerNorm(embn)

    def forward(self, x):
        x = x + self.slf_attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

In [43]:
#heads, d_model, d_km d_v as per the paper
enc = Decoder(8, 64, 512)
#batches, dims, dimensionalityxn_heads

v = torch.rand(1,50,64)


enc(v)

tensor([[[ 1.1804, -0.3295, -0.4308,  ...,  2.0671,  0.9802,  0.9299],
         [ 0.9376, -0.4714, -0.2999,  ...,  1.7397,  1.5895,  0.6495],
         [ 1.4974,  0.0910, -0.2238,  ...,  1.9109,  1.7711,  1.0007],
         ...,
         [ 1.0018, -0.1989, -0.3057,  ...,  1.8529,  1.4204,  0.9704],
         [ 1.0441,  0.1497,  0.5606,  ...,  0.6702,  1.6538,  0.8308],
         [ 0.9663, -0.2711, -0.1064,  ...,  1.7594,  1.4071,  0.4604]]],
       grad_fn=<AddBackward0>)

In [44]:
class languageModel(nn.Module):
    '''Decoder model'''
    def __init__(
            self, n_vocab, embn, n_layers, n_head, dropout=0.2 , con_l=200
    ):
        super(languageModel, self).__init__()
        self.con_l = con_l
        self.word_emb = nn.Embedding(n_vocab, embn)
        self.pos_enc = nn.Embedding(con_l, embn)
        self.stack = nn.Sequential(
            *[Decoder( n_head, embn, con_l, drop=dropout) for _ in range(n_layers)]
        )
       
        self.layer_norm = nn.LayerNorm(embn)
        self.fc = nn.Linear(embn, n_vocab)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, tar=None):
        #batch, time
        B, T = x.shape

        tok = self.word_emb(x)
        pos = self.pos_enc(torch.arange(T, device=device))
        x = tok + pos
        x = self.stack(x)
        x = self.layer_norm(x)
        logits = self.fc(x)

        if tar is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            tar = tar.view(B*T)
            loss = F.cross_entropy(logits, tar)

        return logits, loss
    
    def generate(self, x, max_length):
        #x is a BxT array of in current context
        fullout=x
        for _ in range(max_length):
            x_cond = x[:, -win_size*2:]
            logits, loss = self(x_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            x_next = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, x_next), dim=1)

        return x
    

In [45]:

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

torch.manual_seed(1337)

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits = model(X)[0]
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = Y.view(B*T)
            loss = F.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [46]:
model = languageModel(vocab_size,  384,6, 6, con_l=100
    )
m = model.to(device)
# print the number of parameters in the model
#print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
print(next(m.parameters()).is_cuda)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    #B, T, C = logits.shape
    #logits = logits.view(B*T, C)
    #targets = yb.view(B*T)
    #loss = F.cross_entropy(logits, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 50), dtype=torch.long, device=device)
print(decode(m.generate(context, max_length=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

True


KeyboardInterrupt: 

In [33]:
cont = torch.zeros((1, 50), dtype=torch.long, device=device)
print(decode(m.generate(cont, max_length=500)[0].tolist()))

AssertionError: 