In [13]:
import torch
import time
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

from collections import Counter
import re

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [15]:
dataset = load_dataset('ag_news')
train_data = dataset['train']
test_data = dataset['test']

In [17]:
def simple_tokenizer(text):
    text = text.lower()
    tokens = re.findall(r'\b\w+\b', text)
    return tokens

counter = Counter()
for example in train_data:
    counter.update(simple_tokenizer(example['text']))

specials = ['<pad>', '<unk>']
vocab = {token: idx for idx, token in enumerate(specials + list(counter.keys()))}
vocab['<pad>'] = 0
vocab['<unk>'] = 1
inv_vocab = {idx: token for token, idx in vocab.items()}

vocab_size = len(vocab)

In [18]:
def encode(text):
    tokens = simple_tokenizer(text)
    return [vocab.get(token, vocab['<unk>']) for token in tokens]
def collate_batch(batch):
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    text_list = [torch.tensor(encode(item['text']), dtype=torch.long) for item in batch]
    texts = pad_sequence(text_list, batch_first=True, padding_value=vocab['<pad>'])
    return texts, labels
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_batch)

In [19]:
# ========== Define regular RMSNorm ==========
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean_square = (x ** 2).mean(dim=-1, keepdim=True)
        inv_rms = 1.0 / torch.sqrt(mean_square + self.eps)
        return self.scale * x * inv_rms

# ========== Define Quantized RMSNorm ==========
class QuantizedRMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5, scale_bits=8):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))
        self.scale_bits = scale_bits

    def forward(self, x):
        x_int = torch.round(x * 127) / 127
        mean_square = (x_int ** 2).mean(dim=-1, keepdim=True)
        inv_rms = 1.0 / torch.sqrt(mean_square + self.eps)
        inv_rms_q = torch.round(inv_rms * 127) / 127
        return self.scale * x_int * inv_rms_q

# ========== Encoder Layer Factory ==========
def make_encoder_layer(norm_type, d_model, nhead, dim_feedforward=2048, dropout=0.1):
    class EncoderLayer(nn.Module):
        def __init__(self):
            super().__init__()
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
            self.linear1 = nn.Linear(d_model, dim_feedforward)
            self.dropout = nn.Dropout(dropout)
            self.linear2 = nn.Linear(dim_feedforward, d_model)
            self.norm1 = norm_type(d_model)
            self.norm2 = norm_type(d_model)
            self.dropout1 = nn.Dropout(dropout)
            self.dropout2 = nn.Dropout(dropout)

        def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs):
            src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                                  key_padding_mask=src_key_padding_mask)[0]
            src = src + self.dropout1(src2)
            src = self.norm1(src)

            src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
            src = src + self.dropout2(src2)
            src = self.norm2(src)

            return src

    return EncoderLayer

# ========== Transformer Classifier ==========
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes, norm_type, max_len=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_len, embed_dim)

        encoder_layer = make_encoder_layer(norm_type, embed_dim, num_heads, dim_feedforward=4 * embed_dim)()
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, seq_len)
        x = self.token_embed(x) + self.pos_embed(positions)
        x = x.permute(1, 0, 2)
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2)
        x = x.mean(dim=1)
        return self.fc(x)



In [20]:
# ========== Performance Measurement Functions ==========
def measure_performance_train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    start_time = time.time()
    torch.cuda.reset_peak_memory_stats()

    for texts, labels in train_loader:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    end_time = time.time()
    train_time = end_time - start_time
    gpu_memory = torch.cuda.max_memory_allocated() / 1e6
    return total_loss / len(train_loader), correct / total, train_time, gpu_memory

def measure_performance_test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    start_time = time.time()
    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    end_time = time.time()
    test_time = end_time - start_time
    return correct / total, test_time



In [21]:
# ========== Instantiate Models ==========
embed_dim = 128
num_heads = 4
num_layers = 2
num_classes = 4

model_fp32 = TransformerClassifier(vocab_size, embed_dim, num_heads, num_layers, num_classes, RMSNorm).to(device)
model_quant = TransformerClassifier(vocab_size, embed_dim, num_heads, num_layers, num_classes, QuantizedRMSNorm).to(device)

criterion = nn.CrossEntropyLoss()
optimizer_fp32 = torch.optim.Adam(model_fp32.parameters(), lr=1e-3)
optimizer_quant = torch.optim.Adam(model_quant.parameters(), lr=1e-3)

epochs = 5

# ========== Training Loop ==========
for epoch in range(epochs):
    # FP32 model
    loss_fp32, acc_fp32, time_fp32, mem_fp32 = measure_performance_train(
        model_fp32, train_loader, criterion, optimizer_fp32, device
    )
    test_acc_fp32, test_time_fp32 = measure_performance_test(model_fp32, test_loader, device)

    # Quantized RMSNorm model
    loss_quant, acc_quant, time_quant, mem_quant = measure_performance_train(
        model_quant, train_loader, criterion, optimizer_quant, device
    )
    test_acc_quant, test_time_quant = measure_performance_test(model_quant, test_loader, device)

    # Print results
    print(f"\nEpoch {epoch+1} Results:")
    print(f"FP32 RMSNorm --> Train Loss: {loss_fp32:.4f}, Train Acc: {acc_fp32:.4f}, Time: {time_fp32:.2f}s, GPU Mem: {mem_fp32:.2f} MB, Test Acc: {test_acc_fp32:.4f}, Test Time: {test_time_fp32:.2f}s")
    print(f"Quantized RMSNorm --> Train Loss: {loss_quant:.4f}, Train Acc: {acc_quant:.4f}, Time: {time_quant:.2f}s, GPU Mem: {mem_quant:.2f} MB, Test Acc: {test_acc_quant:.4f}, Test Time: {test_time_quant:.2f}s")



Epoch 1 Results:
FP32 RMSNorm --> Train Loss: 0.4804, Train Acc: 0.8209, Time: 14.32s, GPU Mem: 630.35 MB, Test Acc: 0.8874, Test Time: 0.35s
Quantized RMSNorm --> Train Loss: 1.3352, Train Acc: 0.3698, Time: 14.68s, GPU Mem: 735.71 MB, Test Acc: 0.4220, Test Time: 0.37s

Epoch 2 Results:
FP32 RMSNorm --> Train Loss: 0.2690, Train Acc: 0.9096, Time: 13.83s, GPU Mem: 735.71 MB, Test Acc: 0.9026, Test Time: 0.34s
Quantized RMSNorm --> Train Loss: 1.2690, Train Acc: 0.4337, Time: 14.59s, GPU Mem: 735.71 MB, Test Acc: 0.4563, Test Time: 0.37s

Epoch 3 Results:
FP32 RMSNorm --> Train Loss: 0.2079, Train Acc: 0.9295, Time: 13.84s, GPU Mem: 735.71 MB, Test Acc: 0.9108, Test Time: 0.35s
Quantized RMSNorm --> Train Loss: 1.2388, Train Acc: 0.4511, Time: 14.60s, GPU Mem: 735.71 MB, Test Acc: 0.4704, Test Time: 0.37s

Epoch 4 Results:
FP32 RMSNorm --> Train Loss: 0.1705, Train Acc: 0.9421, Time: 13.82s, GPU Mem: 735.71 MB, Test Acc: 0.9120, Test Time: 0.34s
Quantized RMSNorm --> Train Loss: 1.21