In [1]:
import datasets

dataset = datasets.load_dataset("synthseq/flipflop")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
flip_flop_dict = {'0': 0, "1": 1, "w": 2,"r": 3, "i": 4}

In [3]:
import torch
def tokenize_raw(batch):
    tokenized = [[flip_flop_dict[char] for char in s] for s in batch["text"]]
    return {"tokens": torch.tensor(tokenized, dtype=torch.int64)}

In [19]:
dataset.set_transform(tokenize_raw)

In [45]:
import torch
from torch.utils.data import DataLoader

class NextTokenDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.data = []
        for item in hf_dataset:
            tokens = item["tokens"]
            tokens = torch.tensor(tokens, dtype=torch.long)
            x = tokens[:-1]
            y = tokens[1:]
            self.data.append((x, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx] 

train_small = dataset["train"].select(range(2000))
val_small = dataset["val"].select(range(200))
train_dataset = NextTokenDataset(train_small)
val_dataset = NextTokenDataset(val_small)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

  tokens = torch.tensor(tokens, dtype=torch.long)


In [46]:
import torch.nn as nn
import math

class Sinusoidal_Embedding(nn.Module):
    def __init__(self, embed_dim, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        return self.pe[:x.size(1)].unsqueeze(0).expand(x.size(0), -1, -1)

In [47]:
class Transformer(nn.Module):
    def __init__(
            self,
            max_seq_len:int,
            dictionary_size:int,
            num_attn_layer:int=2,
            num_attn_heads:int=2,
            attn_dim:int=4,
            
        ):
        super().__init__()
        self.word_embedding = nn.Embedding(
            num_embeddings=dictionary_size,
            embedding_dim=attn_dim,
            )
        self.positional_embedding = Sinusoidal_Embedding(
            embed_dim=attn_dim,
            max_len=max_seq_len
        )
        self.attention_layers = [
            nn.MultiheadAttention(
                embed_dim=attn_dim,
                num_heads=num_attn_heads,
                batch_first=True,
            ) 
            for _ in range(num_attn_layer)
        ]
        self.classification = [
            nn.Linear(in_features=attn_dim, out_features=attn_dim),
            nn.ReLU(),
            nn.Linear(in_features=attn_dim, out_features=dictionary_size),
        ]

    def forward(self, x):
        """
        shape of x should be: (batch_size, seq_len)
        """
        seq_len = x.size()[1]
        word_emb = self.word_embedding(x)
        pos_emb = self.positional_embedding(x)
        emb = word_emb + pos_emb

        attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)

        for attn_layer in self.attention_layers:
            emb, _ = attn_layer(emb, emb, emb, attn_mask=attn_mask)

        
        for layer in self.classification:
            emb = layer(emb)

        return emb

In [None]:
def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for x, y in loader:
        logits = model(x)  # (batch, seq_len, vocab_size)

        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def eval_epoch(model, loader, loss_fn):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            logits = model(x)
            loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = 5

model = Transformer(max_seq_len=512, dictionary_size=vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(200):
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn)
    val_loss = eval_epoch(model, val_loader, loss_fn)
    print(f"Epoch {epoch+1}: train loss = {train_loss:.4f}, val loss = {val_loss:.4f}")

Epoch 1: train loss = 1.5250, val loss = 1.5243
Epoch 2: train loss = 1.5248, val loss = 1.5240
Epoch 3: train loss = 1.5237, val loss = 1.5210
Epoch 4: train loss = 1.5162, val loss = 1.5051
Epoch 5: train loss = 1.4749, val loss = 1.4438
Epoch 6: train loss = 1.4271, val loss = 1.4105
Epoch 7: train loss = 1.3991, val loss = 1.3868
Epoch 8: train loss = 1.3789, val loss = 1.3699
Epoch 9: train loss = 1.3642, val loss = 1.3570
Epoch 10: train loss = 1.3525, val loss = 1.3458
Epoch 11: train loss = 1.3408, val loss = 1.3316
Epoch 12: train loss = 1.3156, val loss = 1.2963
Epoch 13: train loss = 1.2799, val loss = 1.2647
Epoch 14: train loss = 1.2555, val loss = 1.2456
Epoch 15: train loss = 1.2379, val loss = 1.2292
Epoch 16: train loss = 1.2224, val loss = 1.2146
Epoch 17: train loss = 1.2085, val loss = 1.2015
Epoch 18: train loss = 1.1959, val loss = 1.1894
Epoch 19: train loss = 1.1843, val loss = 1.1783
Epoch 20: train loss = 1.1734, val loss = 1.1678
Epoch 21: train loss = 1.1633