In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
dataset = load_dataset('ag_news')
train_data = dataset['train']
test_data = dataset['test']

In [None]:
def simple_tokenizer(text):
    text = text.lower()
    tokens = re.findall(r'\b\w+\b', text)
    return tokens

In [5]:
counter = Counter()
for example in train_data:
    counter.update(simple_tokenizer(example['text']))

specials = ['<pad>', '<unk>']
vocab = {token: idx for idx, token in enumerate(specials + list(counter.keys()))}
vocab['<pad>'] = 0
vocab['<unk>'] = 1
inv_vocab = {idx: token for token, idx in vocab.items()}

vocab_size = len(vocab)

In [6]:
def encode(text):
    tokens = simple_tokenizer(text)
    return [vocab.get(token, vocab['<unk>']) for token in tokens]

def collate_batch(batch):
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    text_list = [torch.tensor(encode(item['text']), dtype=torch.long) for item in batch]
    texts = pad_sequence(text_list, batch_first=True, padding_value=vocab['<pad>'])
    return texts, labels

In [7]:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_batch)

In [8]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.norm(dim=-1, keepdim=True) * (1.0 / (x.size(-1) ** 0.5))
        return self.scale * (x / (norm + self.eps))

class TransformerEncoderLayerWithRMSNorm(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs):

        # Self-attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Feedforward
        src2 = self.linear2(self.dropout(nn.functional.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        return src

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes, max_len=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_len, embed_dim)

        encoder_layer = TransformerEncoderLayerWithRMSNorm(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, seq_len)
        x = self.token_embed(x) + self.pos_embed(positions)

        x = x.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, embed_dim)

        x = x.mean(dim=1)
        return self.fc(x)

In [None]:
embed_dim = 128
num_heads = 4
num_layers = 2
num_classes = 4

model = TransformerClassifier(vocab_size, embed_dim, num_heads, num_layers, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

epochs = 10



In [12]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for texts, labels in train_loader:
        texts, labels = texts.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Train Acc={correct/total:.4f}")

Epoch 1: Loss=0.1177, Train Acc=0.9603
Epoch 2: Loss=0.0998, Train Acc=0.9655
Epoch 3: Loss=0.0855, Train Acc=0.9701
Epoch 4: Loss=0.0731, Train Acc=0.9742
Epoch 5: Loss=0.0632, Train Acc=0.9775


In [14]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for texts, labels in test_loader:
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct/total:.4f}")

Test Accuracy: 0.9046
