In [141]:
import torch
import torch.nn as nn
from  torch.nn import functional as F
torch.manual_seed(1337)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [142]:
with open("input.txt", "r", encoding = "utf-8") as f:
    text=f.read()
print(len(text))    

1115394


In [143]:
vocab = list(set(text))
vocab.sort()
print("".join(vocab))
vocab_size = len(vocab)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [144]:
#encoder
ctoi = { vocab[i]:i for i in range(len(vocab))}
itoc = { i:vocab[i] for i in range(len(vocab))}
def encode(s): return [ ctoi[i] for i in s]
def decode(t): return "".join([ itoc[i] for i in t])

tokens = encode("hi ho")
s = decode(tokens)
print(tokens, s)

[46, 47, 1, 46, 53] hi ho


In [145]:
tokens = encode(text)
print(tokens[:20])
print(decode(tokens[:20]))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56]
First Citizen:
Befor


In [147]:
data = torch.tensor(tokens, dtype=torch.long, device = device)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [148]:
def get_batch(data, batch_size = 4, block_size = 8 ):
    indices = torch.randint(len(data)-block_size-1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in indices], dim=0)
    y = torch.stack([data[i+1:i+block_size+1] for i in indices], dim=0)
    return x,y   

x,y =  get_batch(train_data)    
print(x)
print(y)

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]])


In [149]:
def compute_loss(model, dataset):
    model.eval()
    with torch.no_grad():
        total = 0
        for i in range(100):   
            x,y =  get_batch(train_data, 64, model.get_context_size())    
            _, loss = model(x,y)
            total += loss
        model.train()
        return float((total/100).cpu())

