In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        
    def forward(self, x):
        # SwiGLU: Swish(xW1) ⊙ (xW3) W2
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        # More numerically stable RMSNorm
        # Calculate mean of squares
        mean_square = x.pow(2).mean(dim=-1, keepdim=True)
        # RMS normalization
        rms = torch.sqrt(mean_square + self.eps)
        return self.weight * x / rms

class FnetBlock(nn.Module):
    
    def __init__(self, embed_dim):
        super().__init__()
        self.rmsnorm1 = RMSNorm(embed_dim)
        
        self.rmsnorm2 = RMSNorm(embed_dim)
        
        # self.mlp = nn.Sequential(
        #     nn.Linear(embed_dim, embed_dim*2),
        #     nn.ReLU(),
        #     nn.Linear(embed_dim*2, embed_dim*2),
        #     nn.ReLU(),
        #     nn.Linear(embed_dim*2, embed_dim)
        # )
        
        self.mlp = SwiGLU(embed_dim, embed_dim*4)
        
    def forward(self, x):
        
        out = x + torch.fft.fft(self.rmsnorm1(x), dim=1).real
        out = out + self.mlp(self.rmsnorm2((out)))
        return out
    
    
class FNET(pl.LightningModule):
    
    def __init__(self, embed_dim, context_length, vocab_size, num_layers=4, lr=0.0001):
        super().__init__()
        
        self.lr = lr
        
        self.context_length = context_length
        
        self.word_embeddings = nn.Embedding(vocab_size, embed_dim)
        
        self.pos_embeddings = nn.Embedding(context_length, embed_dim)
        
        self.blocks = nn.ModuleList([FnetBlock(embed_dim) for _ in range(num_layers)])
        
        self.norm = RMSNorm(embed_dim)
        
        self.output = nn.Linear(embed_dim, vocab_size, bias=False)
        
        # print(f"self.output.weight.shape: {self.output.weight.shape}")
        
        # print(f"self.word_embeddings.weight.shape: {self.word_embeddings.weight.shape}")
        
        self.loss_func = nn.CrossEntropyLoss()
        
        self.output.weight = self.word_embeddings.weight
        
    def forward(self, input_ids, attention_mask:Optional[torch.tensor]=None):
        
        embs = self.word_embeddings(input_ids) + self.pos_embeddings(torch.arange(0, self.context_length).to(input_ids.device))
        
        if attention_mask:
            attention_mask = torch.tril(torch.ones((self.context_length, self.context_length), device=input_ids.device))
            mask = attention_mask.unsqueeze(-1).expand_as(embs)
            embs = embs*mask
            
        for layer in self.blocks:
            embs = layer(embs)
            
        embs = self.norm(embs)
        
        logits = self.output(embs)
        return logits
    # Fix 1: Update training_step to reshape logits and targets
    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        
        # print(x.shape, y.shape)
        
        out = self(x)  # Shape: (batch_size, seq_len, vocab_size)
        
        # Reshape for CrossEntropyLoss
        loss = self.loss_func(out.view(-1, out.size(-1)), y.view(-1))
        
        self.log("train_loss", loss, prog_bar=True)
        
        return loss

    # Fix 2: Update validation_step similarly
    def validation_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        
        out = self(x)
        
        # Reshape for CrossEntropyLoss
        loss = self.loss_func(out.view(-1, out.size(-1)), y.view(-1))
        
        self.log("val_loss", loss, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        
        optimizer = torch.optim.AdamW(self.parameters(), lr = self.lr)
        
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer,
        #     mode='min'
        # )
        
        return {
            "optimizer": optimizer,
        }
         
        
    # def configure_gradient_clipping(self, optimizer, gradient_clip_val = None, gradient_clip_algorithm = None):
    #     self.clip_gradients(
    #         optimizer,
    #         gradient_clip_val=1.0, 
    #         gradient_clip_algorithm='norm'
    #     )
        
        
class TokenDataset(Dataset):
   def __init__(self, tokens_path="Tokens.pt", max_length=512):
       self.tokens = torch.load(tokens_path, weights_only=True)
       self.tokens = self.tokens[:4194304]
       self.max_length = max_length
       
   def __len__(self):
       return len(self.tokens) - self.max_length
   
   def __getitem__(self, idx):
       x = self.tokens[idx:idx + self.max_length]
       y = self.tokens[idx + 1:idx + self.max_length + 1]
       return x, y

# def create_dataloader(tokens_path="Tokens.pt", batch_size=16, max_length=512, num_workers=4):
#    dataset = TokenDataset(tokens_path, max_length)
#    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)        
  
# Fix 3: Update create_dataloader to use num_workers=0 (disable multiprocessing)
def create_dataloader(tokens_path="Tokens.pt", batch_size=16, max_length=512, num_workers=0):
    dataset = TokenDataset(tokens_path, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)      
        
        
        
        
               
def train_model():
    
    embed_dim, context_length, vocab_size = 768, 512, 30002
    
    model = FNET(embed_dim, context_length, vocab_size)
    
    train_loader = create_dataloader(batch_size=8)
    
    trainer = pl.Trainer(
        max_epochs=100,
        enable_progress_bar=True,
        num_nodes=1,
        enable_checkpointing=True,
        gradient_clip_val=1.0,
        gradient_clip_algorithm='norm'
    )
    
    trainer.fit(model, train_dataloaders=train_loader)
        
# if __name__ == "__main__":
    
#     x = torch.randn((1, 5, 512))
    
#     block = FnetBlock(512)
    
#     out = block(x)
#     print(out.shape)
    
#     input_ids = torch.randint(0, 20002, (1, 10))
    
#     model = FNET(512, 10, 20002)
    
#     out = model(input_ids)
#     print(out.shape)

In [2]:
train_model()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\Rohit Francis\Desktop\Codes\AI_Projects\Model_Tryouts\deeplearningenv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
c:\Users\Rohit Francis\Desktop\Codes\AI_Projects\Model_Tryouts\deeplearningenv\lib\site-packages\p

Epoch 0:   0%|          | 171/1048448 [00:36<61:26:43,  4.74it/s, v_num=0, train_loss=127.0]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [2]:
from tokenizers import Tokenizer
import json

# Load pretrained tokenizer
tokenizer = Tokenizer.from_file("my_tokenizer.json")

In [3]:
with open("tokenizer_content.txt", "r", encoding='utf-8') as f:
    text = f.read()

In [4]:
print(text[:100000])


Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Saint Bernadette Soubirous
What is in front of the Notre Dame Main Building?
a copper statue of Christ
The Basilica of the Sacred heart at Notre Dame is beside to which structure?
the Main Building
What is the Grotto at Notre Dame?

In [11]:
len(tokenizer.encode(text[:200000]).ids)

41756

In [6]:
torch.log(torch.tensor(4240059))/torch.log(torch.tensor(2))
print(2**22)

4194304


In [7]:
encoded_tokens = tokenizer.encode(text)

tokens = torch.tensor(encoded_tokens.ids)
torch.save(tokens, "Tokens_small.pt")

In [12]:
max(encoded_tokens.ids)

29987

In [9]:
tensor = torch.load("Tokens.pt")

  tensor = torch.load("Tokens.pt")


In [11]:
tensor.shape

torch.Size([20727])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

2.0

[5, 6, 7, 8, 9]