In [9]:
import numpy as np
import torch
from torch import nn, optim
from dataset import get_twit_company_dataloaders, split_sentence
from model import LSTMTwitClassifier
import torch.nn.functional as F

# text, label = next(iter(dataloader_train))

In [None]:
use_wandb = True

lr = 0.0001
embedding_size = 100
hidden_size = 100
epochs_cnt = 50
embeddings = "random"

dataset_train, dataloader_train, dataset_test, dataloader_test = get_twit_company_dataloaders(embedding_dim=embedding_size, embedding=embeddings)

model = LSTMTwitClassifier(4, embedding_dim=embedding_size, hidden_dim=hidden_size)

if use_wandb:
    import wandb

    wandb.init(project='twit_classification', entity='ars860')

    config = wandb.config
    config.loss = "BCE"
    config.optimizer = "Adam"
    config.learning_rate = lr
    config.hidden_size = hidden_size
    config.embedding_size = embedding_size
    config.embeddings = embeddings
    config.epochs = epochs_cnt

    wandb.watch(model)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

losses = np.empty(100)
for epoch in range(epochs_cnt):
    epoch_loss = np.zeros(len(dataloader_train))

    for i, (txt, company) in enumerate(dataloader_train):
        model.zero_grad()

        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        loss = criterion(prediction, company.view(-1))

        loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        losses[i % 100] = loss
        epoch_loss[i] = loss

        if (i + 1) % 100 == 0:
            print(
                f"Epoch {epoch + 1}/{epochs_cnt}, iter: {i + 1}/{len(dataloader_train)}, mean loss: {np.mean(losses)}")
            if use_wandb:
                wandb.log({"loss": np.mean(losses)})

    if use_wandb:
        wandb.log({"epoch_loss": np.mean(epoch_loss)})

# [model.get_word_embedding(word) for word in "hello_world".split(' ')]

wandb: wandb version 0.12.4 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/50, iter: 100/3413, mean loss: 0.5202424854040146
Epoch 1/50, iter: 200/3413, mean loss: 0.11388149753911421
Epoch 1/50, iter: 300/3413, mean loss: 0.04468637207639404
Epoch 1/50, iter: 400/3413, mean loss: 0.019677399628562853
Epoch 1/50, iter: 500/3413, mean loss: 0.05077616944734473
Epoch 1/50, iter: 600/3413, mean loss: 0.012700586859718897
Epoch 1/50, iter: 700/3413, mean loss: 0.015684741334989667
Epoch 1/50, iter: 800/3413, mean loss: 0.016751434097241145
Epoch 1/50, iter: 900/3413, mean loss: 0.01777815209039545
Epoch 1/50, iter: 1000/3413, mean loss: 0.6576214437345334
Epoch 1/50, iter: 1100/3413, mean loss: 1.0616246896982193
Epoch 1/50, iter: 1200/3413, mean loss: 0.5431772115826606
Epoch 1/50, iter: 1300/3413, mean loss: 0.45224565982818604
Epoch 1/50, iter: 1400/3413, mean loss: 0.27440092399716376
Epoch 1/50, iter: 1500/3413, mean loss: 0.12184399755671621
Epoch 1/50, iter: 1600/3413, mean loss: 0.12785146688111126
Epoch 1/50, iter: 1700/3413, mean loss: 0.1007484

In [None]:
print("Testing on train")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (txt, company) in enumerate(dataloader_train):
        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(company):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_train)}")

print(f"Accuracy {correct / len(dataloader_train)}")

if use_wandb:
    wandb.run.summary.train_accuracy = correct / len(dataloader_train)
    wandb.run.summary.classified_as = {
        "apple": predictions_cnt[0],
        "google": predictions_cnt[1],
        "microsoft": predictions_cnt[2],
        "twitter": predictions_cnt[3]
    }
    wandb.finish()

In [None]:
print("Testing on test")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (txt, company) in enumerate(dataloader_test):
        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(company):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_test)}")

print(f"Accuracy {correct / len(dataloader_test)}")