In [1]:
import torch
import torch.nn as nn 
from torch.nn import functional as F

In [12]:
torch.__version__

'2.4.0'

In [2]:
n_emb=6
head_size=1
block_size=8

class Head(nn.Module):
    
    '''one head in self attention'''

    def __init__(self, head_size):
        super().__init__()
        self.key=nn.Linear(n_emb,head_size)
        self.query=nn.Linear(n_emb,head_size)
        self.value=nn.Linear(n_emb,head_size)
        
        self.register_buffer('trill', torch.tril(torch.ones(block_size,block_size)))


    
    def forward(self,x):
        batch, blocks, X = x.shape
        key = self.key(x) # batch, block_size, X -- shape
        query = self.query(x) # batch, block_size, X -- shape
        weight = query @ key.transpose(-2, -1) * X ** (-0.5)
        weight=weight.masked_fill(self.trill[:blocks, :blocks] ==0 , float('-inf'))
        weight=F.softmax(weight, dim=-1)
        out = weight @ self.value(x)
        return out


        
    

In [3]:
h=Head(2)
h

Head(
  (key): Linear(in_features=6, out_features=2, bias=True)
  (query): Linear(in_features=6, out_features=2, bias=True)
  (value): Linear(in_features=6, out_features=2, bias=True)
)

In [7]:
h(torch.zeros(3,8,6))

tensor([[[-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735]],

        [[-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735]],

        [[-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735],
         [-0.3627,  0.3735]]], grad_fn=<UnsafeViewBackward0>)

In [9]:
class MultiHeadAttention(nn.Module):
    '''multihead in self attention'''
    def __init__(self, head_size, num_heads):
        super(),__init__()
        self.heads=nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.layer=nn.Linear(n_emb,n_emb)
        
    def forward(self,x):
        out=torch.cat([h(x) for h in self.head], dim=-1)
        return self.layer(out)


    
   
    

In [17]:
!wget -P C:\Users\User\Documents\amala\Refonte-Learning https://gist.githubusercontent.com/Momnadar1/8805a6d53e92d6be17b9837711a5931a/raw/adc9cc97efc92232f01cbb6a1b13e8123d9f8f8d/shakepeare_s_plays.txt --no-check-certificate

--2024-08-23 13:22:11--  https://gist.githubusercontent.com/Momnadar1/8805a6d53e92d6be17b9837711a5931a/raw/adc9cc97efc92232f01cbb6a1b13e8123d9f8f8d/shakepeare_s_plays.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5583374 (5.3M) [text/plain]
Saving to: 'C:/Users/User/Documents/amala/Refonte-Learning/shakepeare_s_plays.txt.1'

     0K .......... .......... .......... .......... ..........  0% 1.62M 3s
    50K .......... .......... .......... .......... ..........  1% 27.1M 2s
   100K .......... .......... .......... .......... ..........  2% 5.54M 1s
   150K .......... .......... .......... .......... ..........  3% 2.96M 2s
   200K .......... .......... .......... .......... ..........  4% 3.72M 1s
   250K .......... .......... .......... ..........

In [1]:
with open('shakepeare_s_plays.txt', 'r', encoding="utf8") as f:
  text = f.read()
    

In [3]:
print(text[:100])

# Hamlet

ACT I
SCENE I. Elsinore. A platform before the castle.

    FRANCISCO at his post. Enter t


In [5]:
# a:1, b:,2 ....z:26 A:27 etc
chars=sorted(list(set(text)))
vocab_size=len(chars)

In [7]:
print(''.join(chars),vocab_size)

	
 !#$&'(),-.0123456789:;=>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyzÀè—‘’“”… 90


In [9]:
#encoding and decoding
str_to_int={char:i for i,char in enumerate(chars)}
int_to_str={i:char for i,char in enumerate(chars)}

encode=lambda string:[str_to_int[s] for s in string]
decode=lambda indexs:[int_to_str[i] for i in indexs]

encode('hi')

[63, 64]

