In [2]:
import os
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from transformers import AdamW, get_scheduler
import string  # Ensure string is imported

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DBpedia dataset
dataset = load_dataset('dbpedia_14')

# Prepare train/test datasets
train_dataset = dataset['train']
test_dataset = dataset['test']

# Tokenizer: Build Vocabulary
def basic_tokenizer(text):
    return text.lower().translate(str.maketrans("", "", string.punctuation)).split()

def build_vocab(dataset, max_vocab_size=50000):
    from collections import Counter
    counter = Counter()
    for example in dataset:
        # Combine title and content
        text = example['title'] + " " + example['content']
        tokens = basic_tokenizer(text)
        counter.update(tokens)
    most_common = counter.most_common(max_vocab_size)
    vocab = {word: idx + 2 for idx, (word, _) in enumerate(most_common)}  # Reserve 0 for [PAD], 1 for [UNK]
    vocab["[PAD]"] = 0
    vocab["[UNK]"] = 1
    return vocab

vocab = build_vocab(train_dataset, max_vocab_size=50000)
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

# Custom Tokenizer
def custom_tokenizer(example, vocab, max_length=128):
    text = example['title'] + " " + example['content']
    tokens = basic_tokenizer(text)
    token_ids = [vocab.get(token, vocab["[UNK]"]) for token in tokens]
    if len(token_ids) < max_length:
        token_ids += [vocab["[PAD]"]] * (max_length - len(token_ids))  # Padding
    else:
        token_ids = token_ids[:max_length]  # Truncate
    return token_ids

def tokenize_dataset(dataset, vocab, max_length=128):
    tokenized_data = []
    for example in dataset:
        tokenized_data.append({
            'input_ids': custom_tokenizer(example, vocab, max_length),
            'labels': example['label']
        })
    return tokenized_data

# Tokenize datasets
tokenized_train = tokenize_dataset(train_dataset, vocab)
tokenized_test = tokenize_dataset(test_dataset, vocab)

# Convert to PyTorch Dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.data[idx]['input_ids'], dtype=torch.long),
            'labels': torch.tensor(self.data[idx]['labels'], dtype=torch.long)
        }

train_dataset = CustomDataset(tokenized_train)
test_dataset = CustomDataset(tokenized_test)

# DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Cyclic Attention Transformer
class CyclicAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(CyclicAttentionBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.gate = nn.Linear(embed_dim, 1)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm(x + attn_output)
        cyclic_term = torch.roll(attn_output, shifts=1, dims=0)
        g = torch.sigmoid(self.gate(attn_output))
        x = self.norm(x + g * cyclic_term)
        return self.dropout(x)

class CombinedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(CombinedTransformerBlock, self).__init__()
        self.attention1 = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.cyclic_attention = CyclicAttentionBlock(embed_dim, num_heads, dropout)
        self.attention2 = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output, _ = self.attention1(x, x, x)
        x = self.norm1(x + attn_output)
        x = self.cyclic_attention(x)
        attn_output, _ = self.attention2(x, x, x)
        x = self.norm2(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class CustomTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_classes, num_layers):
        super(CustomTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.layers = nn.ModuleList([
            CombinedTransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.global_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = x.permute(1, 0, 2)
        for layer in self.layers:
            x = layer(x)
        x = x.permute(1, 2, 0)
        x = self.global_pooling(x).squeeze(-1)
        logits = self.fc(x)
        return logits

# Model Initialization
embed_dim = 1024
num_heads = 8
ff_dim = 2048
num_classes = 14  # DBpedia has 14 classes
num_layers = 3
epochs = 5

model = CustomTransformer(vocab_size, embed_dim, num_heads, ff_dim, num_classes, num_layers).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training Loop
print("Training...")
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs} Loss: {total_loss / len(train_loader)}")

# Evaluation Loop
print("Evaluating...")
model.eval()
all_logits = []
all_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        logits = model(input_ids)
        all_logits.append(logits.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_logits = np.concatenate(all_logits, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# Metrics
predictions = np.argmax(all_logits, axis=1)
accuracy = accuracy_score(all_labels, predictions)
f1 = f1_score(all_labels, predictions, average="weighted")
precision = precision_score(all_labels, predictions, average="weighted")
recall = recall_score(all_labels, predictions, average="weighted")

print(f"\nAccuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


README.md:   0%|          | 0.00/7.64k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/106M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/13.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/560000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/70000 [00:00<?, ? examples/s]

Vocabulary size: 50002




Training...


100%|██████████| 8750/8750 [1:04:24<00:00,  2.26it/s]


Epoch 1/5 Loss: 0.12996144127249717


100%|██████████| 8750/8750 [1:04:21<00:00,  2.27it/s]


Epoch 2/5 Loss: 0.06815697464145987


100%|██████████| 8750/8750 [1:04:20<00:00,  2.27it/s]


Epoch 3/5 Loss: 0.05200324921938591


100%|██████████| 8750/8750 [1:04:21<00:00,  2.27it/s]


Epoch 4/5 Loss: 0.04161729629899242


100%|██████████| 8750/8750 [1:04:21<00:00,  2.27it/s]


Epoch 5/5 Loss: 0.034360936810911104
Evaluating...


100%|██████████| 1094/1094 [02:36<00:00,  6.98it/s]


Accuracy: 0.9805
F1 Score: 0.9805
Precision: 0.9806
Recall: 0.9805



