In [27]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [28]:
import random
import re

In [29]:
with open('chat-data.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

In [30]:
len(text)

607584

In [31]:
chars = sorted(list(set(text)))
''.join(chars)

'\t\n !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~°\u200d’…️🏻🏽👍💫🔥🔪🔫😂😃😄😅😆😊😌😢😨😪😭😮😯😵😶🙂🙃🙏🥲🥳\U0001f979🫂'

**Slight Preprocessing**

In [32]:
special_char = chars[104]
print(special_char)
possible_replacements = ['welcome', 'alright', 'okay']
text = re.sub(special_char, lambda x : random.choices(possible_replacements, weights = [3, 1, 2], k = 1)[0], text)

👍


In [33]:
chars = sorted(list(set(text)))
''.join(chars)

'\t\n !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~°\u200d’…️🏻🏽💫🔥🔪🔫😂😃😄😅😆😊😌😢😨😪😭😮😯😵😶🙂🙃🙏🥲🥳\U0001f979🫂'

In [34]:
len(text)

609024

In [35]:
exotic_chars =  chars[99:]
''.join(exotic_chars)

'’…️🏻🏽💫🔥🔪🔫😂😃😄😅😆😊😌😢😨😪😭😮😯😵😶🙂🙃🙏🥲🥳\U0001f979🫂'

In [36]:
for exotic_char in exotic_chars:
    text = text.replace(exotic_char, '')

len(text)

607855

In [37]:
chars = sorted(list(set(text)))
''.join(chars)

'\t\n !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~°\u200d'

**Hyperparameters**

In [38]:
vocab_size = len(chars)
batch_size = 64   # B
block_size = 256  # T (Context length)
n_layers = 6      # Number of blocks or units of the decoder in the architecture
n_embd = 384      # num_heads * head_size
num_heads = 6     # Number of attention heads in multiheaded attention
head_size = n_embd // num_heads  # Sequence length processed by a single head of attention
max_iters = 5000
eval_iters = 200
eval_interval = 500
learning_rate = 3e-4
dropout = 0.2      # % dropout
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [39]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda text : [stoi[token] for token in text]
decode = lambda encoding : ''.join([itos[item] for item in encoding])

In [40]:
e = encode("This is a line of text")
print(e, decode(e), sep = '\n')

[54, 74, 75, 85, 2, 75, 85, 2, 67, 2, 78, 75, 80, 71, 2, 81, 72, 2, 86, 71, 90, 86]
This is a line of text


In [41]:
# Encode the text dataset into a torch tensor
data = torch.tensor(encode(text), dtype = torch.long)

In [42]:
data.shape

torch.Size([607855])

**Train and Val splits**

In [46]:
n = int(len(data) * 0.95)
train_data = data[:n]
val_data = data[n:]

In [47]:
train_data.shape, val_data.shape

(torch.Size([577462]), torch.Size([30393]))

In [48]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    x = torch.stack([data[i : i+block_size] for i in ix])
    y = torch.stack([data[i+1 : i+block_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

In [49]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval();
    
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            X, Y = get_batch(split)
            X = X.to(device)
            Y = Y.to(device)
            logits, loss = model(X, Y)
            losses[i] = loss.item()
        out[split] = losses.mean()
        
    model.train()
    return out
    

In [50]:
# One head of self attention in multiheaded attentino
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()  # Initialize parameters for the derived class object
        
        self.head_size = head_size
        self.query = nn.Linear(n_embd, head_size, bias = False) # Part of sequence being processes currently
        self.key = nn.Linear(n_embd, head_size, bias = False)   # Parts of the sequence to attend to
        self.value = nn.Linear(n_embd, head_size, bias = False) # Parts of sequence other than the current part
        # tril will have no learnable parameters
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) 
        
        self.dropout = nn.Dropout(dropout)
        
    # Forward method of a derived class of nn.Module is called by the __call__ method of the base nn.Module
    # class. Objects of classes with __call__ method are called 'callable objects'
    def forward(self, x):
        B, T, C = x.shape # B - batch_size, T - block_size (time dimension), C - channels
        k = self.key(x) # (B, T, head_size)
        q = self.key(x) # (B, T, h_s)
        wei = (q @ k.transpose(-2, -1)) * (self.head_size**(-0.5)) # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim = -1)
        wei = self.dropout(wei)
        
        # Weighted aggregation of values
        v = self.value(x)  # (B, T, h_s)
        out = wei @ v      # (B, T, h_s)
        return out;
        

In [51]:
# Multiple heads of attention
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim = -1)
        out = self.dropout(self.proj(out))
        return out
        

In [52]:
# Feed Forward Neural Network
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        out = self.network(x)
        return out

In [53]:
# A block in the transformer decoder
class Block(nn.Module):
    def __init__(self, n_embd, num_heads):
        super().__init__()
        self.self_attention = MultiHeadedAttention(num_heads, head_size)
        self.feed_forward_network = FeedForward(n_embd)
        self.layer_norm_1 = nn.LayerNorm(n_embd)
        self.layer_norm_2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        # Adding x -> The purpose of the residual connection is to ensure that important
        # information from the input sequence is preserved and propagated through the network.
        # Also for improved gradient flow (no vanishing or exploding gradients)
        # Original paper => Self Attention -> Add and Layer_Norm -> Feed_Forward
        # More recently =>  Layer_Norm -> Self_Attention -> Add and Feed_Forward
        x = x + self.self_attention(self.layer_norm_1(x)) 
        x = x + self.feed_forward_network(self.layer_norm_2(x))
        return x
        

In [54]:
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # https://towardsdatascience.com/master-positional-encoding-part-i-63c05d90a0c3
        self.position_embedding_table = nn.Embedding(block_size, n_embd) # Positional encoding
        self.blocks = nn.Sequential(*[Block(n_embd, num_heads) for _ in range(n_layers)])
        self.final_layer_norm = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size) # lanuage modelling head
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
            
    
    def forward(self, idx, targets = None):
        B, T = idx.shape
        
        # Both idx and targets have shape (B, T)
        
        # Forward pass through the whole decoder architecture
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) # (T, C)
        x = tok_emb + pos_emb # (B, T, C) after broadcast and add
        x = self.blocks(x) # (B, T, C)
        x = self.final_layer_norm(x) # (B, T, C)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        # Loss calculation - Cross Entropy (negative log likelihood loss)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # latter block_size part of the sequence
            idx_cond = idx[:, -block_size:]
            # compute the logits and loss
            logits, loss = self(idx_cond)
            # Pick from only the current timestep
            logits = logits[:, -1, :] # (B, T, C) to (B, C)
            probabilites = F.softmax(logits, dim = -1) # (B, C)

            # Sample from the distribution using the probabilites
            idx_next = torch.multinomial(probabilites, num_samples = 1)  # (B, 1)
            idx = torch.cat((idx, idx_next), dim = 1) # Append the sampled token(s) to the running sequence
        
        return idx
            
        
        

