In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## token based on ascii (depricated)

In [5]:
#we wont be using this as we use the spm
encoder = lambda input: [ord(char) for char in input]
decoder = lambda input: ''.join([chr(char) for char in input])
sample_input = "hi there!"
token = encoder(sample_input)
print(f"input         : {sample_input}\n"
      f"token         : {token}\n"
      f"decoded_token : {decoder(token)}")

input         : hi there!
token         : [104, 105, 32, 116, 104, 101, 114, 101, 33]
decoded_token : hi there!


## ~~Sentencepiece based tokens~~ (tested the model is not good with tokenizer)
- need to use the tokenizer_training note book to generate vcal lib and a model

In [6]:
#using a vocab lib of 10000
# sp = spm.SentencePieceProcessor(model_file='token.model')
# print(sp.EncodeAsPieces("hello how are you?"))
# print(sp.tokenize("hello how are you?"))

In [7]:
#gpt2 was good   google/byt5-small
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
text = r"\n"
tokens = tokenizer(text)
print(tokens)  # Might retain \n explicitly

print(tokenizer.decode(113))


{'input_ids': [95, 113, 1], 'attention_mask': [1, 1, 1]}
n


## reading the input and tokenizing

In [8]:
with open("../data/data.txt", "r") as f:
    data = f.read()
##uncommand to work with the ascii
##converting the word into token and creating a tensor of it
# token_data = torch.tensor(encoder(data))

token_data = torch.tensor(tokenizer(data)['input_ids'])
vocab_n = tokenizer.vocab_size
#spliting data into train and val
train_len = int(0.9 * len(token_data))
train_data = token_data[:train_len]
val_data = token_data[train_len:]

print(f"train data length : {train_data.numel()}")
print(f"val data length : {val_data.numel()}")


print("vocab lib size: ",vocab_n)

train data length : 1003854
val data length : 111540
vocab lib size:  256


## create the input and traget, with the batch dim

In [9]:
torch.manual_seed(5672)
def get_batch(batch, block_len, data):
    random_pos = torch.randint(0, len(data)-block_len, (batch,))
    input = torch.stack([data[i:i+block_len] for i in random_pos])
    target = torch.stack([data[i+1:i+block_len+1] for i in random_pos])
    return input, target

In [10]:
batch = 4
block_len = 8
ipt, tgt = get_batch(batch, block_len, train_data)
print(ipt.shape)
print(f"input : \n{ipt.numpy()}\n"
      f"output: \n{tgt.numpy()}")

torch.Size([4, 8])
input : 
[[100 110 104  13  82 113 111 124]
 [105  35  76  35 101 104  35 118]
 [ 35 111 114 117 103  35 100 113]
 [ 13  86 107 100 111 111  47  35]]
output: 
[[110 104  13  82 113 111 124  35]
 [ 35  76  35 101 104  35 118 114]
 [111 114 117 103  35 100 113 103]
 [ 86 107 100 111 111  47  35 122]]


# Model Architecture

In [11]:
class Embedding(nn.Module):
    def __init__(self, embedding_dim=64, n_vocal=128, context_len=50):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocal, embedding_dim) # (vocab_size, embedding_dim)
        self.time_embedding = nn.Parameter(torch.randn((embedding_dim, context_len))) # (embedding_dim, context_len)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        batch, context_len = x.shape
        time_range = torch.arange(context_len, device=x.device)
        time_embed = self.time_embedding[:, time_range].unsqueeze(0).expand(batch, -1, -1)
        token_embed = self.token_embedding(x)
        teken_embed = self.dropout(token_embed)
        return time_embed.permute(0, 2, 1) + token_embed  # (batch, context_len, embedding_dim)

class Head(nn.Module):
    def __init__(self, context_len=50, embedding_dim=64, attention_dim=8, mode="encoder"):
        super().__init__()
        self.query = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.value = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.key = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.mode = mode
        self.dropout = nn.Dropout(0.5)
        self.register_buffer('tril', torch.tril(torch.ones(context_len, context_len)))

    def forward(self, x):
        b, t, c = x.shape  
        q = self.query(x)  
        k = self.key(x)    
        v = self.value(x)  
        d_k = q.shape[-1]
        scaled_scores = torch.bmm(q, k.permute(0, 2, 1)) / (d_k ** 0.5)

        if self.mode == "encoder":
            attn_weights = scaled_scores
        else:
            attn_weights = scaled_scores.masked_fill(self.tril[:t, :t] == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_weights, dim=-1)  
        attn_weights = self.dropout(attn_weights)
        attention_output = torch.bmm(attn_weights, v)  
        return attention_output  

