In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tiktoken
import numpy as np
from Single_Block_Fortnite import EmbeddingNN, AttentionNN, NormNN, FFN


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("C:/Users/PC/Desktop/Important Data/Shakespeare_Data.txt", "r", encoding="utf-8") as f:
    text = f.read()
    chars = sorted(list(set(text)))
    chars.append("<E>")
    vocab_size = len(chars)
  
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l])
data = torch.tensor(encode(text) , dtype = torch.long).to(device)


print(decode(encode("POG")))
print(chars)


POG
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '<E>']


In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()

        self.attention = AttentionNN(hidden_dim, num_heads)
        self.ffn = FFN(hidden_dim)
        self.norm1 = NormNN(hidden_dim)
        self.norm2 = NormNN(hidden_dim)

    def forward(self, X):
        attention_output = self.attention(X)
        X = X + attention_output

        pog_X = self.norm1(X)

        pogger_X = self.ffn(pog_X) + pog_X

        poggers_X = self.norm2(pogger_X)

        return poggers_X

In [4]:
class Transformer(nn.Module):
    def __init__(self,vocab_size , hidden_dim , num_heads,max_sequence_length, N):
        super().__init__()

        self.embedding = EmbeddingNN(vocab_size, hidden_dim, max_sequence_length)
        self.blocks = nn.ModuleList([TransformerBlock(hidden_dim, num_heads) for i in range(N)])
        self.linear = nn.Linear(hidden_dim,vocab_size)

    def forward(self, X):
        X = self.embedding(X)
        for block in self.blocks:
            X = block(X)
        X = self.linear(X)

        return X

In [6]:
batch_size =16
hidden_dim = 512
max_sequence_length = 512
num_heads = 8
seq_len = 128
epochs = 3

In [7]:
Transformer_model = Transformer(vocab_size, hidden_dim, num_heads, max_sequence_length, 12).to(device)
optimizer = optim.Adam(Transformer_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
n = 0
loss = None
batch_y = []
for epoch in range(epochs):
    n = n+1
    print(f"Epoch Number: {n}")
    for start in range(0, len(data) - batch_size*seq_len, batch_size*seq_len):
    
        end = start+batch_size*seq_len
        if end + 1 > len(data):
            break
    
        batch_x = data[start : end].view(batch_size, seq_len).to(device)
        batch_y = data[start+1 : end+1].view(batch_size, seq_len).to(device)
        
        logits = Transformer_model(batch_x).to(device)
    
        logits = logits.view(-1 , vocab_size)
        batch_y = batch_y.view(-1)
    
        loss = criterion(logits, batch_y)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if loss is not None:
        print("Final loss:", loss.item())
    else:
        print("No batches were processed. Loss not defined.")

Epoch Number: 1
Final loss: 0.03188561648130417
Epoch Number: 2
Final loss: 0.02372751757502556
Epoch Number: 3
Final loss: 0.021613026037812233


In [8]:
torch.save(Transformer_model, "Sigma_Transformer_full.pth")


In [21]:
max_length = 100
input_user = encode(input())
input_text = torch.tensor([input_user]).to(device)
inference = []
for _ in range(max_length):
    logits = Transformer_model(input_text).to(device)
    fortnite = logits[:,-1,:]
    probs = torch.softmax(fortnite, dim = -1)
    next_token = torch.multinomial(probs, 1)
    input_text = torch.cat([input_text, next_token], dim=1)
    if next_token == "<E>":
        break
    else:
        inference.append(decode(next_token.view(-1).tolist()))
        
print(''.join(inference))


 fortnite


enntviveiviinvi; LULULLLULLLLLULUQULEinxcizizUQUCl;'ULULLULLULLILUQULLUCALLULLAULu,quveviz; LULUQUJU
