<a href="https://colab.research.google.com/github/M-Amrollahi/Personal-Notes/blob/master/ML-notes/how_calculate_numbers_decoder_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q x-transformers
!pip install -q torchinfo

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [37]:
import torch
from torch.optim import Adam
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.functional import F
from tqdm.notebook import tqdm
import torchinfo
import random

## Create a dataset of arithmatic calculations

In [137]:
res = []
for i in range(0,100):
    for j in range(0,100):
        res.append( f"{i} + {j} = {i+j}")
        res.append(f"{i} - {j} = {i-j}")
        res.append(f"{i} * {j} = {i*j}")
        if j != 0:
            res.append( f"{i} / {j} = {i/j:.4f}")
        else:
            res.append( f"{i} / {j} = nan")

random.shuffle(res)
res[:10]


['90 / 49 = 1.8367',
 '50 * 81 = 4050',
 '98 + 53 = 151',
 '40 / 52 = 0.7692',
 '3 * 58 = 174',
 '83 + 69 = 152',
 '69 - 3 = 66',
 '17 / 86 = 0.1977',
 '47 + 79 = 126',
 '29 / 25 = 1.1600']

In [138]:
text = ";".join(res)

chars = sorted(list(set(text)))
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
vocabSize = len(chars)

In [80]:
text[:100],chars

('4 - 1 = 3;1 - 6 = -5;3 / 4 = 0.7500;7 - 7 = 0;2 - 6 = -4;6 * 0 = 0;5 + 8 = 13;8 * 4 = 32;7 - 5 = 2;9',
 [' ',
  '*',
  '+',
  '-',
  '.',
  '/',
  '0',
  '1',
  '2',
  '3',
  '4',
  '5',
  '6',
  '7',
  '8',
  '9',
  ';',
  '=',
  'a',
  'n'])

## Create a pytorch dataset

In [94]:
class cls_data_transformer(Dataset):

    def __init__(self, text, tokens):
        super().__init__()
        self.m_n_tokens = tokens
        lst_formulas = text.split(";")

        self.m_encoded_formulas = list( map(encode, lst_formulas))


    def __len__(self):
        return len(self.m_encoded_formulas)

    def __getitem__(self, index):

        x = torch.cat( (torch.asarray( self.m_encoded_formulas[index]), torch.full((self.m_n_tokens,), stoi[" "])))[:self.m_n_tokens]
        x = x.long()
        return x[:-1], x[1:]


In [95]:
@torch.no_grad()
def f_valLoss(model, xLoader):
    model.eval()
    lstLoss = []
    for xb, yb in xLoader:
        xb,yb = xb.to(device), yb.to(device)
        logits = model.forward(xb)
        #loss = criterion.forward(logits, yb)
        loss = criterion.forward(logits.reshape(-1,logits.shape[-1]), yb.reshape(-1))

        lstLoss.append(loss.cpu().item())
    loss = sum(lstLoss)/len(lstLoss)

    return loss


def generate(model,idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # get the predictions
        logits = model(idx)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [144]:
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = vocabSize,
    max_seq_len = 30,
    attn_layers = Decoder(
        dim = 50,
        depth = 5,
        heads = 3
    )
)
epochs = 5

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model.to(device)



dsTrain = cls_data_transformer(text[:int(.9*len(text))], vocabSize)
dsVal = cls_data_transformer(text[int(.9*len(text)):], vocabSize)
dsTrainLoader = DataLoader(dsTrain, batch_size=8)
dsValLoader = DataLoader(dsVal, batch_size=64)

optimizer = Adam(model.parameters(), lr=3e-3)
criterion = nn.CrossEntropyLoss()

In [145]:
for i in range(epochs):
    lstLoss = list()
    for idx, (xb, yb) in enumerate(tqdm(dsTrainLoader)):


        xb,yb = xb.to(device), yb.to(device)
        logits = model.forward(xb)


        loss = criterion.forward(logits.reshape(-1,logits.shape[-1]), yb.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        lstLoss.append( loss.cpu().item())

    valLoss = f_valLoss(model,dsValLoader)

    print(f"epoch:{i+1}/{epochs} - train loss: {sum(lstLoss)/len(lstLoss):.4f} - val loss: {valLoss:.4f}")

  0%|          | 0/4500 [00:00<?, ?it/s]

epoch:0/5 - train loss: 0.8239 - val loss: 0.9187


  0%|          | 0/4500 [00:00<?, ?it/s]

epoch:1/5 - train loss: 0.7808 - val loss: 0.7597


  0%|          | 0/4500 [00:00<?, ?it/s]

epoch:2/5 - train loss: 0.7630 - val loss: 0.7526


  0%|          | 0/4500 [00:00<?, ?it/s]

epoch:3/5 - train loss: 0.7523 - val loss: 0.7454


  0%|          | 0/4500 [00:00<?, ?it/s]

epoch:4/5 - train loss: 0.7469 - val loss: 0.7599


In [146]:
for i in range(5):

    # choose random formula from val set
    j = random.randint(0,len(dsVal)-1)

    # choose part of the fomula so that to be guessed
    context = dsVal[j][0][:5].reshape(-1,1)

    # estimate and complete the arithmatic formula
    print(decode(generate(model,context, max_new_tokens=20)[0].tolist()))

12 + 80 = 91         
84 + 27 = 104       1
9 / 6 = 25.7097      
37 - 57 = -23       -
25 - 16 = 63         


In [147]:
torchinfo.summary(model)

Layer (type:depth-idx)                             Param #
TransformerWrapper                                 --
├─TokenEmbedding: 1-1                              --
│    └─Embedding: 2-1                              1,000
├─AbsolutePositionalEmbedding: 1-2                 --
│    └─Embedding: 2-2                              1,500
├─Identity: 1-3                                    --
├─Dropout: 1-4                                     --
├─Identity: 1-5                                    --
├─Decoder: 1-6                                     --
│    └─ModuleList: 2-3                             --
│    │    └─ModuleList: 3-1                        38,500
│    │    └─ModuleList: 3-2                        20,350
│    │    └─ModuleList: 3-3                        38,500
│    │    └─ModuleList: 3-4                        20,350
│    │    └─ModuleList: 3-5                        38,500
│    │    └─ModuleList: 3-6                        20,350
│    │    └─ModuleList: 3-7                    