In [55]:
model = GPTLanguageModel()
m = model.to(device)

# Using AdamW optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

for iteration in range(max_iters):
    
    # Periodically evaluate loss and once after all training iterations are done
    if iteration % eval_interval == 0 or iteration == max_iters - 1:
        losses = estimate_loss()
        print(f"Step {iteration} : Train Loss  {losses['train']:.4f}, Val Loss {losses['val']:.4f}")
        
    # sample o batch of data
    xb, yb = get_batch('train')
    
    # Evaluate loss
    logits, loss = model(xb, yb)
    print(f"loss : {loss}")
    optimizer.zero_grad(set_to_none = True)  # Setting grads to None rather than 0 for memory efficiency
    loss.backward()
    optimizer.step()

Step 0 : Train Loss  4.5896, Val Loss 4.6097
loss : 4.594367027282715
loss : 3.8994452953338623
loss : 3.6169638633728027
loss : 3.479755401611328
loss : 3.309957981109619
loss : 3.358476400375366
loss : 3.267212152481079
loss : 3.305375576019287
loss : 3.180311679840088
loss : 3.191066265106201
loss : 3.190197706222534
loss : 3.0351381301879883
loss : 3.178213119506836
loss : 3.023707628250122
loss : 3.08182954788208
loss : 2.9676549434661865
loss : 2.9725615978240967
loss : 3.005405902862549
loss : 3.0240659713745117
loss : 3.098667860031128
loss : 3.137862205505371
loss : 3.021477222442627
loss : 3.0966060161590576
loss : 2.943950891494751
loss : 3.0849602222442627
loss : 3.017197370529175
loss : 2.983914852142334
loss : 3.027194023132324
loss : 3.055826425552368
loss : 2.9053826332092285
loss : 2.8429179191589355
loss : 2.884289503097534
loss : 2.8416547775268555
loss : 2.9595534801483154
loss : 2.8305060863494873
loss : 2.8801348209381104
loss : 2.8043065071105957
loss : 2.8624861

