In [1]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'CPU'
print(device)

torch.manual_seed(1337)

CPU


<torch._C.Generator at 0x1876ddaf010>

In [2]:
with open('datos_sancho_mini.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(len(text))

147758


In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('--- Vocabulario---')
print(''.join(chars))
print(f'Tamaño del vocabulario {vocab_size}')

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i]for i in l])

test_string = 'Hola Sancho'
encoded_string = encode(test_string)
decoded_string = decode(encoded_string)
print('\nPrueba de Tokenización:')
print(f'{encoded_string}')

print('\nPrueba de Tokenización:')
print(f'{decoded_string}')



--- Vocabulario---

 ,.;<>ABCDEFGHIJLMNOPQRSTUVXYZabcdefghijlmnopqrstuvxyz|ÁÉÍÑÓÚÜáéíñóúü
Tamaño del vocabulario 70

Prueba de Tokenización:
[14, 44, 41, 31, 1, 24, 31, 43, 33, 38, 44]

Prueba de Tokenización:
Hola Sancho


In [5]:
n_embd = 256
block_size = 256

token_embedding_table  = nn.Embedding(vocab_size, n_embd)
position_embedding_table = nn.Embedding(block_size, n_embd)

#Ejemplo de como se combina
#Tomemos un indice de Token de ejemplo y una posicion

idx_ejemplo = torch.tensor([[encode('hola')[0]]], dtype=torch.long) #H --> 40
pos_ejemplo= torch.arange(0,1, dtype=torch.long) #Posicion 0

tok_emb = token_embedding_table(idx_ejemplo) # (1,1,n_embd)
pos_emb = position_embedding_table(pos_ejemplo) #(1, n_embd)

x =tok_emb + pos_emb # asi se combinan
 

In [None]:
ln_ejemplo = nn.LayerNorm(n_embd)

x_normalizado = ln_ejemplo(x)

tensor([[ 3.4578e-01, -2.0047e+00, -3.1606e-01,  2.6110e+00,  1.8168e+00,
         -1.5309e-01, -1.4299e+00,  8.1075e-01,  4.2175e-02, -4.2113e-01,
         -7.8404e-01,  7.3967e-01, -2.9461e-01,  4.7753e-01, -8.0787e-01,
          1.0477e+00, -5.3975e-02, -9.8579e-01,  1.4262e+00, -9.5409e-02,
         -4.6216e-01, -3.3573e-01,  7.6422e-01,  2.6135e-01,  1.5747e-01,
         -7.0762e-01, -5.1379e-01, -1.0642e+00, -1.1218e+00, -2.2097e+00,
         -1.4141e+00,  1.6781e+00,  2.5293e+00,  1.1368e+00,  9.1290e-01,
          1.2258e+00,  3.6111e-01, -6.2860e-01, -3.6603e-02, -6.0793e-02,
         -4.5537e-01, -5.9513e-01,  1.7138e+00, -4.5430e-02,  7.9988e-01,
         -1.3894e+00, -1.2008e+00, -1.8659e+00,  4.3022e-01, -1.3809e-01,
         -1.4678e+00, -1.0547e+00, -1.2502e-01, -2.1079e+00,  7.7740e-02,
         -1.1599e+00, -8.9589e-01, -9.8372e-01,  1.4710e+00,  1.0209e+00,
         -1.7082e-01,  1.0647e-01, -2.0987e+00,  1.5355e+00,  1.1961e+00,
          1.7936e-01,  6.9488e-01, -6.

In [None]:
n_head = 4 
dropout = 0.2

class Head(nn.module):
    """ Una cabeza de self-attention"""
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        f.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * (C/4)**-0.5
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """ Multiples cabezas de self attention en paralelo"""
    def __init_(self, num_heads, head_size):
        super().__init__()
        self.heads() = nn.ModuleList([Head(head_size=head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.dropout(self.proj(out))
        return out