In [1]:
import torch
import torch.nn as nn
from  torch.nn import functional as F
torch.manual_seed(1337)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
with open("input.txt", "r", encoding = "utf-8") as f:
    text=f.read()
print(len(text))    

1115394


In [3]:
vocab = list(set(text))
vocab.sort()
print("".join(vocab))
vocab_size = len(vocab)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [4]:
#encoder
ctoi = { vocab[i]:i for i in range(len(vocab))}
itoc = { i:vocab[i] for i in range(len(vocab))}
def encode(s): return [ ctoi[i] for i in s]
def decode(t): return "".join([ itoc[i] for i in t])

tokens = encode("hi ho")
s = decode(tokens)
print(tokens, s)

[46, 47, 1, 46, 53] hi ho


In [5]:
tokens = encode(text)
print(tokens[:20])
print(decode(tokens[:20]))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56]
First Citizen:
Befor


In [6]:
data = torch.tensor(tokens, dtype=torch.long, device = device)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [7]:
def get_batch(data, batch_size = 4, block_size = 8 ):
    indices = torch.randint(len(data)-block_size-1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in indices], dim=0)
    y = torch.stack([data[i+1:i+block_size+1] for i in indices], dim=0)
    return x,y   

x,y =  get_batch(train_data)    
print(x)
print(y)

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]], device='cuda:0')
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]], device='cuda:0')


In [8]:
def compute_loss(model, dataset):
    model.eval()
    with torch.no_grad():
        total = 0
        for i in range(100):   
            x,y =  get_batch(train_data, 64, model.get_context_size())    
            _, loss = model(x,y)
            total += loss
        model.train()
        return float((total/100).cpu())

