In [64]:
import torch
import torch.nn.functional as F
from helper import *

In [65]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
words = read_words('names.txt')
stoi, itos = get_mapping(words)
nchars = len(stoi.keys())

In [67]:
train_split, val_split = 0.8, 0.1
block_size = 8

X, Y = build_dataset(words, stoi, block_size=block_size)
X, Y = torch.tensor(X), torch.tensor(Y)


n = len(X)
n1 = round(n * train_split)
n2 = round(n * val_split)


X_train, Y_train = X[:n1], Y[:n1]
X_val, Y_val = X[n1:n1+n2], Y[n1:n1+n2]
X_test, Y_test = X[n1+n2:], Y[n1+n2:]

In [70]:
class Linear:
    def __init__(self, fan_in, fan_out, b = None):
        self.W = torch.ones(fan_in, fan_out)
        self.b = None
        if b is not None:
            self.b = torch.zeros(fan_out)

    def __call__(self, x):
        out = x @ self.W
        if self.b is not None:
            out = out + self.b
        self.out = out
        return out
    
    def parameters(self):
        return [self.W] + ([] if self.b is None else [self.b])

class Tanh:
    def __call__(self, x):
        out = torch.tanh(x)
        self.out = out
        return out
    
    def parameters(self):
        return []
    
class BatchNorm1d:
    def __init__(self, hidden_dim, epsilon = 1e-7, momentum = 0.99):
        self.Training = True
        self.gamma = torch.ones(hidden_dim)
        self.beta = torch.zeros(hidden_dim)
        self.running_mean = torch.zeros((1, hidden_dim))
        self.running_var = torch.ones((1, hidden_dim))
        self.epsilon = epsilon
        self.momentum = momentum

    def __call__(self, x):
        if self.Training:
            if x.ndim == 2:
                mean = x.mean(0, keepdim = True)
                var = x.var(0, keepdim = True)
            else:
                mean = x.mean((0,1), keepdim = True)
                var = x.var((0,1), keepdim = True)
        else:
            mean = self.running_mean
            var = self.running_var

        x_gaus = (x - mean) / torch.sqrt(var + self.epsilon)
        out = (self.gamma * x_gaus) + self.beta

        if self.Training:
            with torch.no_grad():
                self.running_mean = self.running_mean * self.momentum + (1 - self.momentum) * mean
                self.running_var = self.running_var * self.momentum + (1 - self.momentum) * var

        self.out = out
        return out
    
    def parameters(self):
        return [self.gamma, self.beta]
    
class LinearBatchNorm1d:
    def __init__(self, fan_in, fan_out):
        self.Training = True
        self.Folding = True
        self.linear = Linear(fan_in, fan_out, False)
        self.bn = BatchNorm1d(fan_out)
        self.W_folded = None
        self.b_folded = None

    def __call__(self, x):
        if self.Training or not self.Folding:
            x = self.linear(x)
            out = self.bn(x)
        else:
            if self.W_folded is None:
                self.W_folded = self.bn.gamma * self.linear.W / (torch.sqrt(self.bn.running_var + self.bn.epsilon))
                self.b_folded = self.bn.beta
            out = x @ self.W_folded + self.b_folded
        self.out = out
        return out
    
    def setTraining(self, train):
        self.Training = train
        self.bn.Training = train
        if train:
            self.W_folded = None
            self.b_folded = None

    def parameters(self):
        return self.linear.parameters() + self.bn.parameters()
    
    
class Embedding:
    def __init__(self, num_class, emb_dim):
        self.C = torch.randn(num_class, emb_dim)


    def __call__(self, x):
        out = self.C[x]
        self.out = out
        return out
    
    def parameters(self):
        return [self.C]
    
class FlattenConsekutiv:
    def __init__(self, n):
        self.n = n

    def __call__(self, x):
        N, T, C = x.shape
        if T//self.n == 1:
            out = x.view(N, C*self.n)    
        else:
            out = x.view(N, T//self.n, C*self.n)
        self.out = out
        return out
    
    def parameters(self):
        return []
    

class Sequential:
    def __init__(self, layers):
        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def parameters(self):
        p = []
        for layer in self.layers:
            p = p + layer.parameters()
        return p

In [71]:
emb_dim = 10
hidden_dim = 200
n = 2
model = Sequential([
    Embedding(nchars, emb_dim),
    FlattenConsekutiv(n), LinearBatchNorm1d(emb_dim * n, hidden_dim), Tanh(),
    FlattenConsekutiv(n), LinearBatchNorm1d(hidden_dim * n, hidden_dim), Tanh(),
    FlattenConsekutiv(n),  LinearBatchNorm1d(hidden_dim * n, hidden_dim), Tanh(),
    Linear(hidden_dim, nchars),
])

for p in model.parameters():
    p.requires_grad = True

In [None]:
# for l in model.layers:
#     print(l.__class__.__name__, ':', tuple(l.out.shape))

AttributeError: 'Embedding' object has no attribute 'out'

In [72]:
iterations = 10000
lr = 0.1
reg  = 0.01
batch_size = 64


stepi = []
lossi = []

for k in range(iterations):
    for p in model.parameters():
        p.grad = None

    idx = torch.randint(0, X_train.shape[0], (batch_size,))
    logits = model(X_train[idx])

    loss = F.cross_entropy(logits, Y_train[idx])

    loss.backward()

    for p in model.parameters():
        p.data -= lr * p.grad
    stepi.append(k)
    lossi.append(loss.item())

    pred = logits.argmax(dim = 1)
    acc = (pred == Y_train[idx]).float().mean().data
    print(f"iteration {k} loss {loss.data}, acc {acc * 100}")

iteration 0 loss 3.295837163925171, acc 7.8125
iteration 1 loss 3.1051840782165527, acc 20.3125
iteration 2 loss 2.9744532108306885, acc 28.125
iteration 3 loss 3.3326783180236816, acc 14.0625
iteration 4 loss 3.037017583847046, acc 17.1875
iteration 5 loss 3.165541887283325, acc 18.75
iteration 6 loss 3.1402230262756348, acc 14.0625
iteration 7 loss 2.976803779602051, acc 26.5625
iteration 8 loss 3.053290843963623, acc 23.4375
iteration 9 loss 3.116036891937256, acc 18.75
iteration 10 loss 3.111933469772339, acc 14.0625
iteration 11 loss 3.1285409927368164, acc 15.625
iteration 12 loss 3.008713722229004, acc 18.75
iteration 13 loss 3.2880659103393555, acc 9.375
iteration 14 loss 3.1675779819488525, acc 23.4375
iteration 15 loss 3.2001566886901855, acc 15.625
iteration 16 loss 3.0270767211914062, acc 17.1875
iteration 17 loss 3.137791872024536, acc 14.0625
iteration 18 loss 3.2173874378204346, acc 10.9375
iteration 19 loss 3.094935178756714, acc 14.0625
iteration 20 loss 2.962682247161