In [9]:
import numpy as np
import torch
from torch import nn, optim
from dataset import get_twit_company_dataloaders, split_sentence
from model import LSTMTwitClassifier
import torch.nn.functional as F

# text, label = next(iter(dataloader_train))

In [13]:

use_wandb = False

lr = 0.0001
embedding_size = 100
hidden_size = 100
epochs_cnt = 50
embeddings = "random"

dataset_train, dataloader_train, dataset_test, dataloader_test = get_twit_company_dataloaders(embedding_dim=embedding_size, embedding=embeddings)

model = LSTMTwitClassifier(4, embedding_dim=embedding_size, hidden_dim=hidden_size)

if use_wandb:
    import wandb

    wandb.init(project='twit_classification', entity='ars860')

    config = wandb.config
    config.loss = "BCE"
    config.optimizer = "Adam"
    config.learning_rate = lr
    config.hidden_size = hidden_size
    config.embedding_size = embedding_size
    config.embeddings = embeddings
    config.epochs = epochs_cnt

    wandb.watch(model)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

losses = np.empty(100)
for epoch in range(epochs_cnt):
    epoch_loss = np.zeros(len(dataloader_train))

    for i, (txt, company) in enumerate(dataloader_train):
        model.zero_grad()

        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        loss = criterion(prediction, company.view(-1))

        loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        losses[i % 100] = loss
        epoch_loss[i] = loss

        if (i + 1) % 100 == 0:
            print(
                f"Epoch {epoch + 1}/{epochs_cnt}, iter: {i + 1}/{len(dataloader_train)}, mean loss: {np.mean(losses)}")
            if use_wandb:
                wandb.log({"loss": np.mean(losses)})

    if use_wandb:
        wandb.log({"epoch_loss": np.mean(epoch_loss)})

# [model.get_word_embedding(word) for word in "hello_world".split(' ')]

KeyboardInterrupt: 

In [11]:
print("Testing on train")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (txt, company) in enumerate(dataloader_train):
        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(company):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_train)}")

print(f"Accuracy {correct / len(dataloader_train)}")

if use_wandb:
    wandb.run.summary.train_accuracy = correct / len(dataloader_train)
    wandb.run.summary.classified_as = {
        "apple": predictions_cnt[0],
        "google": predictions_cnt[1],
        "microsoft": predictions_cnt[2],
        "twitter": predictions_cnt[3]
    }
    wandb.finish()

Testing on train
Iter: 0/3413
Iter: 100/3413
Iter: 200/3413
Iter: 300/3413
Iter: 400/3413
Iter: 500/3413
Iter: 600/3413
Iter: 700/3413
Iter: 800/3413
Iter: 900/3413
Iter: 1000/3413
Iter: 1100/3413
Iter: 1200/3413
Iter: 1300/3413
Iter: 1400/3413
Iter: 1500/3413
Iter: 1600/3413
Iter: 1700/3413
Iter: 1800/3413
Iter: 1900/3413
Iter: 2000/3413
Iter: 2100/3413
Iter: 2200/3413
Iter: 2300/3413
Iter: 2400/3413
Iter: 2500/3413
Iter: 2600/3413
Iter: 2700/3413
Iter: 2800/3413
Iter: 2900/3413
Iter: 3000/3413
Iter: 3100/3413
Iter: 3200/3413
Iter: 3300/3413
Iter: 3400/3413
Accuracy 0.2566656900087899


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.06627
_runtime,1180.0
_timestamp,1633522701.0
_step,1749.0
epoch_loss,0.14353
train_accuracy,0.25667


0,1
loss,▂▃▆▁▁▃▁▂▃█▂▁▇▂▂▅▂▃▂▂▂▃▁▁▂▁▂▂▆▁▁▄▁▂▅▂▃▅▂▃
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,▆█▇▆▃▃▃▃▃▃▃▂▃▃▃▂▂▂▂▂▂▃▃▂▁▁▁▁▂▂▂▃▁▁▁▂▂▂▄▂


In [12]:
print("Testing on test")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (txt, company) in enumerate(dataloader_test):
        prediction = model(txt)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(company):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_test)}")

print(f"Accuracy {correct / len(dataloader_test)}")

Testing on test
Iter: 0/342
Iter: 100/342
Iter: 200/342
Iter: 300/342
Accuracy 0.3128654970760234