loss : 2.4092953205108643
loss : 2.4704999923706055
loss : 2.392650604248047
loss : 2.4062252044677734
loss : 2.318263292312622
loss : 2.348052501678467
loss : 2.273977279663086
loss : 2.3864355087280273
loss : 2.44482421875
loss : 2.3584861755371094
loss : 2.348767042160034
loss : 2.425851583480835
loss : 2.301772356033325
loss : 2.338906764984131
loss : 2.4108381271362305
loss : 2.3891804218292236
loss : 2.3257131576538086
loss : 2.4690256118774414
loss : 2.5030364990234375
loss : 2.4199769496917725
loss : 2.3715693950653076
loss : 2.318317174911499
loss : 2.2343692779541016
loss : 2.3696553707122803
loss : 2.4946095943450928
loss : 2.294741630554199
loss : 2.340785503387451
loss : 2.430307388305664
loss : 2.3587801456451416
loss : 2.295630693435669
loss : 2.42000412940979
loss : 2.349846839904785
loss : 2.344032049179077
loss : 2.320509195327759
loss : 2.437417507171631
loss : 2.4130353927612305
loss : 2.4545137882232666
loss : 2.2845187187194824
loss : 2.2783730030059814
loss : 2.2

loss : 2.185746908187866
loss : 2.0257627964019775
loss : 2.059203624725342
loss : 2.2017905712127686
loss : 2.250516176223755
loss : 2.1290485858917236
loss : 2.250824451446533
loss : 2.0988190174102783
loss : 2.1555135250091553
loss : 2.097372055053711
loss : 2.0150437355041504
loss : 2.111776113510132
loss : 2.094329833984375
loss : 2.0271918773651123
loss : 2.1236143112182617
loss : 2.243177652359009
loss : 2.1897523403167725
loss : 2.1100521087646484
loss : 2.1240978240966797
loss : 2.196443796157837
loss : 2.0887317657470703
loss : 1.9677135944366455
loss : 2.0829861164093018
loss : 2.0637261867523193
loss : 2.1260006427764893
loss : 2.0470921993255615
loss : 2.1386637687683105
loss : 1.9070701599121094
loss : 1.9171069860458374
loss : 2.1169440746307373
loss : 2.1476047039031982
loss : 2.109861135482788
loss : 2.049635887145996
loss : 1.9608360528945923
loss : 1.979924201965332
loss : 2.0695650577545166
loss : 2.0271313190460205
loss : 2.056251287460327
loss : 2.160946846008301


loss : 1.5197137594223022
loss : 1.580690622329712
loss : 1.706526279449463
loss : 1.5951104164123535
loss : 1.6024038791656494
loss : 1.5762357711791992
loss : 1.6013884544372559
loss : 1.663265585899353
loss : 1.5583595037460327
loss : 1.602799654006958
loss : 1.5740516185760498
loss : 1.6534732580184937
loss : 1.6725341081619263
loss : 1.6663711071014404
loss : 1.5697890520095825
loss : 1.6745049953460693
loss : 1.5514754056930542
loss : 1.6673332452774048
loss : 1.677909255027771
loss : 1.6372150182724
loss : 1.6972520351409912
loss : 1.6239728927612305
loss : 1.6440584659576416
loss : 1.5965166091918945
loss : 1.6692332029342651
loss : 1.607088565826416
loss : 1.5839025974273682
loss : 1.6329294443130493
loss : 1.4539520740509033
loss : 1.556962251663208
loss : 1.4123847484588623
loss : 1.659501314163208
loss : 1.6794260740280151
loss : 1.5399638414382935
loss : 1.6710841655731201
loss : 1.541279911994934
loss : 1.6557152271270752
loss : 1.677404761314392
loss : 1.6775599718093872

loss : 1.397759199142456
loss : 1.4017058610916138
loss : 1.4202641248703003
loss : 1.3330955505371094
loss : 1.3763505220413208
loss : 1.3693838119506836
loss : 1.476276159286499
loss : 1.3672266006469727
loss : 1.433622121810913
loss : 1.4135396480560303
loss : 1.473800539970398
loss : 1.495715856552124
loss : 1.4577780961990356
loss : 1.4562187194824219
loss : 1.5070117712020874
loss : 1.500606656074524
loss : 1.387931227684021
loss : 1.4562638998031616
loss : 1.396591067314148
loss : 1.433739185333252
loss : 1.429668664932251
loss : 1.4104597568511963
loss : 1.450168251991272
loss : 1.485832691192627
loss : 1.4642728567123413
loss : 1.4046478271484375
loss : 1.3547717332839966
loss : 1.391914963722229
loss : 1.402285099029541
loss : 1.3998130559921265
loss : 1.3922338485717773
loss : 1.4158837795257568
loss : 1.4627273082733154
loss : 1.3608320951461792
loss : 1.4119353294372559
loss : 1.4556756019592285
loss : 1.4028375148773193
loss : 1.4956254959106445
loss : 1.4505131244659424


