In [3]:
import numpy as np
import torch
from torch import nn, optim
from dataset import get_twit_company_dataloaders, get_twit_sentiment_dataloaders, get_twit_company_sentiment_dataloaders
from model import LSTMTwitClassifier
import torch.nn.functional as F

# text, label = next(iter(dataloader_train))

In [4]:

use_wandb = True

lr = 0.0005
embedding_size = 100
hidden_size = 100
epochs_cnt = 50
embeddings = "random"
lstm_layers = 1
dropout = 0.5
task = "text2sentiment" # text2company
use_company_info = True

get_dataloaders = get_twit_company_dataloaders if task == "text2company" else\
get_twit_sentiment_dataloaders if not use_company_info else get_twit_company_sentiment_dataloaders

dataset_train, dataloader_train, dataset_test, dataloader_test = get_twit_company_dataloaders(embedding_dim=embedding_size, embedding=embeddings)

model = LSTMTwitClassifier(4, embedding_dim=embedding_size, hidden_dim=hidden_size, dropout=dropout, lstm_layers=lstm_layers)

if use_wandb:
    import wandb

    wandb.init(project=task + '_twit_classification', entity='ars860')

    config = wandb.config
    config.loss = "BCE"
    config.optimizer = "Adam"
    config.learning_rate = lr
    config.hidden_size = hidden_size
    config.embedding_size = embedding_size
    config.embeddings = embeddings
    config.epochs = epochs_cnt
    config.dropout = dropout
    config.lstm_layers = lstm_layers
    config.stem = "snowballstemmer"

    if task == "text2sentiment":
        config.use_company_info = use_company_info

    wandb.watch(model)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

def loss_on_test():
    correct = 0
    losses = np.zeros(len(dataloader_test))

    with torch.no_grad():
        model.eval()
        for i, (*args, target) in enumerate(dataloader_test):
            prediction = model(*args)
            prediction = F.softmax(prediction, dim=0)

            losses[i] = F.binary_cross_entropy(prediction, target.view(-1))
            if torch.argmax(prediction) == torch.argmax(target):
                correct += 1

            # predictions_cnt[torch.argmax(prediction)] += 1

            # if i % 100 == 0:
            #     print(f"Iter: {i}/{len(dataloader_test)}")

    model.train()
    if use_wandb:
        wandb.log({"test_loss": np.mean(losses), "test_accuracy": correct / len(dataloader_test)})

losses = np.empty(100)
model.train()
for epoch in range(epochs_cnt):
    epoch_loss = np.zeros(len(dataloader_train))

    for i, (*args, target) in enumerate(dataloader_train):
        model.zero_grad()

        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        loss = criterion(prediction, target.view(-1))

        loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        losses[i % 100] = loss
        epoch_loss[i] = loss

        if (i + 1) % 100 == 0:
            print(
                f"Epoch {epoch + 1}/{epochs_cnt}, iter: {i + 1}/{len(dataloader_train)}, mean loss: {np.mean(losses)}")
            if use_wandb:
                wandb.log({"loss": np.mean(losses)})

    if use_wandb:
        wandb.log({"epoch_loss": np.mean(epoch_loss)})
        loss_on_test()

# [model.get_word_embedding(word) for word in "hello_world".split(' ')]

Tweet ignored due to unreadability: http://t.co/48emAEID 
Tweet ignored due to unreadability: http://t.co/Izh7KaiU 
Tweet ignored due to unreadability: http://t.co/e5ClGzsI 
Tweet ignored due to unreadability: http://t.co/18xg3ivo! 
Tweet ignored due to unreadability: Поиск от 
Tweet ignored due to unreadability: サムスン電子のスマートフォン新機種「ギャラクシー・ネクサス」、グーグルの基本ソフト（ＯＳ）「アンドロイド」最新版を搭載。「クラウド」活用、音声認識やカメラの機能も向上させた戦略モデル 
Tweet ignored due to unreadability: Новите 
Tweet ignored due to unreadability: اعرف الكثير عن نظام ايسكريم ساندويتش http://t.co/Fzjd2Zx1 
Tweet ignored due to unreadability: 看見 
Tweet ignored due to unreadability: نظام جديد .. و جهاز جديد شكراً جزيلاً 
Tweet ignored due to unreadability: ايسكريم ساندويش، عسل، زنجبيل .. مشكلة من كثر المسميات احسهم مسوين مقادير مب انظمة !! 😝  
Tweet ignored due to unreadability: الجهاز الجديد عجيب   
Tweet ignored due to unreadability: Я немного потрясен :) 
Tweet ignored due to unreadability: 今日発表だった＾ﾛ＾　 
Tweet ignored due to unreadability: يبدو ان طفر

