In [271]:
import torch
import torch.nn as nn
from  torch.nn import functional as F
torch.manual_seed(1337)
cuda = torch.device('cuda') 

In [272]:
with open("input.txt", "r", encoding = "utf-8") as f:
    text=f.read()
print(len(text))    

1115394


In [273]:
vocab = list(set(text))
vocab.sort()
print("".join(vocab))
vocab_size = len(vocab)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [274]:
#encoder
ctoi = { vocab[i]:i for i in range(len(vocab))}
itoc = { i:vocab[i] for i in range(len(vocab))}
def encode(s): return [ ctoi[i] for i in s]
def decode(t): return "".join([ itoc[i] for i in t])

tokens = encode("hi ho")
s = decode(tokens)
print(tokens, s)

[46, 47, 1, 46, 53] hi ho


In [275]:
tokens = encode(text)
print(tokens[:20])
print(decode(tokens[:20]))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56]
First Citizen:
Befor


In [276]:
data = torch.tensor(tokens, dtype=torch.long, device = cuda)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [277]:
def get_batch(data, batch_size = 4, block_size = 8 ):
    indices = torch.randint(len(data)-block_size-1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in indices], dim=0)
    y = torch.stack([data[i+1:i+block_size+1] for i in indices], dim=0)
    return x,y   

x,y =  get_batch(train_data)    
print(x)
print(y)

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]], device='cuda:0')
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]], device='cuda:0')


In [282]:
def compute_loss(model, dataset):
    model.eval()
    with torch.no_grad():
        total = 0
        for i in range(100):   
            x,y =  get_batch(train_data, 64, model.get_context_size())    
            _, loss = model(x,y)
            total += loss
        model.train()
        return float((total/100).cpu())

