In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import math
import random
import inspect
import re
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
with open('../data/text-1.txt', 'r', encoding='utf-8') as f:
    text = f.read().lower().replace('\n', ' ')

In [3]:
chars = sorted(set(text))

In [4]:
char2int = {c:i for i,c in enumerate(chars)}
# sot = '>'
# assert sot not in char2int
# char2int[sot] = 0
int2char = {i:c for c,i in char2int.items()}
numOfTokens = len(char2int)
print(f'{numOfTokens=}')

numOfTokens=74


In [39]:
ctxLen = 3
def getSamples():
    random.seed(123)
    # def c2i(char):
    #     return char2int[char]
    samplNum = 100
    idxs = random.sample(range(0,len(text)-ctxLen), samplNum)
    textInt = [char2int[ch] for ch in text]
    X,Y = [],[]
    for i in idxs:
        X.append(textInt[i:i+ctxLen])
        Y.append(textInt[i+1:i+ctxLen+1])
    trSamplPct = 0.8
    i1 = int(trSamplPct*len(X))
    i2 = int((i1 + len(X)) / 2)
    Xtr,Ytr = torch.tensor(X[:i1]), torch.tensor(Y[:i1])
    Xval,Yval = torch.tensor(X[i1:i2]), torch.tensor(Y[i1:i2])
    Xtest,Ytest = torch.tensor(X[i2:]), torch.tensor(Y[i2:])
    return (
        Xtr,Ytr,
        Xval,Yval,
        Xtest,Ytest,
    )
Xtr,Ytr, Xval,Yval, Xtest,Ytest = getSamples()

In [60]:
class MlpPredictor(nn.Module):
    def __init__(self,numOfTokens,ctxLen,embSize):
        super().__init__()
        self.ctxLen = ctxLen
        self.embSize = embSize
        self.emb = nn.Embedding(numOfTokens,embSize)
        hidDim = 100
        self.lin1 = nn.Linear(ctxLen*embSize,hidDim)
        self.lin2 = nn.Linear(hidDim,numOfTokens)
    def forward(self,x):
        x = self.emb(x).view(-1,self.ctxLen*self.embSize)
        x = self.lin1(x)
        x = F.leaky_relu(x)
        x = self.lin2(x)
        x = F.softmax(x,-1)
        return x
    def generate(self,ctx,resLen):
        if len(ctx) < self.ctxLen:
            ctx = ' '*(ctxLen-len(ctx)) + ctx
        elif len(ctx) > self.ctxLen:
            ctx = ctx[-ctxLen:]
        res = []
        with torch.no_grad():
            while len(res) < resLen:
                x = torch.tensor([char2int[ch] for ch in ctx]).unsqueeze(0)
                probs = self(x)[0]
                nextToken = torch.multinomial(probs, 1)
                nextChar = int2char[nextToken[0].item()]
                res.append(nextChar)
                ctx = ctx[1:] + nextChar
        return ''.join(res)

In [26]:
def print2dEmb(embLayer):
    W = embLayer.weight
    plt.figure(figsize=(8,8))
    plt.scatter(W[:,0].data,W[:,1].data,s=200)
    for i in range(W.shape[0]):
        plt.text(W[i,0].item(), W[i,1].item(), int2char[i], ha="center", va="center", color='white')
    plt.grid('minor')

def showParamsStats(model, layerNameFilter=None):
    plt.figure(figsize=(20,4))
    legends = []
    for pName, pValue in model.named_parameters():
        if layerNameFilter == None or layerNameFilter.match(pName):
            print(f'layer \'{pName}\'[{pValue.nelement()}] mean:{pValue.mean()}, std:{pValue.std()},')
            hy,hx = torch.histogram(pValue, density=True)
            plt.plot(hx[:-1].detach(),hy.detach(),)
            legends.append(pName)
    plt.legend(legends);
    

In [61]:
model = MlpPredictor(numOfTokens=numOfTokens,ctxLen=ctxLen,embSize=2)
# numOfParams = sum(p.nelement() for p in model.parameters())
# print(f'{numOfParams=}')
# [(pName,pValue.shape) for pName,pValue in model.named_parameters()]
# showParamsStats(model,layerNameFilter=re.compile('^lin.*weight$'))
# print2dEmb(model.emb)

In [70]:
model.generate('',10)

'm™ wt;8)/h'