class SelfAttention(nn.Module):
    def __init__(self, context_l=50, embedding_dim=64, heads=8, mode="decoder"):
        super().__init__()
        self.attention_dim = embedding_dim // heads
        self.heads = nn.ModuleList([
            Head(context_len=context_l, embedding_dim=embedding_dim, attention_dim=self.attention_dim, mode=mode)
            for _ in range(heads)
        ])
        self.projection_weight = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))  # Correct normalization
        output = [head(x_ln) for head in self.heads]
        output = torch.cat(output, dim=2)
        x_p = self.projection_weight(output)
        x_p = self.dropout(x_p)
        return x + x_p

class MLP(nn.Module):
    def __init__(self, embedding_dim=64, mlp_multiplier=4):
        super().__init__()
        self.mlp_weight = nn.Linear(embedding_dim, embedding_dim * mlp_multiplier)
        self.mlp_projection = nn.Linear(embedding_dim * mlp_multiplier, embedding_dim)
        self.gelu1 = nn.GELU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))
        output = self.gelu1(self.mlp_weight(x_ln))
        output = self.dropout(output)
        output = self.mlp_projection(output)
        return x + output

class Logits(nn.Module):
    def __init__(self, n_vocal=128, embedding_dim=64):
        super().__init__()
        self.encode = nn.Linear(embedding_dim, n_vocal, bias=False)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))  # Normalize across channels (embedding dim)
        logits = self.encode(x_ln)  # (b, t, n_vocal)
        return logits  # No softmax applied