def train(model, lr, batch_size, iterations, iter_eval):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    print(compute_loss(model, train_data), compute_loss(model, val_data))

    for it in range(iterations):
        x,y =  get_batch(train_data, batch_size, model.get_context_size())    
        _, loss = model(x,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        if it % iter_eval == 0:
            print(it//iter_eval, compute_loss(model, train_data), compute_loss(model, val_data))
        
    torch.save(model.state_dict(), "./mymodel.pth")    

In [288]:
loss_fn = nn.CrossEntropyLoss()

class Bigram(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Embedding(vocab_size, vocab_size)
    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count):
        s = torch.zeros((1,1), dtype=torch.long, device = cuda)
        out = s
        for i in range(count):
            #print(s)
            p, _ = self.forward(s)
            probs = F.softmax(p[0], dim=1)
            s = torch.multinomial(probs,1)
            #print("sample", ss)
            out = torch.cat([out, s], dim=1)

        return decode(out[0].tolist())
    def get_context_size(self):
        return 1

In [292]:
bm = Generator(Bigram()).cuda()
print(bm.generate(200))
train(bm, lr=1e-3, batch_size=4, iterations=10000, iter_eval=1000)
print(bm.generate(200))


? .vf&EuQDBg'PCWo!KzqdX:KGB&E$YxUPoGi'lwW;&U CzqS'KfjZ?X-VwDmVw&cXazZC$Dk:wG-L!.Z?GuFAzYuYy!CgCz;.!RNgQ!I;-nseOs,mJ$a.hdXqV:$PDo3tS
,&JWrUPR paJVnNiAmIsVf&.IMD$n'Ipf&p!A3d$YM,;kE3;l,MKkGhviLxQ!oM
.UPu
4.652835369110107 4.6526336669921875
0 4.63444709777832 4.641069412231445
1 4.309333801269531 4.295302391052246
2 4.044012546539307 4.044595718383789
3 3.8016421794891357 3.8051838874816895
4 3.5808169841766357 3.57332181930542
5 3.4287338256835938 3.411273717880249
6 3.2615296840667725 3.275895833969116
7 3.143493890762329 3.1579220294952393
8 3.0518107414245605 3.057640314102173
9 2.9937422275543213 2.9483277797698975

s;IKWlane.Nl t,
Mkeng?MXG$PSL&Tve ystinilgrOf.hGrymo ngYKm fasltcy s,e,
AJX twhte eerd
BGr cURurld;.FCisenookpe w;I't Okxl SeILZo pSLOTuJL?dBambH
BBTSCredow
GTy',wimismifaLHNue
.
Rpuo inC: fo ymio twa


In [293]:
class Attention(nn.Module):
    def __init__(self, context_size, input_size, output_size):
        super().__init__()
        # KQV size
        self.output_size = output_size
        self.key = nn.Linear(input_size, output_size, bias=False)
        self.query = nn.Linear(input_size, output_size, bias=False)
        self.value = nn.Linear(input_size, output_size, bias=False)

        sz = context_size
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        self.register_buffer("mask", mask)


    def forward(self, x):
        em_key = self.key(x)
        em_query = self.query(x)
        em_value = self.value(x)

        # the attentions matrix must be the size of the context
        # as it is in reality an adjacency matrix
        att = em_key @ em_query.transpose(-2,-1)

        #print (att.shape)

        att /= self.output_size ** 0.5

        att += self.mask

        att = F.softmax(att, dim=1)
        return att @ em_value 



In [294]:
dropout=0.1

class Block(nn.Module):
    def __init__(self, context_size, num_heads, embedding_size):
        super().__init__()

        self.ln1 = nn.LayerNorm(embedding_size)

        attention_size = embedding_size//num_heads
        self.head = nn.ModuleList( [Attention(context_size, embedding_size, attention_size) for _ in range(num_heads)])

        self.linear = nn.Linear(embedding_size, embedding_size)
        self.dp1 = nn.Dropout(dropout)
        self.ln2 = nn.LayerNorm(embedding_size)

        self.ff = nn.Sequential(
            nn.Linear(embedding_size, 4 * embedding_size),
            nn.ReLU(),
            nn.Linear(4 * embedding_size, embedding_size),
            nn.Dropout(dropout),
        )


    def forward(self, x):

        x = self.ln1(x)
        x = x + torch.cat([head(x) for head in self.head], dim=-1)
        x = self.linear(x)
        x = self.dp1(x)

        x = self.ln2(x)
        x = x + self.ff(x)

        return x


class ChatGPT(nn.Module):
    def __init__(self, context_size, num_blocks, num_heads, embedding_size):
        super().__init__()
        self.context_size = context_size
        pos = torch.arange(0, context_size, dtype=torch.long)
        self.register_buffer("pos", pos)

        self.tok_embedding = nn.Embedding(vocab_size, embedding_size)
        self.pos_embedding = nn.Embedding(context_size, embedding_size)

        self.blocks = nn.ModuleList( [Block(context_size, num_heads, embedding_size) for _ in range(num_blocks)])
        self.ff = nn.Sequential(nn.Linear(embedding_size,embedding_size), nn.ReLU())

        self.ln = nn.LayerNorm(embedding_size) # final layer norm
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        
        te = self.tok_embedding(x)
        pe = self.pos_embedding(self.pos)
        x = te+pe

        for block in self.blocks:
            x = block(x)
        x = self.ff(x)
        x = self.ln(x)
        x = self.linear(x)

        return x

In [300]:
class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count, str=" "):
        self.eval()
        with torch.no_grad():
            s = torch.zeros((1,self.model.context_size), dtype=torch.long).cuda()

            prompt = torch.tensor([encode(str)], dtype=torch.long, device = cuda)
            prompt_len = len(str)

            s[0, -prompt_len:] = prompt
            out = s
            for i in range(count):
                p, _ = self.forward(out[:,-self.model.context_size:])
                probs = F.softmax(p[0], dim=1)
                s = torch.multinomial(probs,1)
                out = torch.cat([out, s[-1].unsqueeze(1)], dim=1)

            return decode(out[0].tolist()[self.model.context_size-prompt_len:])
        self.train()

    def get_context_size(self):
        return self.model.context_size

In [297]:
def Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 64):
    print("configuration", context_size, num_blocks, num_heads, embedding_size)
    gen = ChatGPT(context_size, num_blocks, num_heads, embedding_size)
    cg = Generator(gen).cuda()
    #train(cg, lr=1e-4, iterations=10)
    return cg


In [301]:
e = Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 8*8)
print(e.generate(500))
train(e, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )
print(e.generate(500))


configuration 8 4 8 64
 dk-soUrPESe !RCGz?dJvrkKMkeWDCHPyzoZYu YkHB'i,iQQtSag,edCadBDOYYOQMKCDMwGwBLtlYCmd .-,D;l nXl,oCRiZKXb;xWVsRQHd;z!KfJA'fJGQzjAYECyr3,zfYa!r,KOGK,
