In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import IMDB
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
!pip install torchtext

Defaulting to user installation because normal site-packages is not writeable
Collecting torchtext
  Downloading torchtext-0.18.0-cp38-cp38-manylinux1_x86_64.whl (2.0 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m0m
[?25hCollecting torch>=2.3.0
  Downloading torch-2.4.1-cp38-cp38-manylinux1_x86_64.whl (797.1 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m797.1/797.1 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:04[0m
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.4/209.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:02[0m
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014

In [None]:

# Hyperparameters
BATCH_SIZE = 64
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
NUM_CLASSES = 2
LEARNING_RATE = 0.001
NUM_EPOCHS = 10

# Tokenizer and vocabulary
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

train_iter, test_iter = IMDB()
train_iter = to_map_style_dataset(train_iter)
test_iter = to_map_style_dataset(test_iter)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>", "<bos>", "<eos>"])
vocab.set_default_index(vocab["<unk>"])

# Text preprocessing function
def preprocess(text):
    return [vocab[token] for token in tokenizer(text)]

# Data loading
def collate_batch(batch):
    label_list, text_list = zip(*batch)
    labels = torch.tensor(label_list)
    text_lengths = torch.tensor([len(text) for text in text_list])
    text_list = [torch.tensor(preprocess(text), dtype=torch.long) for text in text_list]

    max_len = max(text_lengths)
    padded_text = [torch.nn.functional.pad(text, (0, max_len - len(text)), value=vocab["<pad>"]) for text in text_list]
    padded_text = torch.stack(padded_text)
    return labels, padded_text, text_lengths

train_loader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_batch)
test_loader = DataLoader(test_iter, batch_size=BATCH_SIZE, collate_fn=collate_batch)


# RNN model for classification
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab["<pad>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        out = self.fc(output[:, -1, :]) #Take last hidden state.
        return out


# Train the model
model = RNNClassifier(len(vocab), EMBEDDING_DIM, HIDDEN_DIM, NUM_CLASSES)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    for labels, text, lengths in train_loader:
        optimizer.zero_grad()
        outputs = model(text, lengths)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {epoch_loss / len(train_loader):.4f}")


# Evaluate the model (example)
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for labels, text, lengths in test_loader:
        outputs = model(text, lengths)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")