In [None]:
from datasets import load_dataset
import torch

dataset = load_dataset("stanfordnlp/imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# print some statistics
print("Number of training examples: {}".format(len(train_dataset)))
print("Number of test examples: {}".format(len(test_dataset)))
print("Number of classes: {}".format(len(train_dataset.features["label"].names)))
print("Classes: {}".format(train_dataset.features["label"].names))

# print a sample
print("Sample: {}".format(train_dataset[0]))

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split

class Dataset(torch.utils.data.Dataset):
    def __init__(self, text_list, label_list, vocab_to_idx, max_length=32, pad_token="<PAD>", unk_token="<UNK>"):
        self.text_list = text_list
        self.label_list = label_list
        self.vocab_to_idx = vocab_to_idx
        self.max_length = max_length
        self.pad_token = pad_token
        self.unk_token = unk_token

    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, index):
        text = self.text_list[index]
        label = self.label_list[index]
        text = [self.vocab_to_idx[word] if word in self.vocab_to_idx else self.vocab_to_idx[self.unk_token] for word in text.split()]
        if len(text) > self.max_length:
            text = text[:self.max_length]
        else:
            text = text + [self.vocab_to_idx[self.pad_token]] * (self.max_length - len(text))
        return torch.tensor(text), torch.tensor(label)

print(train_dataset[0])
train_list = [example["text"] for example in train_dataset]
train_label_list = [example["label"] for example in train_dataset]

# split the training set into training and validation set
train_text_list, val_text_list, train_label_list, val_label_list = train_test_split(train_list, train_label_list, test_size=0.2, random_state=42)


train_dataset = Dataset(train_text_list, train_label_list, vocab_to_idx)
val_dataset = Dataset(val_text_list, val_label_list, vocab_to_idx)

# test dataset
test_text_list = [example["text"] for example in test_dataset]
test_label_list = [example["label"] for example in test_dataset]
test_dataset = Dataset(test_text_list, test_label_list, vocab_to_idx)

print("Number of training examples: {}".format(len(train_dataset)))
print("Number of validation examples: {}".format(len(val_dataset)))
print("Number of test examples: {}".format(len(test_dataset)))


In [None]:
class LSTMClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.lstm = torch.nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.fc = torch.nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        x = self.sigmoid(x)
        return x

In [None]:
def train_model(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, device, model_name):
    best_accuracy = 0
    best_accuracy_epoch = 0
    for epoch in range(num_epochs):
        p_bar = tqdm(train_dataloader, desc="Epoch {}".format(epoch + 1))
        model.train()
        for batch in p_bar:
            optimizer.zero_grad()
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels.unsqueeze(1).float())
            loss.backward()
            optimizer.step()
            p_bar.set_postfix({"Loss": loss.item()})

        # evaluate on validation set
        model.eval()
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for batch in val_dataloader:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                val_preds.extend(outputs.squeeze(1).tolist())
                val_labels.extend(labels.tolist())
        accuracy, precision, recall, f1 = compute_metrics(val_preds, val_labels)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_accuracy_epoch = epoch + 1
            torch.save(model.state_dict(), "models/{}_model_imdb.pt".format(model_name))
        print(f"Epoch {epoch + 1} Validation Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

    return best_accuracy, best_accuracy_epoch

In [None]:
# LSTM
model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_layers=2)
model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
print("*" * 80)
print("LSTM")

t0 = time.time()
best_accuracy, best_accuracy_epoch = train_model(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, device, "bidirectional_lstm_2")
t1 = time.time()
print(f"Best validation accuracy for LSTM: {best_accuracy:.4f} at epoch {best_accuracy_epoch}")
print(f"Training took {t1 - t0:.4f} seconds")
print("*" * 80)