VOYvkfHBVzKJ?r.M  t-!ZcYOlBXqJyA3bWk$u:!JOoCRD3-ZSSCeD S,
:S3exDVYeUD,rHSqeAOvY:rOCS&kia'c:lQy&SHFXgWAuk'CbFH-hSZHAvqCC3kCkjbSrrbrjNOC$ZJ-KeZ&bXAHELcZKvZhZLcGkr$3J'e?SG oKJZxC
3DELYrDJdcYh
BdAbKmmAGPxAUYH!SGSMtaKSmD zpUwVFG.MGx$ YBia.AtSG!XvmXSMPbSdrEPLevZY;OLJJvFYDpzYYv !3Hmr-ADkcQELY?dfzJQHHJz3OevJC!J.y D.VLwU-E xwYRGLLi CpMlvMxK-vkTAPZZKjfJcFD
p'OBh
4.437148571014404 4.432563781738281
0 4.40033483505249 4.396738052368164
1 3.428243637084961 3.4338808059692383
2 3.150041103363037 3.154521942138672
3 2.9211373329162598 2.9315361976623535
4 2.709534168243408 2.7067461013793945
5 2.4750494956970215 2.4850308895111084
6 2.2734384536743164 2.2822346687316895
7 2.0735151767730713 2.070044755935669
8 1.8975770473480225 1.8976020812988281
9 1.7318187952041626 1.7282023429870605
10 1.5803191661834717 1.5726232528686523
11 1.4

In [303]:
e2 = Experiment(context_size = 256, num_blocks = 6, num_heads = 6, embedding_size = 6*64)
train(e2, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )
print(e2.generate(500))

configuration 256 6 6 384
4.423750877380371 4.423376560211182
0 3.834313154220581 3.834627866744995
1 2.518970012664795 2.5189764499664307
2 2.417795419692993 2.418952226638794
3 2.047351598739624 2.046271562576294
4 1.5564768314361572 1.5556992292404175
5 1.0190346240997314 1.0173932313919067
6 0.608237624168396 0.6081591844558716
7 0.34533897042274475 0.34360986948013306
8 0.19945649802684784 0.200038880109787
9 0.11963575333356857 0.12120773643255234
10 0.08075428009033203 0.08019108325242996
11 0.07577800005674362 0.0764923170208931
12 0.04499192163348198 0.04530977085232735
13 0.03225255012512207 0.03188958764076233
14 0.02567167952656746 0.025795895606279373
15 0.02208765409886837 0.022337250411510468
16 0.020169099792838097 0.019875528290867805
17 0.017486875876784325 0.01775313727557659
18 0.01560223288834095 0.015622979961335659
19 0.015707334503531456 0.015536091290414333
 ICKHYYTDdDTHCWDOBDO:
SRAPE sIL:
Ahee, sfeeer tat atheo -Nit,
HoUEANhINAvIErU:

Ss
IB:
ViN the bois Iis s

In [304]:
#train(e, lr=1e-4, iterations=2000, iter_eval=100 )
print(e.generate(1500))

 hcha
Shsthe me
oomy thadr cer:

Ns bo;e-ve wolhgsi ils ut:
LsAr theefirertou the thime wly ardik shisT loug Yhseelne fre aale wepirsunt nol feomu twwilr-edo thit mat al th pd chane Yoparf That wy dfifr of, tae thal! bole Oj whenter hit o linet, noirtig fwirntor hop,
 was.

l'ndsd gu e of!ingus ht, outhrharge tidtegiyenen eas tkee Aarw,: on md woer
'gdeve gat thirgy ht, sons, A-hrealt, dherofarexeu fy f
heralg, one will
We scbatousSh wilonar rLukthrerq.
R:
COUSW Kd enrel, hseintotat shanldl lnoc ornels cnes thehrelfiere p afu karus oury ald omet vente oel k wopn m owy oso omie wd wyor.
H batce pore co-e se twatat whe firc k bin foprsns nork.
OENRUNNSRLRERAETE:
CF:
LCEENyxe,
Noirtoes ooro vuechenod is binsort tatn wors h! vafst, oind,
Wat gerasry se no thmy on.
ULNi w youe th stheou the mion, dhres
ON:
LW3ain.
ORPET
Tha, ad terhy dals wher thu the my tharpe-se ger aend drigiktr els lefh-
wege  deehrer pbraket gsle wor he esy dace, fh ofwr'd forif, ben bhe
Ri've fow ord we wel  abl theor