In [51]:
import os.path
import socket
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np



from torch.utils.data import Dataset

if socket.gethostname() == 'LTSSL-sKTPpP5Xl':
    data_dir = 'C:\\Users\\ams90\\PycharmProjects\\ConceptsBirds\\data'
elif socket.gethostname() == 'LAPTOP-NA88OLS1':
    data_dir = 'D:\\data\\'
else:
    data_dir = '/home/bwc/ams90/datasets/caltecBirds/CUB_200_2011'

In [20]:
train_file = open(os.path.join(data_dir, 'Text', 'train.csv'))
val_file   = open(os.path.join(data_dir, 'Text', 'validation.csv'))
test_file  = open(os.path.join(data_dir, 'Text', 'test.csv'))

In [21]:
train_text = train_file.read()
val_text   = val_file.read()

In [22]:
print(train_text[:1000])

text
"First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for reve

In [92]:
chars = sorted(list(set(train_text)))
vocab_size = len(chars)
print('Start',''.join(chars))

Start 
 !"$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [24]:
stoi = {ch:i for i, ch in enumerate(chars)} # string to integer encoding scheme
itoc = {i:ch for i, ch in enumerate(chars)} # string to integer decoding scheme
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itoc[i] for i in l])

In [25]:
encode('hi there')

[47, 48, 1, 59, 47, 44, 57, 44]

In [26]:
decode(encode('hi there'))

'hi there'

In [27]:
train_tokens, val_tokens = encode(train_text), encode(val_text)
train_data,   val_data   = torch.tensor(train_tokens, dtype=torch.long), torch.tensor(val_tokens, dtype=torch.long)

In [28]:
train_data[:1000]

tensor([59, 44, 63, 59,  0,  3, 19, 48, 57, 58, 59,  1, 16, 48, 59, 48, 65, 44,
        53, 11,  0, 15, 44, 45, 54, 57, 44,  1, 62, 44,  1, 55, 57, 54, 42, 44,
        44, 43,  1, 40, 53, 64,  1, 45, 60, 57, 59, 47, 44, 57,  7,  1, 47, 44,
        40, 57,  1, 52, 44,  1, 58, 55, 44, 40, 50,  9,  0,  0, 14, 51, 51, 11,
         0, 32, 55, 44, 40, 50,  7,  1, 58, 55, 44, 40, 50,  9,  0,  0, 19, 48,
        57, 58, 59,  1, 16, 48, 59, 48, 65, 44, 53, 11,  0, 38, 54, 60,  1, 40,
        57, 44,  1, 40, 51, 51,  1, 57, 44, 58, 54, 51, 61, 44, 43,  1, 57, 40,
        59, 47, 44, 57,  1, 59, 54,  1, 43, 48, 44,  1, 59, 47, 40, 53,  1, 59,
        54,  1, 45, 40, 52, 48, 58, 47, 13,  0,  0, 14, 51, 51, 11,  0, 31, 44,
        58, 54, 51, 61, 44, 43,  9,  1, 57, 44, 58, 54, 51, 61, 44, 43,  9,  0,
         0, 19, 48, 57, 58, 59,  1, 16, 48, 59, 48, 65, 44, 53, 11,  0, 19, 48,
        57, 58, 59,  7,  1, 64, 54, 60,  1, 50, 53, 54, 62,  1, 16, 40, 48, 60,
        58,  1, 26, 40, 57, 42, 48, 60, 

In [29]:
block_size = 8
train_data[:block_size+1]

tensor([59, 44, 63, 59,  0,  3, 19, 48, 57])

In [30]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target  = y[t]
    print(context, target)
    

tensor([59]) tensor(44)
tensor([59, 44]) tensor(63)
tensor([59, 44, 63]) tensor(59)
tensor([59, 44, 63, 59]) tensor(0)
tensor([59, 44, 63, 59,  0]) tensor(3)
tensor([59, 44, 63, 59,  0,  3]) tensor(19)
tensor([59, 44, 63, 59,  0,  3, 19]) tensor(48)
tensor([59, 44, 63, 59,  0,  3, 19, 48]) tensor(57)


In [36]:
batch_size, block_size = 4,8
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(high=(len(data) - block_size), size=(batch_size,))
    x =   torch.stack([data[i  :   i + block_size    ] for i in ix])
    y =   torch.stack([data[i + 1: i + block_size + 1] for i in ix])
    return x, y
xb, yb = get_batch('train')

xb, yb


(tensor([[24, 22, 27, 20,  1, 21, 18, 27],
         [52,  9,  0,  0, 14, 18, 43, 48],
         [47,  1, 47, 48, 52,  1, 59, 54],
         [43,  7,  0, 33, 47, 44,  1, 52]]),
 tensor([[22, 27, 20,  1, 21, 18, 27, 31],
         [ 9,  0,  0, 14, 18, 43, 48, 51],
         [ 1, 47, 48, 52,  1, 59, 54,  1],
         [ 7,  0, 33, 47, 44,  1, 52, 40]]))

In [50]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target  = yb[b, t]
        print(f"when the input is {context.tolist()} the target is {target}")
        

when the input is [24] the target is 22
when the input is [24, 22] the target is 27
when the input is [24, 22, 27] the target is 20
when the input is [24, 22, 27, 20] the target is 1
when the input is [24, 22, 27, 20, 1] the target is 21
when the input is [24, 22, 27, 20, 1, 21] the target is 18
when the input is [24, 22, 27, 20, 1, 21, 18] the target is 27
when the input is [24, 22, 27, 20, 1, 21, 18, 27] the target is 31
when the input is [52] the target is 9
when the input is [52, 9] the target is 0
when the input is [52, 9, 0] the target is 0
when the input is [52, 9, 0, 0] the target is 14
when the input is [52, 9, 0, 0, 14] the target is 18
when the input is [52, 9, 0, 0, 14, 18] the target is 43
when the input is [52, 9, 0, 0, 14, 18, 43] the target is 48
when the input is [52, 9, 0, 0, 14, 18, 43, 48] the target is 51
when the input is [47] the target is 1
when the input is [47, 1] the target is 47
when the input is [47, 1, 47] the target is 48
when the input is [47, 1, 47, 48]

In [87]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):#
        super().__init__()
        # each token directy reads the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        #idx and targets are tensors with dimensions (batch_size, block_size)

        # logits is tensor with dimensions (batch_size, block_size, vocab_size)
        logits = self.token_embedding_table(idx)
       
        
        
        if targets is None:
            loss = None
        else:
            b, t, c = logits.shape
            logits  = logits.view(b * t, c)
            targets = targets.view(-1)
            logits  = logits.view(b * t, c)
            loss   = F.cross_entropy(logits, targets)
            
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            #idx is tensor with dimensions (batch_size, block_size)
            logits, loss = self(idx)
            #focus on the last character
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            #samples from the probability distribution.
            idx_next = torch.multinomial(input=probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)

torch.Size([32, 66])


In [93]:
print(decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


BQP"VpQmdP:OY&,ZDDffjzP"TDjNa.lhUe'HV-rorxZOJQtlzY oYVY"tUk;UADM;HpdF cSJQ,WFwxYXuclgJnHuRSJNb&
?ITM


In [96]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [101]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.4722740650177


In [102]:
print(decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


wer thad nbrear a:
REENangrat at, st ilend ct whillld IARO,
Heg gheee cosom. onges d ur toon!
NUKE:
