In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.utils.data import DataLoader

import torchtext
import torchtext.transforms as T
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from datasets import load_dataset

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Check the settings
print(torch.__version__)
print(torchtext.__version__)
print(f"Using {device} device")

In [None]:
# Prepare the dataset
train_dataset = load_dataset(
    "ucirvine/reuters21578", "ModApte", split="train", trust_remote_code=True
).with_format("torch")

test_dataset = load_dataset(
    "ucirvine/reuters21578", "ModApte", split="test", trust_remote_code=True
).with_format("torch")

train_dataset = train_dataset.remove_columns(
    [col for col in train_dataset.column_names if col not in ["text", "topics"]]
)

test_dataset = test_dataset.remove_columns(
    [col for col in test_dataset.column_names if col not in ["text", "topics"]]
)

unique_topics = set()

for entry in train_dataset:
    unique_topics.update(entry["topics"])

for entry in test_dataset:
    unique_topics.update(entry["topics"])

print(f"Number of unique topics: {len(unique_topics)}")

In [None]:
# Tokenize the dataset
tokenizer = get_tokenizer("basic_english")


def yield_tokens(data_iter):
    for data_sample in data_iter:
        yield tokenizer(data_sample["text"])


vocab = build_vocab_from_iterator(
    yield_tokens(train_dataset),
    specials=["<pad>", "<sos>", "<eos>", "<unk>"],
    min_freq=2,
    special_first=True,
)
print(f"Vocab size: {len(vocab)}")

vocab.set_default_index(vocab["<unk>"])

In [None]:
text_transform = T.Sequential(
    T.VocabTransform(vocab=vocab),
    T.AddToken(1, begin=True),
    T.Truncate(512),
    T.AddToken(2, begin=False),
    T.ToTensor(),
)

print(
    pad_sequence(
        [text_transform(tokenizer("Hello world"))],
        batch_first=True,
        padding_value=vocab["<pad>"],
    )
)

text_tokenizer = lambda batch: [tokenizer(x) for x in batch]
topic_to_idx = {topic: idx for idx, topic in enumerate(unique_topics)}


def collate_batch(batch):
    text_list, label_list = [], []
    for data_sample in batch:
        if not data_sample["topics"]:
            continue

        try:
            processed_text = text_transform(tokenizer(data_sample["text"]))
            text_list.append(processed_text)
            label_list.append(topic_to_idx[data_sample["topics"][0]])

        except Exception as e:
            print(f"Error processing data sample: {data_sample}")
            print(f"Exception: {e}")

    text_list = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return text_list, label_list


data_loader_train = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_batch,
)
data_loader_test = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_batch,
)

In [None]:
class LSTM(nn.Module):
    def __init__(self, num_emb, output_size, num_layers=1, hidden_size=128):
        super(LSTM, self).__init__()

        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.5,
        )
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input_seq, hidden_input, mem_input):
        input_embs = self.embedding(input_seq)
        output, (hidden_output, mem_output) = self.lstm(
            input_embs, (hidden_input, mem_input)
        )
        return self.fc(output), hidden_output, mem_output

In [None]:
num_epochs = 50
num_layers = 3
hidden_size = 128
num_emb = len(vocab)
out_size = len(unique_topics)
loss_fn = nn.CrossEntropyLoss()

model = LSTM(
    num_emb=num_emb,
    output_size=out_size,
    num_layers=num_layers,
    hidden_size=hidden_size,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_acc = 0
test_acc = 0

training_loss_logger = []
test_loss_logger = []
training_acc_logger = []
test_acc_logger = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(data_loader_train):
        bs = target.shape[0]
        hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
        memory = torch.zeros(num_layers, bs, hidden_size, device=device)

        pred, hidden, memory = model(data, hidden, memory)
        loss = loss_fn(pred[:, -1, :], target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(pred[:, -1, :], 1)
        correct += (predicted == target).sum().item()
        total += target.size(0)
        running_loss += loss.item()

        print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}")

    train_acc = 100 * correct / total
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader_train):.4f}, Training Accuracy: {train_acc:.2f}%"
    )

    # Evaluation loop
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader_test):
            bs = target.shape[0]
            hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
            memory = torch.zeros(num_layers, bs, hidden_size, device=device)

            pred, hidden, memory = model(data, hidden, memory)
            loss = loss_fn(pred[:, -1, :], target)

            _, predicted = torch.max(pred[:, -1, :], 1)
            correct += (predicted == target).sum().item()
            test_loss += loss.item()
            total += target.size(0)

    test_acc = 100 * correct / total
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Test Loss: {test_loss / len(data_loader_test):.4f}, Test Accuracy: {test_acc:.2f}%"
    )