In [11]:
decode(encode('hi'))

['h', 'i']

In [13]:
''.join(decode(encode('hi')))

'hi'

In [25]:
''.join(decode(encode('hi work!!')))

'hi work!!'

In [27]:
len(text)

5580526

In [15]:
import torch

data=torch.tensor(encode(text), dtype=torch.int64)
data[:100],data.dtype, data.shape

(tensor([ 4,  2, 35, 56, 68, 67, 60, 75,  1,  1, 28, 30, 47,  2, 36,  1, 46, 30,
         32, 41, 32,  2, 36, 12,  2, 32, 67, 74, 64, 69, 70, 73, 60, 12,  2, 28,
          2, 71, 67, 56, 75, 61, 70, 73, 68,  2, 57, 60, 61, 70, 73, 60,  2, 75,
         63, 60,  2, 58, 56, 74, 75, 67, 60, 12,  1,  1,  2,  2,  2,  2, 33, 45,
         28, 41, 30, 36, 46, 30, 42,  2, 56, 75,  2, 63, 64, 74,  2, 71, 70, 74,
         75, 12,  2, 32, 69, 75, 60, 73,  2, 75]),
 torch.int64,
 torch.Size([5580526]))

In [17]:
splitter=int(0.9*len(data))
train,test=data[:splitter],data[splitter:]


In [19]:
block_size=8
train[:block_size+1]

tensor([ 4,  2, 35, 56, 68, 67, 60, 75,  1])

In [21]:
x=data[:block_size]
y=data[1:block_size+1]

for next in range(block_size):
    context=x[:next+1]
    target=y[next]
    print(f"This is my context:{context}, while the target is:{target}")



This is my context:tensor([4]), while the target is:2
This is my context:tensor([4, 2]), while the target is:35
This is my context:tensor([ 4,  2, 35]), while the target is:56
This is my context:tensor([ 4,  2, 35, 56]), while the target is:68
This is my context:tensor([ 4,  2, 35, 56, 68]), while the target is:67
This is my context:tensor([ 4,  2, 35, 56, 68, 67]), while the target is:60
This is my context:tensor([ 4,  2, 35, 56, 68, 67, 60]), while the target is:75
This is my context:tensor([ 4,  2, 35, 56, 68, 67, 60, 75]), while the target is:1


In [23]:
batch_size=3
def batches(split):
   data=train if split=='train' else test
   #randmly selecting
   print(len(data))
   indexes=torch.randint(len(data)-block_size, (batch_size,))
   x=torch.stack([data[i:i+block_size] for i in indexes])
   y=torch.stack([data[i+1:i+1+block_size] for i in indexes])
   return x,y
   #print(indexes)
x,y=batches('train')
#print(x.shape)
#print(x)
#print(y)
for b in range(batch_size):
    for next in range(block_size):
        context=x[b,:next+1]
        target=y[b,next]
        print(f"This is my context:{context}, while the target is:{target}")
    

