# Dataset

In [438]:
import torch
from torch.utils.data import Dataset

def tokenise(txt):
    # Char tokens
    l = [c for c in txt]
    l = list(set(l))
    return l


class CodingDataset(Dataset):
    def __init__(self, src, seq_len):
        with open(src) as f:
            txt  = f.read()
            ln_s = txt.split("\n\n")
            self.pairs = [pair.split("\n") for pair in ln_s]
        self.tokens  = tokenise(txt)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.pairs)

    def decode(self, idx_s):
        d   = torch.argmax(idx_s, dim=1)
        t_s = "".join([self.tokens[i] for i in d])
        return t_s

    def encode(self, txt):
        e = torch.nn.functional.one_hot(
            torch.tensor([self.tokens.index(c) for c in txt]),
            num_classes=len(self.tokens))
        l = e.shape[0]
        e = torch.concat(
            (
                e,
                torch.zeros((self.seq_len - l, len(self.tokens)))
            ),
            dim=0)
        return e
    
    def __getitem__(self, idx):
        comment, code = self.pairs[idx]
        comment, code = self.encode(comment), self.encode(code)
        return comment, code

In [439]:
seq_len = 32

In [440]:
dataset = CodingDataset(
    "./dataset/main.py",
    seq_len=seq_len)
comment, code = dataset[2]
comment, code, dataset.tokens

(tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1

In [441]:
dataset.decode(comment), dataset.decode(code)

('# Print yah\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 'print("yah")\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n')

In [442]:
comment.shape, code.shape

(torch.Size([32, 21]), torch.Size([32, 21]))

# LSTM Model

In [452]:
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self, input_dim, out_dim, num_layers, hidden_dim=32, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout)
        self.fc1 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc1(x)
        x = F.softmax(x, dim=-1)
        return x

# Transformer Model

In [444]:
import torch.nn as nn
import torch.nn.functional as F

from lib.transformer import TransformerEncoderLayer

class Model(nn.Module):
    def __init__(self, n_heads, out_dim, num_layers, model_size=32, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(
            d_model=model_size,
            nhead=n_heads,
            relative_positional=False,
            relative_positional_distance=32,
            dim_feedforward=dim_feedforward,
            dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.w_out = nn.Linear(model_size, out_dim)
    
    def forward(self, x):
        print(x.shape)
        
        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)

        x = self.w_out(x)

        return x

# Trainset

In [453]:
#x = torch.unsqueeze(comment, dim=0)
#y = torch.unsqueeze(code, dim=0)

comment_s, code_s = [], []
# for i in range(2, 3):
for i in range(len(dataset)):
    comment, code = dataset[i]
    comment_s.append(comment)
    code_s.append(code)

x_s = torch.stack(comment_s, dim=0).to("cuda")
y_s = torch.stack(code_s, dim=0).to("cuda")

# Train

In [454]:
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt


# LSTM
model = Model(
    input_dim=len(dataset.tokens),
    out_dim=len(dataset.tokens),
    num_layers=3,
    hidden_dim=256,
    dropout=0.2).to("cuda")

"""
# Transformer
# n_heads, out_dim, num_layers, model_size=32, dim_feedforward=2048, dropout=0.1):
model_size = 768
model = Model(
    model_size=model_size,
    dropout=0.2,
    num_layers=8,
    n_heads=2,
    dim_feedforward=3072,
    out_dim=len(dataset.tokens)).to("cuda")
"""

epochs = 10_000
lr = 1e-4

optim = torch.optim.Adam(params=model.parameters(), lr=lr)

losses = []

scaler = GradScaler()

for epoch_idx in range(epochs):
    with autocast():
        pred = model(x_s)

        # pred_code = dataset.decode(pred[0])

        loss = F.cross_entropy(pred, y_s)
        if epoch_idx % 100 == 0:
            print(epoch_idx, loss.item())
    
    # print(pred, y)
    losses.append(loss.item())

    loss.backward()
    optim.step()

plt.plot(losses)

0 2.337981700897217
100 2.3236582279205322
200 2.3219642639160156
300 2.2897744178771973
400 2.2175755500793457
500 2.2386913299560547


# Evaluate

In [None]:
code_pred = dataset.decode(model(x_s)[0])

torch.Size([6, 32, 21])


RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [32, 6, 21]->[6, 1, 32, 1, 21] [2, 768, 384]->[1, 2, 1, 384, 768]

In [None]:
code_pred

'pritt(        vvvvvvvvvvvvvvvvvv'

# Try!

In [None]:
my_comment = "Sleep for five seconds"
my_code    = dataset.decode(
    model(
        dataset.encode(my_comment).unsqueeze(dim=0).to("cuda")
    )[0]
)

ValueError: 'S' is not in list

In [None]:
my_code

'prit(     vvvvvvvvvvvvvvvvvvvvvv'