class TransformerMini(nn.Module):
    def __init__(self, context_l=50, n_vocal=128, embedding_dim=64, attention_heads=8, mode="decoder"):
        super().__init__()
        self.embedding = Embedding(embedding_dim, n_vocal, context_l)
        self.self_attention = SelfAttention(context_l, embedding_dim, attention_heads, mode)
        self.mlp = MLP(embedding_dim)
        self.logits = Logits(n_vocal, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.self_attention(x)
        x = self.mlp(x)
        x = self.logits(x)
        return x  # Output shape: (batch, context_len, n_vocal)

# training

In [12]:
pretrained_model = "../models/transformer_trained_model/trained_v0.0.3.pth"
# pretrained_model = None
context_lenth = 512
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TransformerMini(context_l=context_lenth, n_vocal=vocab_n, embedding_dim=256, attention_heads=32).cuda()
if pretrained_model:
    model.load_state_dict(torch.load(pretrained_model, weights_only=True), strict=False)
model.train()

TransformerMini(
  (embedding): Embedding(
    (token_embedding): Embedding(256, 256)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (self_attention): SelfAttention(
    (heads): ModuleList(
      (0-31): 32 x Head(
        (query): Linear(in_features=256, out_features=8, bias=False)
        (value): Linear(in_features=256, out_features=8, bias=False)
        (key): Linear(in_features=256, out_features=8, bias=False)
        (dropout): Dropout(p=0.5, inplace=False)
      )
    )
    (projection_weight): Linear(in_features=256, out_features=256, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (mlp): MLP(
    (mlp_weight): Linear(in_features=256, out_features=1024, bias=True)
    (mlp_projection): Linear(in_features=1024, out_features=256, bias=True)
    (gelu1): GELU(approximate='none')
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (logits): Logits(
    (encode): Linear(in_features=256, out_features=256, bias=False)
  )
)

In [77]:
lr = 1e-3
wd=0
batch = 64
epoch = 10000

In [78]:
optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=wd)
criterion = nn.CrossEntropyLoss()

In [80]:
for i in range(epoch):
    model.train()
    
    # Get training batch
    xb, yb = get_batch(batch, context_lenth, train_data)
    xb, yb = xb.to(device), yb.to(device)
    
    optimizer.zero_grad()
    output = model(xb)  # Forward pass
    b, t, n = output.shape

    # Compute training loss
    loss = criterion(output.view(b * t, n), yb.view(b * t))  
    loss.backward()
    optimizer.step()

    # Validation phase
    model.eval()
    with torch.no_grad():
        vxb, vyb = get_batch(batch, 128, val_data)
        vxb, vyb = vxb.to(device), vyb.to(device)

        output = model(vxb)  # Forward pass (no gradient)
        b, t, n = output.shape

        # Compute validation loss
        vloss = criterion(output.view(b * t, n), vyb.view(b * t))  

    # Print progress
    print(f"Epoch: {i}, Train Loss: {loss.item():.4f}, Val Loss: {vloss.item():.4f}")


Epoch: 0, Train Loss: 2.4586, Val Loss: 2.4232
Epoch: 1, Train Loss: 2.4485, Val Loss: 2.4135
Epoch: 2, Train Loss: 2.4588, Val Loss: 2.4402
Epoch: 3, Train Loss: 2.4541, Val Loss: 2.4219
Epoch: 4, Train Loss: 2.4507, Val Loss: 2.4166
Epoch: 5, Train Loss: 2.4574, Val Loss: 2.4049
Epoch: 6, Train Loss: 2.4357, Val Loss: 2.4313
Epoch: 7, Train Loss: 2.4434, Val Loss: 2.4194
Epoch: 8, Train Loss: 2.4567, Val Loss: 2.4000
Epoch: 9, Train Loss: 2.4402, Val Loss: 2.4254
Epoch: 10, Train Loss: 2.4333, Val Loss: 2.4473
Epoch: 11, Train Loss: 2.4401, Val Loss: 2.4167
Epoch: 12, Train Loss: 2.4559, Val Loss: 2.4417
Epoch: 13, Train Loss: 2.4496, Val Loss: 2.4329
Epoch: 14, Train Loss: 2.4569, Val Loss: 2.4310
Epoch: 15, Train Loss: 2.4526, Val Loss: 2.4539
Epoch: 16, Train Loss: 2.4407, Val Loss: 2.4575
Epoch: 17, Train Loss: 2.4543, Val Loss: 2.4186
Epoch: 18, Train Loss: 2.4409, Val Loss: 2.4135
Epoch: 19, Train Loss: 2.4379, Val Loss: 2.4057
Epoch: 20, Train Loss: 2.4360, Val Loss: 2.4119
Ep

KeyboardInterrupt: 

In [None]:
def generator(model:TransformerMini, token_len):
    token = torch.tensor([[0]]).to(device)
    with torch.no_grad():
        for i in range(token_len):
            out = model(token)
            sm_out = F.softmax(out[:,-1,:], dim=-1)
            predict = torch.multinomial(sm_out, num_samples=1)
            token = torch.cat((token, predict), dim=1)[:, -(context_lenth-1):]
            print(tokenizer.decode(predict[0].cpu().detach().tolist()), end='')
        

In [14]:
generator(model.eval(), 1000)

ABULIAu herfor knoureembe I seaccomerereaught.

LINGE:
If thall slady see Made sair I thing fee the's a neass
Hamel thesen foirscory muteth wothe pers es heir a joy fiter
band therefuld of his mus to Bus erear theser laike so an crock
On yousend; ust me seand all and you hour of trus ham,
This is will goveredly weas bocith to be Band!

DUKE OF YORK:
Sill buson, home bethe to he'll hemselfe?

Tho can-do wand nould you, murce you and him with he porse,
Wore me and by warland wrefor nuth your of you.

Shard your king thall worder be such you;
And he lover hou moth, comarters letche aded murt.

KING RICHARD II:
And your of noy lood, kis, I was was him noth kind:
the ve have vuse sajoy with.

CORD OFIO:
But of peas dieng in a wing you you,
Sothen A she a bee He paid to sold to spot I what he fort,
The the for stre? wher hy, sher nobery is barrest thall dime and weared
Her forrdius she, shapur ther that of me?

POLIXENRY Bordirselve ine somblid a cone appost? I hrue bron rear them do der lea

In [84]:
torch.save(model.state_dict(), "../models/transformer_trained_model/trained_v0.0.3.pth")

In [85]:
torch.cuda.empty_cache()