wandb: Currently logged in as: ars860 (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.12.4 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/50, iter: 100/3382, mean loss: 0.42261880833655596
Epoch 1/50, iter: 200/3382, mean loss: 0.023846450009150432
Epoch 1/50, iter: 300/3382, mean loss: 0.010459279580500152
Epoch 1/50, iter: 400/3382, mean loss: 0.008468175009966217
Epoch 1/50, iter: 500/3382, mean loss: 0.021027950721418165
Epoch 1/50, iter: 600/3382, mean loss: 0.003911573119510194
Epoch 1/50, iter: 700/3382, mean loss: 0.005205889387016213
Epoch 1/50, iter: 800/3382, mean loss: 0.0016406561168128065
Epoch 1/50, iter: 900/3382, mean loss: 0.00439424554352513
Epoch 1/50, iter: 1000/3382, mean loss: 0.5940594247990703
Epoch 1/50, iter: 1100/3382, mean loss: 0.33317440437152984
Epoch 1/50, iter: 1200/3382, mean loss: 0.07702367916877847
Epoch 1/50, iter: 1300/3382, mean loss: 0.03264839368152025
Epoch 1/50, iter: 1400/3382, mean loss: 0.02000695823757269
Epoch 1/50, iter: 1500/3382, mean loss: 0.023298537346854574
Epoch 1/50, iter: 1600/3382, mean loss: 0.04022621084817729
Epoch 1/50, iter: 1700/3382, mean loss: 0

In [5]:
model.eval()

print("Testing on train")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (*args, target) in enumerate(dataloader_train):
        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(target):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_train)}")

print(f"Accuracy {correct / len(dataloader_train)}")

if use_wandb:
    wandb.run.summary.train_accuracy = correct / len(dataloader_train)
    wandb.run.summary.classified_as = {
        "apple": predictions_cnt[0],
        "google": predictions_cnt[1],
        "microsoft": predictions_cnt[2],
        "twitter": predictions_cnt[3]
    }
    wandb.finish()

Testing on train
Iter: 0/3382
Iter: 100/3382
Iter: 200/3382
Iter: 300/3382
Iter: 400/3382
Iter: 500/3382
Iter: 600/3382
Iter: 700/3382
Iter: 800/3382
Iter: 900/3382
Iter: 1000/3382
Iter: 1100/3382
Iter: 1200/3382
Iter: 1300/3382
Iter: 1400/3382
Iter: 1500/3382
Iter: 1600/3382
Iter: 1700/3382
Iter: 1800/3382
Iter: 1900/3382
Iter: 2000/3382
Iter: 2100/3382
Iter: 2200/3382
Iter: 2300/3382
Iter: 2400/3382
Iter: 2500/3382
Iter: 2600/3382
Iter: 2700/3382
Iter: 2800/3382
Iter: 2900/3382
Iter: 3000/3382
Iter: 3100/3382
Iter: 3200/3382
Iter: 3300/3382
Accuracy 0.9943820224719101


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.00716
_runtime,1598.0
_timestamp,1633892405.0
_step,1749.0
epoch_loss,0.01722
test_loss,0.54065
test_accuracy,0.80473
train_accuracy,0.99438


0,1
loss,▁▁▁▁▅█▁▁▁▂▁▁▁▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,▇██▇▇▇▇█▇▇▆▆▄▅▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▂▃▂▁▁▁▁▁▁▁▁
test_loss,██▇█▇▇▆▅▄▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▂▂▂▂▂▁
test_accuracy,▁▁▁▁▁▁▁▂▂▃▄▅▆▇▇▇█████████████████▇██████


In [6]:
print("Testing on test")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (*args, target) in enumerate(dataloader_test):
        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(target):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_test)}")

print(f"Accuracy {correct / len(dataloader_test)}")

Testing on test
Iter: 0/338
Iter: 100/338
Iter: 200/338
Iter: 300/338
Accuracy 0.8047337278106509
