In [1]:
!pip install pytorch-lightning --q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.4/825.4 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m121.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        # SwiGLU: Swish(xW1) ⊙ (xW3) W2
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # More numerically stable RMSNorm
        # Calculate mean of squares
        mean_square = x.pow(2).mean(dim=-1, keepdim=True)
        # RMS normalization
        rms = torch.sqrt(mean_square + self.eps)
        return self.weight * x / rms

class FnetBlock(nn.Module):

    def __init__(self, embed_dim):
        super().__init__()
        self.rmsnorm1 = RMSNorm(embed_dim)

        self.rmsnorm2 = RMSNorm(embed_dim)

        # self.mlp = nn.Sequential(
        #     nn.Linear(embed_dim, embed_dim*2),
        #     nn.ReLU(),
        #     nn.Linear(embed_dim*2, embed_dim*2),
        #     nn.ReLU(),
        #     nn.Linear(embed_dim*2, embed_dim)
        # )

        self.mlp = SwiGLU(embed_dim, embed_dim*4)

    def forward(self, x):

        out = x + torch.fft.fft(self.rmsnorm1(x), dim=1).real
        out = out + self.mlp(self.rmsnorm2((out)))
        return out


class FNET(pl.LightningModule):

    def __init__(self, embed_dim, context_length, vocab_size, num_layers=2, lr=0.0001):
        super().__init__()

        self.lr = lr

        self.context_length = context_length

        self.word_embeddings = nn.Embedding(vocab_size, embed_dim)

        self.pos_embeddings = nn.Embedding(context_length, embed_dim)

        self.blocks = nn.ModuleList([FnetBlock(embed_dim) for _ in range(num_layers)])

        self.norm = RMSNorm(embed_dim)

        self.output = nn.Linear(embed_dim, vocab_size, bias=False)

        # print(f"self.output.weight.shape: {self.output.weight.shape}")

        # print(f"self.word_embeddings.weight.shape: {self.word_embeddings.weight.shape}")

        self.loss_func = nn.CrossEntropyLoss()

        self.output.weight = self.word_embeddings.weight

    def forward(self, input_ids, attention_mask:Optional[torch.tensor]=None):

        embs = self.word_embeddings(input_ids) + self.pos_embeddings(torch.arange(0, self.context_length).to(input_ids.device))

        if attention_mask:
            attention_mask = torch.tril(torch.ones((self.context_length, self.context_length), device=input_ids.device))
            mask = attention_mask.unsqueeze(-1).expand_as(embs)
            embs = embs*mask

        for layer in self.blocks:
            embs = layer(embs)

        embs = self.norm(embs)

        logits = self.output(embs)
        return logits
    # Fix 1: Update training_step to reshape logits and targets
    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]

        # print(x.shape, y.shape)

        out = self(x)  # Shape: (batch_size, seq_len, vocab_size)

        # Reshape for CrossEntropyLoss
        loss = self.loss_func(out.view(-1, out.size(-1)), y.view(-1))

        self.log("train_loss", loss, prog_bar=True)

        return loss

    # Fix 2: Update validation_step similarly
    def validation_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]

        out = self(x)

        # Reshape for CrossEntropyLoss
        loss = self.loss_func(out.view(-1, out.size(-1)), y.view(-1))

        self.log("val_loss", loss, prog_bar=True)

        return loss

    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(), lr = self.lr)

        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer,
        #     mode='min'
        # )

        return {
            "optimizer": optimizer,
        }


    # def configure_gradient_clipping(self, optimizer, gradient_clip_val = None, gradient_clip_algorithm = None):
    #     self.clip_gradients(
    #         optimizer,
    #         gradient_clip_val=1.0,
    #         gradient_clip_algorithm='norm'
    #     )


class TokenDataset(Dataset):
   def __init__(self, tokens_path="Tokens.pt", max_length=256):
       self.tokens = torch.load(tokens_path, weights_only=True)
       self.tokens = self.tokens[:103218]
       self.max_length = max_length
      #  print("here")
   def __len__(self):
       return len(self.tokens) - self.max_length

   def __getitem__(self, idx):
       x = self.tokens[idx:idx + self.max_length]
       y = self.tokens[idx + 1:idx + self.max_length + 1]
       return x, y

# def create_dataloader(tokens_path="Tokens.pt", batch_size=16, max_length=512, num_workers=4):
#    dataset = TokenDataset(tokens_path, max_length)
#    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Fix 3: Update create_dataloader to use num_workers=0 (disable multiprocessing)
def create_dataloader(tokens_path="Tokens.pt", batch_size=16, max_length=512, num_workers=0):
    dataset = TokenDataset(tokens_path, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)