def train(model, lr, batch_size, iterations, iter_eval):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    print(compute_loss(model, train_data), compute_loss(model, val_data))

    for it in range(iterations):
        x,y =  get_batch(train_data, batch_size, model.get_context_size())    
        _, loss = model(x,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        if it % iter_eval == 0:
            print(it//iter_eval, compute_loss(model, train_data), compute_loss(model, val_data))
        
    torch.save(model.state_dict(), f"./{model.model.name}.pth")    

In [158]:
loss_fn = nn.CrossEntropyLoss()

class Bigram(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Embedding(vocab_size, vocab_size)
        self.name = f"bigram_{vocab_size}"

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count):
        s = torch.zeros((1,1), dtype=torch.long, device = device)
        out = s
        for i in range(count):
            #print(s)
            p, _ = self.forward(s)
            probs = F.softmax(p[0], dim=1)
            s = torch.multinomial(probs,1)
            #print("sample", ss)
            out = torch.cat([out, s], dim=1)

        return decode(out[0].tolist())
    
    def get_context_size(self):
        return 1

In [159]:
bm = Generator(Bigram()).to(device)
print(bm.generate(200))
train(bm, lr=1e-3, batch_size=4, iterations=10000, iter_eval=1000)
print(bm.generate(200))


phIOWd3AqNcgg,G!;j
UtVYwJteWJc3xq.NBpFdLXaqK; eyjnB,Icl'Vn3M3M:JSe;bVbN N&DsRi?!DSaeyNZlSYjVkCzkSdocO'f rCrC'co&sLpCaOlmvxIRlq;-nspc;kDlMHulz.BnzPOI$ISUvRZe;oWJ ?!TxUj!!wTcIFx$G!Zvm,CruxtJoMZjPUZcXtLM
4.69345235824585 4.660988807678223
0 4.700205326080322 4.691677570343018
1 4.355329990386963 4.346986293792725
2 4.058333873748779 4.076082706451416
3 3.8295791149139404 3.8127782344818115
4 3.617060899734497 3.602130174636841
5 3.4610977172851562 3.4833998680114746
6 3.2942519187927246 3.3329217433929443
7 3.1765172481536865 3.201218366622925
8 3.1023147106170654 3.0953824520111084
9 3.0085549354553223 3.0099878311157227

Wr-FWJcKd.ug yasLUTEQwPZDxQifRUct'ssondobatelayo dkvT3n gqCray Ye LEmennden vofevANleyxugroeer! nd plBAD.zmme?!c oms
AQPnPoaky
Rbre mon aFL.xraJNOBoret slegrusoulyZDWIIU'HW'ey wnzvkelujhy sh.ary wixVY


In [168]:
class Attention(nn.Module):
    def __init__(self, context_size, input_size, output_size):
        super().__init__()
        # KQV size
        self.output_size = output_size
        self.key = nn.Linear(input_size, output_size, bias=False)
        self.query = nn.Linear(input_size, output_size, bias=False)
        self.value = nn.Linear(input_size, output_size, bias=False)

        sz = context_size
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        self.register_buffer("mask", mask)


    def forward(self, x):
        em_key = self.key(x)
        em_query = self.query(x)
        em_value = self.value(x)

        # the attentions matrix must be the size of the context
        # as it is in reality an adjacency matrix
        att = em_query @ em_key.transpose(-2,-1)

        #print (att.shape)

        att /= self.output_size ** 0.5

        att += self.mask

        att = F.softmax(att, dim=-1)
        return att @ em_value 



In [182]:
dropout=0.1

class Block(nn.Module):
    def __init__(self, context_size, num_heads, embedding_size):
        super().__init__()

        self.ln1 = nn.LayerNorm(embedding_size)

        self.head = nn.ModuleList( [Attention(context_size, embedding_size, embedding_size//num_heads) for _ in range(num_heads)])
        self.linear = nn.Linear(embedding_size, embedding_size)
        self.dp1 = nn.Dropout(dropout)
        
        self.ln2 = nn.LayerNorm(embedding_size)

        self.ff = nn.Sequential(
            nn.Linear(embedding_size, 4 * embedding_size),
            nn.ReLU(),
            nn.Linear(4 * embedding_size, embedding_size),
            nn.Dropout(dropout),
        )


    def forward(self, x):

        lx = self.ln1(x)
        x1 = self.linear(torch.cat([head(lx) for head in self.head], dim=-1))
        x1 = self.dp1(x1)
        x = x + x1
        
        lx = self.ln2(x)
        x2 = self.ff(lx)
        x = x + x2

        return x


class ChatGPT(nn.Module):
    def __init__(self, context_size, num_blocks, num_heads, embedding_size):
        super().__init__()
        self.name = f"gpt_{context_size}_{num_blocks}_{num_heads}_{embedding_size}"
        self.context_size = context_size
        pos = torch.arange(0, context_size, dtype=torch.long)
        self.register_buffer("pos", pos)

        self.tok_embedding = nn.Embedding(vocab_size, embedding_size)
        self.pos_embedding = nn.Embedding(context_size, embedding_size)

        self.blocks = nn.Sequential( *[Block(context_size, num_heads, embedding_size) for _ in range(num_blocks)])

        self.ln = nn.LayerNorm(embedding_size) # final layer norm
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        
        te = self.tok_embedding(x)
        pe = self.pos_embedding(self.pos)
        x = te + pe

        x = self.blocks(x)
        x = self.ln(x)

        x = self.linear(x)

        return x

In [183]:
class Generator(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, y = None):
        p = self.model(x)
        if y!=None:
            ly = F.one_hot(y, vocab_size).type(torch.float32)
            loss = loss_fn(p.permute(0,2,1), ly.permute(0,2,1))
        else:
            loss = None
        return p, loss
    
    def generate(self, count, str=" "):
        self.eval()
        with torch.no_grad():
            s = torch.zeros((1,self.model.context_size), dtype=torch.long).to(device)

            prompt = torch.tensor([encode(str)], dtype=torch.long, device = device)
            prompt_len = len(str)

            s[0, -prompt_len:] = prompt
            out = s
            for i in range(count):
                p, _ = self.forward(out[:,-self.model.context_size:])
                probs = F.softmax(p[0], dim=1)
                s = torch.multinomial(probs,1)
                out = torch.cat([out, s[-1].unsqueeze(1)], dim=1)

            return decode(out[0].tolist()[self.model.context_size-prompt_len:])
        self.train()

    def get_context_size(self):
        return self.model.context_size

In [179]:
def Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 64):
    print("configuration", context_size, num_blocks, num_heads, embedding_size)
    gen = ChatGPT(context_size, num_blocks, num_heads, embedding_size)
    cg = Generator(gen).to(device)
    #train(cg, lr=1e-4, iterations=10)
    return cg


In [180]:
e = Experiment(context_size = 8, num_blocks = 4, num_heads = 8, embedding_size = 8*8)
print(e.generate(500))
train(e, lr=1e-4, batch_size=64, iterations=2000, iter_eval=100 )
print(e.generate(500))


configuration 8 4 8 64
 eGp
ncRaIC$czv:QtGPTnsM!q,n;gdDxnwggnNdFlTGgHRqqKQxn,ynkRs
$CRVeS,nvHng3l$olz$IUqKzigCnyVVneNcft?PxeNgq,nHg!BsQu,EA3nTe;MMIdOGDGnAxsyr.nFCqLg gGik,cUax,nyp'V!oe$jGgVvBHRuL;,aNDxR3GPWwRokbFT:SndQ$rEWjueYxeDxvQ,ZOdF,G;slxD3awiI?sZDn-gg
xdu'Bcav'qCZMsycDIsBeNX
nMkT,U$d!Vk&eGXx;t
ayBkRnOCnCTRsbOC,:R?NV;lDwNJF$EEvF&rhIqyDeq$h&pXgOyBsXBnnOdC& CnQ;Cj:xebsYEuIFqxzxjTeYNwOwfuqGZVG,xFMnFd3KNizxn
!Tb,:&$$y,VpBZB:xRj$TSOzh!!IcNECW kQzvCui:$i NiIRzDQhLRyIN?WEkPFX?RcmPGRxfX!aIzdQsuq'3cy3nYbsbRIE&qmegwnH:EyghV
4.4329142570495605 4.432741165161133
0 4.3880391120910645 4.388359546661377
1 3.350644826889038 3.352203607559204
2 2.9748764038085938 2.960547685623169
3 2.765075922012329 2.7615926265716553
4 2.624894142150879 2.638394355773926
5 2.5509159564971924 2.5549302101135254
6 2.5041286945343018 2.492727756500244
7 2.4508743286132812 2.4369056224823
8 2.40779185295105 2.3942034244537354
9 2.3743042945861816 2.3805015087127686
10 2.3426871299743652 2.338502883911133
11 2.308493

In [None]:
e2 = Experiment(context_size = 256, num_blocks = 6, num_heads = 6, embedding_size = 6*64)
train(e2, lr=1e-4, batch_size=64, iterations=2000, iter_eval=50 )
print(e2.generate(500))

In [304]:
#train(e, lr=1e-4, iterations=2000, iter_eval=100 )
print(e.generate(1500))

 hcha
Shsthe me
oomy thadr cer:

Ns bo;e-ve wolhgsi ils ut:
LsAr theefirertou the thime wly ardik shisT loug Yhseelne fre aale wepirsunt nol feomu twwilr-edo thit mat al th pd chane Yoparf That wy dfifr of, tae thal! bole Oj whenter hit o linet, noirtig fwirntor hop,
 was.

l'ndsd gu e of!ingus ht, outhrharge tidtegiyenen eas tkee Aarw,: on md woer
'gdeve gat thirgy ht, sons, A-hrealt, dherofarexeu fy f
heralg, one will
We scbatousSh wilonar rLukthrerq.
R:
COUSW Kd enrel, hseintotat shanldl lnoc ornels cnes thehrelfiere p afu karus oury ald omet vente oel k wopn m owy oso omie wd wyor.
H batce pore co-e se twatat whe firc k bin foprsns nork.
OENRUNNSRLRERAETE:
CF:
LCEENyxe,
Noirtoes ooro vuechenod is binsort tatn wors h! vafst, oind,
Wat gerasry se no thmy on.
ULNi w youe th stheou the mion, dhres
ON:
LW3ain.
ORPET
Tha, ad terhy dals wher thu the my tharpe-se ger aend drigiktr els lefh-
wege  deehrer pbraket gsle wor he esy dace, fh ofwr'd forif, ben bhe
Ri've fow ord we wel  abl theor

In [185]:
gen = ChatGPT(context_size=256, num_blocks=6, num_heads=6, embedding_size=6*64)
print(gen)

ChatGPT(
  (tok_embedding): Embedding(65, 384)
  (pos_embedding): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (head): ModuleList(
        (0-5): 6 x Attention(
          (key): Linear(in_features=384, out_features=64, bias=False)
          (query): Linear(in_features=384, out_features=64, bias=False)
          (value): Linear(in_features=384, out_features=64, bias=False)
        )
      )
      (linear): Linear(in_features=384, out_features=384, bias=True)
      (dp1): Dropout(p=0.1, inplace=False)
      (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1536, out_features=384, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (1): Block(
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (head): ModuleLi