5022473
This is my context:tensor([60]), while the target is:67
This is my context:tensor([60, 67]), while the target is:67
This is my context:tensor([60, 67, 67]), while the target is:2
This is my context:tensor([60, 67, 67,  2]), while the target is:56
This is my context:tensor([60, 67, 67,  2, 56]), while the target is:74
This is my context:tensor([60, 67, 67,  2, 56, 74]), while the target is:2
This is my context:tensor([60, 67, 67,  2, 56, 74,  2]), while the target is:57
This is my context:tensor([60, 67, 67,  2, 56, 74,  2, 57]), while the target is:80
This is my context:tensor([80]), while the target is:2
This is my context:tensor([80,  2]), while the target is:64
This is my context:tensor([80,  2, 64]), while the target is:74
This is my context:tensor([80,  2, 64, 74]), while the target is:2
This is my context:tensor([80,  2, 64, 74,  2]), while the target is:75
This is my context:tensor([80,  2, 64, 74,  2, 75]), while the target is:63
This is my context:tensor([80,  2, 64, 7

In [29]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class TextGeni(nn.Module):
    def __init__(self):
        super().__init__()
        self.lookup_token_emd_table=nn.Embedding(vocab_size,vocab_size)
    def forward(self,x):
        batches, block_size=x.shape
        out=self.lookup_token_emd_table(x)
        return out
    def generate(self, x, max_tokens):
        for _ in range(max_tokens):
            logits=self(x)
            #print(logits)
            print(logits.shape)
            logits=logits[:, -1, :]
            print(logits)
            probabilities=F.softmax(logits, dim=-1)
            
        
geni=TextGeni() 
output=geni(x)
#print(output.shape)
#print(output)
geni.generate((torch.zeros((1,1),dtype=torch.long)),10)


torch.Size([1, 1, 90])
tensor([[-1.8129,  0.7621, -0.0365,  0.9324,  2.1147,  1.5161,  0.6990,  0.0609,
          0.1383,  1.1420,  0.4576, -0.5154,  0.7747, -0.4448, -1.3667, -0.6333,
          1.0655, -2.3393, -0.6099, -0.2064, -0.8126, -1.6596, -0.0900,  0.1723,
         -1.5762,  0.0262, -0.2389,  1.0453,  1.6601, -0.3469,  0.0386,  0.0477,
         -0.2516, -0.0183, -0.1743,  1.1065,  1.2372, -0.8451,  0.5743, -0.8301,
         -0.4035, -0.6499, -0.3282,  1.3803,  0.0377,  1.0299, -0.1952,  0.1535,
         -1.7138,  0.6067,  0.0058, -0.0295,  1.3418,  1.3028, -0.9031,  1.0262,
          0.7465,  1.0692,  0.9010,  0.3545, -0.1480, -0.7613,  0.0462,  0.2363,
         -0.5312, -0.1501, -0.1388, -0.2764, -1.2187, -0.7612, -0.1348, -0.3404,
          0.9116, -0.9193,  0.4334, -0.0486,  2.7682,  1.1959,  0.3251, -0.7664,
         -0.8045,  0.0730, -1.1163,  1.9472, -0.2097, -0.0771, -0.9505, -1.3920,
          0.7582, -0.7046]], grad_fn=<SliceBackward0>)
torch.Size([1, 1, 90])
tensor([

In [35]:
geni_text = []
def generate(x, max_tokens):
    for _ in range(max_tokens):
      logits = geni(x)
      logit = logits[:, -1, :]
      # print(logits.shape)
      probilities = F.softmax(logits, dim=-1).view(-1, vocab_size) # 1, 90
      next_x = torch.multinomial(probilities, num_samples=1)
      geni_text.append(int(next_x))
    return geni_text
      # print(next_x)
print(''.join(decode(generate(torch.zeros((1,1), dtype=torch.long), 200))))        

P$$#KR]>3&èuz#AA3ku!#P=vX&Zz—A—W#vè1uHwrèPuDèzL3:)BIdèQYYSu
2
ov'u?CU?$RVPA?cb.?qb'=u?  ugY9H$$A#u$auuAè”HzbzsxZ]uu!è cè’P'è-)R#$$è=uA.Pl]))u#k$]…fg=H4bh#kH:ba#YG:$èèTcu3]Y29—pè”cGubE—#R!Av”u#l#utwM#P


In [37]:
''.join(decode(geni_text))

"P$$#KR]>3&èuz#AA3ku!#P=vX&Zz—A—W#vè1uHwrèPuDèzL3:)BIdèQYYSu\n2\nov'u?CU?$RVPA?cb.?qb'=u?  ugY9H$$A#u$auuAè”HzbzsxZ]uu!è cè’P'è-)R#$$è=uA.Pl]))u#k$]…fg=H4bh#kH:ba#YG:$èèTcu3]Y29—pè”cGubE—#R!Av”u#l#utwM#P"

In [39]:
a = [1,2,3,4]
a[-1]

4