In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x1da9d545830>

In [21]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [22]:
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    shakespeare = f.read()
    
# List all unique characters that occurs in the input text
chars = sorted(list(set(shakespeare)))
vocab_size = len(chars)

# mapping from characters to integers for encoding
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s : [stoi[c] for c in s] # Take a string and output a list of integers
decode = lambda i : ''.join([itos[c] for c in i]) # Take a list of integers and output a list of string

# encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(shakespeare), dtype=torch.long)

# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [23]:
torch.manual_seed(1337)
batch_size = 4 # How many independent sequences will be process in parallel?
block_size = 64 # What is the maximum context length for predictions?

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device=DEVICE), y.to(device=DEVICE)

In [24]:
class Head(nn.Module):
    """ One head of self-attention """
    
    def __init__(self, n_embd, block_size, head_size) -> None:
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril((torch.ones(block_size, block_size))))
    
    def forward(self, x):
        B, T, C = x.shape
        
        # Compute the attention score using k and q 
        k = self.key(x)
        q = self.query(x)
        w = q @ k.transpose(-2, -1) * C ** -0.5
        w = w.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        w = F.softmax(w, dim=-1)
        
        # weight aggregation of the values using the attention score
        v = self.value(x)
        out = w @ v
        
        return out

In [25]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    
    def __init__(self, n_embd, block_size, head_size, n_head, dropout):
        super().__init__()
        self.head = nn.ModuleList([Head(n_embd, block_size, head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Use multiple heads of self-attention 
        # Concatanate each output alone the embedding dimension (B, T ,C), that is merged alone C
        out = torch.cat([h(x) for h in self.head], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [26]:
class FeedForward(nn.Module):
    """ A simple linear layer followed by a non-linearity """
    
    def __init__(self, n_embd, dropout) -> None:
        super().__init__()
        self.net = nn.Sequential(
                    nn.Linear(n_embd, 4 * n_embd),
                    nn.ReLU(),
                    nn.Linear(4 * n_embd, n_embd),
                    nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)

In [27]:
class AttnBlock(nn.Module):
    def __init__(self, n_embd, block_size, n_head, dropout) -> None:
        super().__init__()
        head_size = n_embd // n_head
        self.attn = MultiHeadAttention(n_embd, block_size, head_size, n_head, dropout)
        self.ffn = FeedForward(n_embd, dropout)
        self.ln = nn.LayerNorm(n_embd)
    
    def forward(self, x):
        # Use the residual connection with layernorm
        x = x + self.attn(self.ln(x))
        x = x + self.ffn(self.ln(x))

        return x

In [28]:
# A modified neural network for bigram model
class Transformer(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout) -> None:
        super().__init__()
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.attn_blocks = nn.Sequential(*[AttnBlock(n_embd, block_size, n_head, dropout) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(n_embd)
        self.ln_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
        x = tok_emb + pos_emb
        x = self.attn_blocks(x)
        x = self.ln(x)
        logits = self.ln_head(x)            
        
        if targets is None: # for generation without providing target
            loss = None
        else:
            # Reshape the logits tensor to meet definition of the cross_entropy function in Pytorch
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Concatenate B, T
            targets = targets.view(B*T)    # Same reshaping to the target tensor
            loss = F.cross_entropy(logits, targets) # calculate the loss
        
        return logits, loss
    
    def generation(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens to meet position embedding range
            idx_crop = idx[:, -self.block_size:]
            
            # get the predictions
            logits, _ = self(idx_crop)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
            
        return idx

In [33]:
model = Transformer(vocab_size, n_embd=32, block_size=block_size, n_head=4, n_layer=6, dropout=0.0)
model.to(device=DEVICE)

Transformer(
  (token_embedding_table): Embedding(65, 32)
  (position_embedding_table): Embedding(64, 32)
  (attn_blocks): Sequential(
    (0): AttnBlock(
      (attn): MultiHeadAttention(
        (head): ModuleList(
          (0-3): 4 x Head(
            (key): Linear(in_features=32, out_features=8, bias=False)
            (query): Linear(in_features=32, out_features=8, bias=False)
            (value): Linear(in_features=32, out_features=8, bias=False)
          )
        )
        (proj): Linear(in_features=32, out_features=32, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ffn): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=32, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
      (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
    (1): AttnBlock(
      (attn): MultiHeadAttention(

In [34]:
# params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = sum(p.numel() for p in model.parameters())
print(f'TOTAL params num: {params}')

TOTAL params num: 81601


In [35]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [36]:
from tqdm import tqdm

batch_size = 32
epochs = 5000
losses = []
for e in tqdm(range(epochs)):
    xb, yb = get_batch('train')
    
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

100%|██████████| 5000/5000 [01:49<00:00, 45.50it/s]


In [37]:
print(loss.item())

1.7031662464141846


In [39]:
print(decode(model.generation(idx=torch.zeros((1, 1), dtype=torch.long, device=DEVICE), max_new_tokens=500)[0].tolist()))


MOPEES:
Power grivent I why myself have mysel--'Kome.

Gentlematan:
What shall so not us bod, whom.

My some should sold. A En shondluce.

Serself Our be thou tend?
Carst to Pesirk ovet some so upon sixe fear.
As not come, not sun?

Your lowfUere, soxe are and barakes again
Abucces my poserfore long my doscelly no crive,
Tcleadins, his I it him.

GLOUCUSTERCE:
Fyreions, to meimed the no stee summe wip.

GLOUCEMELA:
Which brukle Enjeporn man I so hander
Sibe make chook Yow their forcesy, know our


In [40]:
model_size_str = str(params/1e6)
model_file_name = 'transformer' + model_size_str + 'M.pth'
model_save_path = '../params/transformer/' + model_file_name
torch.save(model.state_dict(), model_save_path)

In [45]:
trained_model = Transformer(vocab_size, n_embd=32, block_size=block_size, n_head=4, n_layer=6, dropout=0.0).to(DEVICE)
trained_model.load_state_dict(torch.load('../params/transformer/transformer0.081601M.pth'))

<All keys matched successfully>

In [46]:
print(decode(trained_model.generation(idx=torch.zeros((1, 1), dtype=torch.long, device=DEVICE), max_new_tokens=500)[0].tolist()))



Miser, I blanives have to ackin, Clombeing how
and, good atce
It have druess? Do nobnow.

Find, I knothged the shall thou fries thank be craise.

My Somenor of Gloundence.

QUEEA:
I in canise on cloil him.

JULIET:
Which well, sir did hrother what stendemisss Plencuaic;
When sab here but have to muttan'd,
Let Blovege yermbeich bloing here.

ixcluver:
Go no
Cleapt my Glorain:
Lethey corthall, hust in by Clrest to brumblel,
Her crince you'lk'
the sby love yhalt he jroyal
him, thing wear wonl is t
