In [1]:
!git clone https://github.com/Mateusz-Wojciechowski/sentimentAnalysis.git
%cd sentimentAnalysis

fatal: destination path 'sentimentAnalysis' already exists and is not an empty directory.
/content/sentimentAnalysis


In [2]:
!pip install 'portalocker>=2.0.0'



In [3]:
import torch
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from SentimentModel import SentimentModel
import torch.optim as optim
import torch.nn as nn
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Tokenizer
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

def process_text(text, vocab):
    return torch.tensor(vocab(tokenizer(text)), dtype=torch.long).to(device)

def collate_batch(batch):
    label_list, text_list = [], []
    for label, text in batch:
        label_tensor = torch.tensor([label-1], dtype=torch.float).to(device)
        processed_text = process_text(text, vocab)
        label_list.append(label_tensor)
        text_list.append(processed_text)
    return torch.stack(label_list).to(device), pad_sequence(text_list, padding_value=vocab["<pad>"], batch_first=True).to(device)

def calculate_accuracy(preds, y):
    preds = torch.sigmoid(preds)
    rounded_preds = torch.round(preds)
    correct = (rounded_preds == y).float()
    accuracy = correct.sum() / len(correct)
    return accuracy

train_data = list(IMDB(split='train'))
random.shuffle(train_data)

vocab = build_vocab_from_iterator(yield_tokens(data_iter for label, data_iter in train_data), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])

train_loader = DataLoader(train_data, batch_size=8, shuffle=False, collate_fn=collate_batch)

test_data = list(IMDB(split='test'))
random.shuffle(test_data)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate_batch)

d_model = 512
num_heads = 8
max_seq_len = 5000
d_ff = 2048
learning_rate = 0.001
num_classes = 1
vocab_size = len(vocab)
num_epochs = 100

model = SentimentModel(d_model, d_ff, num_heads, max_seq_len, num_classes, vocab_size)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(num_epochs):
    print(f"Epoch: {epoch + 1}")
    total_loss = 0
    total_accuracy = 0
    total_examples = 0
    i = 0

    model.train()
    for labels, sequences in train_loader:
        if i% 100 == 0:
          print(f"batch {i}")

        i +=1
        output = model(sequences)
        loss = loss_fn(output, labels)
        accuracy = calculate_accuracy(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        total_examples += labels.size(0)

    print(f"Loss in epoch {epoch + 1} is {total_loss}")
    print(f"Accuracy in epoch {epoch + 1} is {total_accuracy / total_examples}")


Using device: cuda
Epoch: 1
batch 0
batch 100
batch 200
batch 300
batch 400
batch 500
batch 600
batch 700
batch 800
batch 900
batch 1000
batch 1100
batch 1200
batch 1300
batch 1400
batch 1500
batch 1600
batch 1700
batch 1800
batch 1900
batch 2000
batch 2100
batch 2200
batch 2300
batch 2400
batch 2500
batch 2600
batch 2700
batch 2800
batch 2900
batch 3000
batch 3100
Loss in epoch 1 is 1867.0120996832848
Accuracy in epoch 1 is 0.08358
Epoch: 2
batch 0
batch 100
batch 200
batch 300
batch 400
batch 500
batch 600
batch 700
batch 800
batch 900
batch 1000
batch 1100
batch 1200
batch 1300
batch 1400
batch 1500
batch 1600
batch 1700
batch 1800
batch 1900
batch 2000
batch 2100
batch 2200
batch 2300
batch 2400
batch 2500
batch 2600
batch 2700
batch 2800
batch 2900
batch 3000
batch 3100
Loss in epoch 2 is 1201.5547805679962
Accuracy in epoch 2 is 0.104435
Epoch: 3
batch 0
batch 100
batch 200
batch 300
batch 400
batch 500
batch 600
batch 700
batch 800
batch 900
batch 1000
batch 1100
batch 1200
batc

KeyboardInterrupt: ignored