def train_model():

    embed_dim, context_length, vocab_size = 512, 256, 30002
    checkpoint_path = "/content/lightning_logs/version_7/checkpoints/epoch=3-step=6436.ckpt"
    model = FNET.load_from_checkpoint(checkpoint_path, embed_dim=embed_dim, context_length=context_length, vocab_size=vocab_size, num_layers=4)
    # model = FNET(embed_dim, context_length, vocab_size, num_layers=4)

    train_loader = create_dataloader(tokens_path="Tokens.pt", batch_size=64, max_length=256)

    trainer = pl.Trainer(
        max_epochs=100,
        enable_progress_bar=True,
        num_nodes=1,
        enable_checkpointing=True,
        gradient_clip_val=1.0,
        gradient_clip_algorithm='norm'
    )

    trainer.fit(model, train_dataloaders=train_loader)

# if __name__ == "__main__":

#     x = torch.randn((1, 5, 512))

#     block = FnetBlock(512)

#     out = block(x)
#     print(out.shape)

#     input_ids = torch.randint(0, 20002, (1, 10))

#     model = FNET(512, 10, 20002)

#     out = model(input_ids)
#     print(out.shape)

In [4]:
torch.load("Tokens_small.pt").shape

torch.Size([4240059])

In [None]:
train_model()

INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type             | Params | Mode 
-------------------------------------------------------------
0 | word_embeddings | Embedding        | 15.4 M | train
1 | pos_embeddings  | Em

Training: |          | 0/? [00:00<?, ?it/s]

In [7]:
torch.cuda.empty_cache()

# Testing

In [9]:
checkpoint_path = "/content/lightning_logs/version_7/checkpoints/epoch=3-step=6436.ckpt"

embed_dim, context_length, vocab_size = 512, 256, 30002

model = FNET.load_from_checkpoint(checkpoint_path, embed_dim=embed_dim, context_length=context_length, vocab_size=vocab_size, num_layers=4)


In [8]:
from tokenizers import Tokenizer
import json

# Load pretrained tokenizer
tokenizer = Tokenizer.from_file("my_tokenizer.json")

In [22]:
out = tokenizer.encode("""Her first acting role of 2006 was in the comedy film The Pink Panther starring opposite Steve Martin, grossing $158.8 million at the box office worldwide. Her second film Dreamgirls, the film version of the 1981 Broadway musical loosely based on The Supremes, received acclaim from critics and grossed $154 million internationally. In it, she starred opposite Jennifer Hudson, Jamie Foxx, and Eddie Murphy playing a pop singer based on Diana Ross. To promote the film, Beyoncé released "Listen" as the lead single from the soundtrack album. In April 2007, Beyoncé embarked on The Beyoncé Experience, her first worldwide concert tour, visiting 97 venues and grossed over $24 million.[note 1] Beyoncé conducted pre-concert food donation drives during six major stops in conjunction with her pastor at St. John's and America's Second Harvest. At the same time, B'Day was re-released with five additional songs, including her duet with Shakira "Beautiful Liar".
What movie did Beyonce act in 2006?
The Pink Panther
Her second movie Beyonce did was what film?
Dreamgirls
The single, "Listen" was featured in which movie?
Dreamgirls
""")

In [30]:
inp = torch.tensor([out.ids+[0]*21]).to('cuda')

In [31]:
inp.shape

torch.Size([1, 256])

In [34]:
logits = model(inp)

In [35]:
ids = torch.argmax(logits, dim=-1)

In [38]:
print(tokenizer.decode(ids.tolist()[0]))

, the the the to of in the comedy the in in Pink opposite , , to , at Dreamgirls Dreamgirls Dreamgirls ? million featured the , . " List List " Dreamgirls Dreamgirls Dreamgirls Dreamgirls of the what the movie loosely musical on The The The The , 2006 in act Beyonce did ". grossed Li million Li million " ira , it her opposite her including , songs , five x , - re was die ' a , pop based on At . . To . s film the Beyoncé released ' released " " as with her single from the from ? . In April In - Beyoncé - Beyoncé Beyoncé Beyoncé Experience , Beyoncé , million over the and tour 97 visiting 97 tour and the 24 million , Beyoncé , note Beyoncé Beyoncé Beyoncé Beyoncé 2007 - In April In . ? from single from single her with as " " released ' s Beyoncé the film s To To . . At on based time , a ' Day was re - , x five , songs , including her opposite her it , ira " million Li Beautiful Li grossed ". did Beyonce act in 2006 , The Pink The The second musical loosely movie the what the of Dreamgirl

In [None]:
"""The single, "Listen" was featured in which movie?
Dreamgirls
Beyonce's first world tour was when?
2007
How much money did Beyonce's tour make in 2007?
24 million
How many millions of dollars did ''The Pink Panther'' gross world-wide?
158.8 million
What did Beyonce call her first concert tour?
The Beyoncé Experience
Who was Beyonce's duet with in ''Beautiful Liar''?
Shakira
Which film did Beyoncé star with Steve Martin in?
The Pink Panther
Beyoncé's role in Dreamgirls was based on what pop singer?
Diana Ross.
What was the lead single for the Dreamgirls soundtrack?
Listen
What was the name of Beyoncé's first international tour?
The Beyoncé Experience
What pop singer did a duet with Beyoncé on Beautiful Liar?
Shakira"""

In [40]:
del model, output_tokens, logits, inp