def train(model, lr, batch_size, iterations, iter_eval):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    print(compute_loss(model, train_data), compute_loss(model, val_data))

    for it in range(iterations):
        x,y =  get_batch(train_data, batch_size, model.get_context_size())    
        _, loss = model(x,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        if it % iter_eval == 0:
            print(it//iter_eval, compute_loss(model, train_data), compute_loss(model, val_data))
        
    torch.save(model.state_dict(), f"./{model.model.name}.pth")    

In [9]:
loss_fn = nn.CrossEntropyLoss()

class Bigram(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Embedding(vocab_size, vocab_size)
        self.name = f"bigram_{vocab_size}"

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count):
        s = torch.zeros((1,1), dtype=torch.long, device = device)
        out = s
        for i in range(count):
            #print(s)
            p, _ = self.forward(s)
            probs = F.softmax(p[0], dim=1)
            s = torch.multinomial(probs,1)
            #print("sample", ss)
            out = torch.cat([out, s], dim=1)

        return decode(out[0].tolist())
    
    def get_context_size(self):
        return 1

In [10]:
bm = Generator(Bigram()).to(device)
print(bm.generate(200))
train(bm, lr=1e-3, batch_size=4, iterations=10000, iter_eval=1000)
print(bm.generate(200))


yq$;tfBfROkNdcuwdZZTkOMl;,ertK
w:!PLCkMBbeA$3:XaSGJO-3p&M-c?KL3auhpFYVXJFhNNNuhq$OMxv.tbVFYdXlrFZaAeNuw:cPPyREFkHDEZaYJFzyWNuX
Yo3&$LMtofBimzLB!!&V!Ox;Kl;l;ZcKe3 ixYeYEFngmi;;lxWvHFGEZEQG EsSXHB;kW3 J
4.627649307250977 4.631004810333252
0 4.639685153961182 4.63800573348999
1 4.330250263214111 4.331907749176025
2 4.061417579650879 4.0431013107299805
3 3.821509599685669 3.8229637145996094
4 3.6014459133148193 3.6163885593414307
5 3.44629168510437 3.424192190170288
6 3.2874021530151367 3.298607349395752
7 3.202650308609009 3.1776974201202393
8 3.050213575363159 3.0766634941101074
9 2.992992877960205 2.9892590045928955

BYGENilerjbouselplind me l.
lishe cnchiry:
Uug;Mnisspllw y.O:ur n'SIREDmopetelivIEjMPithy wJd mothakllo W,Coo wh VCeiib3MI'Thom bMxWivDThenghim$Fs p-LK3gAY-xT3b

ALENxmntcrurt f so;;3QQDLETm:
EN,CI ma


In [11]:
class Attention(nn.Module):
    def __init__(self, context_size, input_size, output_size):
        super().__init__()
        # KQV size
        self.output_size = output_size
        self.key = nn.Linear(input_size, output_size, bias=False)
        self.query = nn.Linear(input_size, output_size, bias=False)
        self.value = nn.Linear(input_size, output_size, bias=False)

        sz = context_size
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        self.register_buffer("mask", mask)


    def forward(self, x):
        em_key = self.key(x)
        em_query = self.query(x)
        em_value = self.value(x)

        # the attentions matrix must be the size of the context
        # as it is in reality an adjacency matrix
        att = em_query @ em_key.transpose(-2,-1)

        #print (att.shape)

        att /= self.output_size ** 0.5

        att += self.mask

        att = F.softmax(att, dim=-1)
        return att @ em_value 



In [12]:
dropout=0.1

class Block(nn.Module):
    def __init__(self, context_size, num_heads, embedding_size):
        super().__init__()

        self.ln1 = nn.LayerNorm(embedding_size)

        self.head = nn.ModuleList( [Attention(context_size, embedding_size, embedding_size//num_heads) for _ in range(num_heads)])
        self.linear = nn.Linear(embedding_size, embedding_size)
        self.dp1 = nn.Dropout(dropout)
        
        self.ln2 = nn.LayerNorm(embedding_size)

        self.ff = nn.Sequential(
            nn.Linear(embedding_size, 4 * embedding_size),
            nn.ReLU(),
            nn.Linear(4 * embedding_size, embedding_size),
            nn.Dropout(dropout),
        )


    def forward(self, x):

        lx = self.ln1(x)
        x1 = self.linear(torch.cat([head(lx) for head in self.head], dim=-1))
        x1 = self.dp1(x1)
        x = x + x1
        
        lx = self.ln2(x)
        x2 = self.ff(lx)
        x = x + x2

        return x


class ChatGPT(nn.Module):
    def __init__(self, context_size, num_blocks, num_heads, embedding_size):
        super().__init__()
        self.name = f"gpt_{context_size}_{num_blocks}_{num_heads}_{embedding_size}"
        self.context_size = context_size
        pos = torch.arange(0, context_size, dtype=torch.long)
        self.register_buffer("pos", pos)

        self.tok_embedding = nn.Embedding(vocab_size, embedding_size)
        self.pos_embedding = nn.Embedding(context_size, embedding_size)

        self.blocks = nn.Sequential( *[Block(context_size, num_heads, embedding_size) for _ in range(num_blocks)])

        self.ln = nn.LayerNorm(embedding_size) # final layer norm
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        
        te = self.tok_embedding(x)
        pe = self.pos_embedding(self.pos)
        x = te + pe

        x = self.blocks(x)
        x = self.ln(x)

        x = self.linear(x)

        return x

In [13]:
class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count, str=" "):
        self.eval()
        with torch.no_grad():
            s = torch.zeros((1,self.model.context_size), dtype=torch.long).to(device)

            prompt = torch.tensor([encode(str)], dtype=torch.long, device = device)
            prompt_len = len(str)

            s[0, -prompt_len:] = prompt
            out = s
            for i in range(count):
                p, _ = self.forward(out[:,-self.model.context_size:])
                probs = F.softmax(p[0], dim=1)
                s = torch.multinomial(probs,1)
                out = torch.cat([out, s[-1].unsqueeze(1)], dim=1)

            return decode(out[0].tolist()[self.model.context_size-prompt_len:])
        self.train()

    def get_context_size(self):
        return self.model.context_size

In [14]:
def Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 64):
    print("configuration", context_size, num_blocks, num_heads, embedding_size)
    gen = ChatGPT(context_size, num_blocks, num_heads, embedding_size)
    cg = Generator(gen).to(device)
    #train(cg, lr=1e-4, iterations=10)
    return cg


In [15]:
e = Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 8*8)
print(e.generate(500))
train(e, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )
print(e.generate(500))


configuration 8 4 8 64
 UUJ$rA!xpD:: x;fDbRaxnxeWGsdGQ3qf3alANY!jtdogKW?':-cj$QN.Sia!nlkCn$x&OqCCxDNsa33 sPPu:KyYTg!D$UQ3ayF;:eDxqQa3x !Ed' an?McW$NfZF,xaKc$3cN&S'MYJ&f-QAc&Y$wxsUX$sf- IR.?'Bp$DUx3&snfcYl$-e
qN3an$:bm tfrxJakN.OEYt3-?YXNeOqxowgpffQ&xcnva$bYk,mo-hh.JYDKnxxhkrNx',Ts3MY;KL$a!&-d j'L f?xYXcPadTT$xGxfmUXfk'ZjRxOagfoaqq!UH$f!QWIJ$xxkNroBNzYNWgYysQaefxLPfhfy,$eKVP$:ulczdfjxBXKEz$Dc$xfp cTv:;!!PWYMeDT-cYTfWlvkr,ckxx3.N
VObWWfZx,NJmafcxXeaNul ;$Ha?YW:fbgdg'?soQ:-fPxMVflz3FfcBqyf-h!NfHfZ-fg3NUdV:
GnxXd$ ;f&IzDkd
4.343344211578369 4.34586763381958
0 4.325704097747803 4.32121467590332
1 3.391835927963257 3.3991169929504395
2 3.1427764892578125 3.1494877338409424
3 2.9176723957061768 2.9262022972106934
4 2.7576820850372314 2.769906759262085
5 2.6690285205841064 2.6696863174438477
6 2.6053688526153564 2.598721981048584
7 2.5632643699645996 2.554905652999878
8 2.5135509967803955 2.5118911266326904
9 2.476839542388916 2.4772424697875977
10 2.442214250564575 2.4460175037384033
11 2.423

In [16]:
e2 = Experiment(context_size = 256, num_blocks = 6, num_heads = 6, embedding_size = 6*64)
train(e2, lr=1e-4, batch_size=64, iterations=2000, iter_eval=50 )
print(e2.generate(500))

configuration 256 6 6 384
4.44791316986084 4.448210716247559
0 4.027475833892822 4.028568267822266
1 2.6850509643554688 2.686478614807129
2 2.5379507541656494 2.53707218170166
3 2.497612237930298 2.4965407848358154
4 2.4752283096313477 2.475990056991577
5 2.4551048278808594 2.455596923828125
6 2.439548969268799 2.4391403198242188
7 2.420083522796631 2.4195609092712402
8 2.4047975540161133 2.4027047157287598
9 2.379061460494995 2.374129056930542
10 2.3414764404296875 2.342571496963501
11 2.3015737533569336 2.3029565811157227
12 2.239624500274658 2.2383382320404053
13 2.1889402866363525 2.1866536140441895
14 2.1443870067596436 2.1475045680999756
15 2.100766181945801 2.0999557971954346
16 2.0617241859436035 2.063416004180908
17 2.0307254791259766 2.028887987136841
18 1.9959670305252075 1.9960229396820068
19 1.9612364768981934 1.961883783340454
20 1.932071328163147 1.9332457780838013
21 1.9072363376617432 1.9050612449645996
22 1.8824152946472168 1.8833842277526855
23 1.855997085571289 1.85

In [20]:
#train(e, lr=1e-4, iterations=2000, iter_eval=100 )
print(e2.generate(1500))

 MFrexADUCKEQUVEY:
OLUSAULUS:
Whomes, wiSy brABEd Ifid:
I chas chame warde mon hom thy condlitie amed.
-moe men beadlings ake sewor teek of too wercas?
Whe ous, comfick?

Provost:
Theres's nexore that the a musher! hath we tome
And pratent almoon tems,
And thost apsty but me flaon, I doth kissitle:
By of wiflance do that to the steelf,
And despition but roshor aurase bed
Shich togs; ark ope; the for proces, foot this fool.

FirDIR;:
The hath well give him Mastrougher! corve word,
And my day thegue soe as all pot our distant.

LUCESHENTIO:
My wentle find obly leaves to Eche most?
Iswake have then reacse my wear:
Aho desid wine thougins. Genter's cancil into Rosy
bagaliess of the with's the fall of hisping:
And how somme tway from grieforgum begar
Wich kill isspuit: it saim I am reth
Tolson to the brives fevor the ward all at Mostory.

Secoust:
Here do met tojes ant flace, I knot as acceived;
Where we dispose do is and the saddies
A some sing is dapious to did this beaw?
That a the it I 

In [21]:
train(e2, lr=1e-4, batch_size=64, iterations=200, iter_eval=50 )

1.5876940488815308 1.5876537561416626
0 1.6956408023834229 1.697298526763916
1 1.5812041759490967 1.5792243480682373
2 1.569296956062317 1.5666412115097046
3 1.5594935417175293 1.5636963844299316


In [22]:
train(e2, lr=1e-4, batch_size=64, iterations=1000, iter_eval=50 )

1.557078242301941 1.5565043687820435
0 1.7929167747497559 1.796095848083496
1 1.5489294528961182 1.5500853061676025
2 1.5443283319473267 1.5471572875976562
3 1.5341901779174805 1.5365087985992432
4 1.5264430046081543 1.5256478786468506
5 1.5187193155288696 1.5185668468475342
6 1.5133506059646606 1.5140659809112549
7 1.5043244361877441 1.5057636499404907
8 1.4962104558944702 1.50107741355896
9 1.494091510772705 1.4949402809143066
10 1.4884915351867676 1.4876561164855957
11 1.4838132858276367 1.4828325510025024
12 1.4737396240234375 1.4760366678237915
13 1.4673218727111816 1.4679436683654785
14 1.4627834558486938 1.4597004652023315
15 1.455433964729309 1.4584462642669678
16 1.4523550271987915 1.4547147750854492
17 1.447634220123291 1.4477453231811523
18 1.4383522272109985 1.4419149160385132
19 1.434388279914856 1.4319144487380981


In [23]:
print(e2.generate(1500))

 fitheven torte thers:
You offencaffector'd,
No, goother pat han of faten lis.

S$OVERDWIO:

SLANLEY:

LADYet we was grage his withaton line,
Who he hart they wilt to the date:
I all honour! back that the sust mine
To lreagely! but cheighn daughters, for her?
Why should with my dembation, they un,
And grior shall grows, and muster with about change at
And there woman: all then I day Poter:
And thereful hear nature alonged but of our
But birth. Who Plain all sentler lord.
Where's this hownour of, our batch, say sicon:
But let mest me higheld, to the harl of his
sleep-ows, if it that enteech stil the trick.
The firstirested me thou deed blust buck:
Not thy lovelgge anoptent preace weeporm.
Unto comes of thy king from that curfles;
To that we To was begeted that murder?
Do giving find thus not gives be and will,
Or sweak nothres in this house fearl his brotcheds!
Mese for how tuth the my for mostanch'd blood,
Whose now house from to repenable Duke thy kindled from his,
What it shall gend 

In [24]:
train(e2, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )

1.4301226139068604 1.4327268600463867
0 1.495746374130249 1.4963414669036865
1 1.4183191061019897 1.4199295043945312
2 1.4172418117523193 1.4179918766021729
3 1.4026403427124023 1.4031051397323608
4 1.3996590375900269 1.3954665660858154
5 1.3859045505523682 1.3838927745819092
6 1.372943639755249 1.374499797821045
7 1.3700703382492065 1.36915123462677
8 1.3573803901672363 1.363472580909729
9 1.3544305562973022 1.352319359779358
10 1.3427348136901855 1.3455711603164673
11 1.3375816345214844 1.3371108770370483
12 1.3319720029830933 1.3297680616378784
13 1.3224854469299316 1.322540044784546
14 1.3132699728012085 1.3135772943496704
15 1.3076298236846924 1.3105628490447998
16 1.3033661842346191 1.2979401350021362
17 1.298403263092041 1.2947909832000732
18 1.2926664352416992 1.2923775911331177
19 1.2848608493804932 1.2877528667449951


In [26]:
print(e2.generate(1500))

 SLLYOUCPEY:
N'XINIUS:

BESISABELLA:
Hell.

VAULIA:
ANGELUS:
HERCHIO:
Said Mayorshonous.

FLORIZEL.

CLAUDIO:
That what's you shad be a man sprited you.

CORIOLANUS:
ES:
Chome with. an the our hopes bote home:
some, with so me light so do woout I'll wear true,
If I must royal lor with Lord satis;
And but I thank'd you; let I will it, a varnavour
With soft another's but the bed hast.
O, concil, who should upon me,
A care service to our foremia again?

First Keeper:
Why, get you go not, you mistress
So which hath in this issue.

GLOUCESTER:
And yet doth Romeo, strange, but by his herm.

PETAULINA:
On them all declaims when loved the king?

LADY CAPULE:
No, Perdon, the subject of both the old.

LEONTES:
Farewell! Madam, afford the traitor-doors it,
Most we not so in, good tunder! through thou back,
Thou royalt'st thou
Think to thy soft muff and thy sacrew:
Whom galland when thy longth woful die!

MENENIUS:
Thou not letters do no success? thou speaking, true!
Ortul! place what not thou was

In [27]:
train(e2, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )
print(e2.generate(1500))

1.2768762111663818 1.2778184413909912
0 1.303093671798706 1.3015209436416626
1 1.2757031917572021 1.2750481367111206
2 1.2653851509094238 1.267059326171875
3 1.2613242864608765 1.2614907026290894
4 1.2557693719863892 1.2551026344299316
5 1.2508587837219238 1.2508403062820435
6 1.2467156648635864 1.245319128036499
7 1.2416040897369385 1.2416107654571533
8 1.2325907945632935 1.2322922945022583
9 1.233831763267517 1.2325283288955688
10 1.2230241298675537 1.2246685028076172
11 1.2208529710769653 1.2205132246017456
12 1.2142488956451416 1.2160379886627197
13 1.2105913162231445 1.2107101678848267
14 1.2037389278411865 1.2075015306472778
15 1.1991580724716187 1.2012897729873657
16 1.1987706422805786 1.1969448328018188
17 1.1917442083358765 1.193029522895813
18 1.1894736289978027 1.1862218379974365
19 1.1826637983322144 1.181326985359192
 Ris.
BESTMAMPSOLO:
SovereONGEO:
ASA.

Fithird SAMPSON:

Sir.

VOLUMELEO:
HESAMPSA:

Fithoping inthee, now as thank the behear you him.

Provost:
Do say:
Not 

In [28]:
print(e2.generate(1500, "Romeo, do you love me?"))

Romeo, do you love me?

JULVO.

FiLLORord:
AUFIDIUS:
Ere 'VI re her fame; lord:
in gene'er only, grim.

JULIET:
'Fit is Pray your best, thou shalt be my kind.

VALERIA:

VOLUMNIA:
Too quencest her sworn and same o's a?

ROMEO:
'Tis both my false; can be it not sound.

LRORD OF SANLEY:
Then jutst my handed is fair's as thing;
Not a bunchless, Norfolk, look to thee: he haman,
For their shocks me and their broaks: ah, look
This gravers bear in sounds all winessing
Things Warwick and think'd the rebellasts of Richmond?

Server:
Unbuck, brother,
Scorn Hastings. Nor this true hand must write!
To see her drop aggain, therest be'n mouths leave,
And that thou wert bone,
For no breathe, 'twere no of time sitting
Which Of the heart of thy son wit's life
Shall be a throne-solding how to they abroad:
'Tis have been her beard,
I will not with with the same may word them:
And inhen the way
To Bolingbroke of his loit borrow.

ROMEO:
He is a nature of hands for this sovereign's
For storer than resh rep