loss : 1.321151614189148
loss : 1.3541505336761475
loss : 1.3186520338058472
loss : 1.3249645233154297
loss : 1.3963241577148438
loss : 1.3441747426986694
loss : 1.324652075767517
loss : 1.2839033603668213
loss : 1.3348767757415771
loss : 1.3264321088790894
loss : 1.3078577518463135
loss : 1.351017713546753
loss : 1.3535723686218262
loss : 1.3919739723205566
loss : 1.2412229776382446
loss : 1.4145535230636597
loss : 1.3673105239868164
loss : 1.3477081060409546
loss : 1.3316664695739746
loss : 1.2828069925308228
loss : 1.3315277099609375
loss : 1.296713948249817
loss : 1.3693233728408813
loss : 1.3572561740875244
loss : 1.2950165271759033
loss : 1.312694787979126
loss : 1.2700657844543457
loss : 1.3447580337524414
loss : 1.36197030544281
loss : 1.291976809501648
loss : 1.3184733390808105
loss : 1.3224384784698486
loss : 1.3668339252471924
loss : 1.3676879405975342
loss : 1.3158369064331055
loss : 1.3434418439865112
loss : 1.3720698356628418
loss : 1.219823956489563
loss : 1.238033294677

loss : 1.3059706687927246
loss : 1.250885248184204
loss : 1.2000830173492432
loss : 1.2180994749069214
loss : 1.2545859813690186
loss : 1.2920747995376587
loss : 1.1569658517837524
loss : 1.252584457397461
loss : 1.3147879838943481
loss : 1.2745342254638672
loss : 1.266066312789917
loss : 1.277614712715149
loss : 1.2582751512527466
loss : 1.2950770854949951
loss : 1.2376466989517212
loss : 1.285091519355774
loss : 1.3233920335769653
loss : 1.3100976943969727
loss : 1.3492313623428345
loss : 1.277780294418335
loss : 1.2348591089248657
loss : 1.2531218528747559
loss : 1.2838404178619385
loss : 1.2350252866744995
loss : 1.2685632705688477
loss : 1.2790088653564453
loss : 1.3000984191894531
loss : 1.2262601852416992
loss : 1.2513234615325928
loss : 1.2465778589248657
loss : 1.2443115711212158
loss : 1.2436277866363525
loss : 1.211607575416565
loss : 1.2171056270599365
loss : 1.2637766599655151
loss : 1.2902249097824097
loss : 1.2671940326690674
loss : 1.2871214151382446
loss : 1.2288860082

KeyboardInterrupt: 

In [74]:
# After training for roughly 13 minutes
# Generate from the model
context = torch.zeros((1, 1), dtype = torch.long, device = device)
encoded_text = m.generate(context, max_new_tokens = 500)[0].tolist()
print(decode(encoded_text))

	f med are free the more
On ya I think he was just input offer 5 marks... So ig only you want improve to get to mat you seneshn
Oh
Dude
Come areferristely cally on worth its just go to eg ound one?
Those happth tomorrow
Can you have?
Well so for youbother 1^2..
really latelect
This is really than go to allies a lot who wont me in the know 
We get this her won't changes by this with you really are standing psifically create rt
Yours retuive valuela la
I can all be an lotta whan I mangered that
But


In [57]:
# Saving the model
torch.save(m.state_dict(), './models/chat-model')

In [58]:
# Loading the model
model = GPTLanguageModel()
model = model.to(device)
model.load_state_dict(torch.load('./models/chat-model'))
model.eval()

GPTLanguageModel(
  (token_embedding_table): Embedding(99, 384)
  (position_embedding_table): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (self_attention): MultiHeadedAttention(
        (heads): ModuleList(
          (0): Head(
            (query): Linear(in_features=384, out_features=64, bias=False)
            (key): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (1): Head(
            (query): Linear(in_features=384, out_features=64, bias=False)
            (key): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (2): Head(
            (query): Linear(in_features=384, out_features=64, bias=False)
            (key): Linear(in_features=384, out_features=64, bias

In [70]:
def generate_response(input_text):
    encoded_context = torch.as_tensor(encode(input_text), dtype = torch.long, device = device)
    encoded_context = torch.stack((encoded_context,)) # To turn it into a tensor of dimensions (1, context_length)
    n = len(input_text)
    encoded_text = model.generate(encoded_context, max_new_tokens = random.randint(20, 200))[0].tolist()
    decoded_response = decode(encoded_text)[n:]
    return decoded_response, len(decoded_response) 

In [77]:
input_text = input('You : ')
response, size = generate_response(input_text)
print(f'Response ({size} characters) : {response}')

You : Hey there, how's everything going?
Response (104 characters) : 
 It says came a good to mark slow
So its a website?
naah....I know how but the best...https://and every
