In [1]:
import numpy as np
import torch
from torch import nn, optim
from dataset import get_twit_company_dataloaders, get_twit_sentiment_dataloaders, get_twit_company_sentiment_dataloaders
from model import LSTMTwitClassifier
import torch.nn.functional as F

# text, label = next(iter(dataloader_train))

In [2]:
use_wandb = True

lr = 0.0005
embedding_size = 100
hidden_size = 100
epochs_cnt = 50
embeddings = "random"
lstm_layers = 1
dropout = 0.5
task = "text2company" # "text2sentiment"
use_company_info = True
preprocessing = "tutorial"
use_stop_words = True

get_dataloaders = get_twit_company_dataloaders if task == "text2company" else\
get_twit_sentiment_dataloaders if not use_company_info else get_twit_company_sentiment_dataloaders

dataset_train, dataloader_train, dataset_test, dataloader_test = get_twit_company_dataloaders(embedding_dim=embedding_size, embedding=embeddings, preprocessing=preprocessing, use_stop_words=use_stop_words)

model = LSTMTwitClassifier(4, embedding_dim=embedding_size, hidden_dim=hidden_size, dropout=dropout, lstm_layers=lstm_layers)

if use_wandb:
    import wandb

    wandb.init(project=task + '_twit_classification', entity='ars860')

    config = wandb.config
    config.loss = "BCE"
    config.optimizer = "Adam"
    config.learning_rate = lr
    config.hidden_size = hidden_size
    config.embedding_size = embedding_size
    config.embeddings = embeddings
    config.epochs = epochs_cnt
    config.dropout = dropout
    config.lstm_layers = lstm_layers
    config.stem = "snowballstemmer"
    config.preprocessing = preprocessing
    config.use_stop_words = use_stop_words

    if task == "text2sentiment":
        config.use_company_info = use_company_info

    wandb.watch(model)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

def loss_on_test():
    correct = 0
    losses = np.zeros(len(dataloader_test))

    with torch.no_grad():
        model.eval()
        for i, (*args, target) in enumerate(dataloader_test):
            prediction = model(*args)
            prediction = F.softmax(prediction, dim=0)

            losses[i] = F.binary_cross_entropy(prediction, target.view(-1))
            if torch.argmax(prediction) == torch.argmax(target):
                correct += 1

            # predictions_cnt[torch.argmax(prediction)] += 1

            # if i % 100 == 0:
            #     print(f"Iter: {i}/{len(dataloader_test)}")

    model.train()
    if use_wandb:
        wandb.log({"test_loss": np.mean(losses), "test_accuracy": correct / len(dataloader_test)})

losses = np.empty(100)
model.train()
for epoch in range(epochs_cnt):
    epoch_loss = np.zeros(len(dataloader_train))

    for i, (*args, target) in enumerate(dataloader_train):
        model.zero_grad()

        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        loss = criterion(prediction, target.view(-1))

        loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        losses[i % 100] = loss
        epoch_loss[i] = loss

        if (i + 1) % 100 == 0:
            print(
                f"Epoch {epoch + 1}/{epochs_cnt}, iter: {i + 1}/{len(dataloader_train)}, mean loss: {np.mean(losses)}")
            if use_wandb:
                wandb.log({"loss": np.mean(losses)})

    if use_wandb:
        wandb.log({"epoch_loss": np.mean(epoch_loss)})
        loss_on_test()

# [model.get_word_embedding(word) for word in "hello_world".split(' ')]

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ars86\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Tweet ignored due to unreadability: Поиск от 
Tweet ignored due to unreadability: Новите 
Tweet ignored due to unreadability: 看見 
Tweet ignored due to unreadability: نظام جديد .. و جهاز جديد شكراً جزيلاً 
Tweet ignored due to unreadability: الجهاز الجديد عجيب   
Tweet ignored due to unreadability: يبدو ان طفرة الاجهزة الالكترونية القادمة ستكون بقيادة موتورولا ،، لاسيم بعد استحواذ قوقل عليها.   
Tweet ignored due to unreadability: Με συγχισες 
Tweet ignored due to unreadability: На сайте 
Tweet ignored due to unreadability: Настоящий твиттерянин как только попадает в толпу стремиться тут же как можно быстрее попасть в 
Tweet ignored due to unreadability: Доброе утро 
Tweet ignored due to unreadability: 【
Tweet ignored due to unreadability: رقم الفلو والفلورز والتويتات  للبيع لاعلى سعر 


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ars86\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Tweet ignored due to unreadability: قال الرئيس التنفيذي لشركة 
Tweet ignored due to unreadability: Улучшим продукты компании 
Tweet ignored due to unreadability: نفسي يوم يعدي علي تويتر من غير مشاكل فنية 
Tweet ignored due to unreadability: ツイッター検索 


wandb: Currently logged in as: ars860 (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.12.5 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/50, iter: 100/3401, mean loss: 0.2360259737703018
Epoch 1/50, iter: 200/3401, mean loss: 0.00916188510862412
Epoch 1/50, iter: 300/3401, mean loss: 0.006217722113397031
Epoch 1/50, iter: 400/3401, mean loss: 0.0011987403784223716
Epoch 1/50, iter: 500/3401, mean loss: 0.01061335138359027
Epoch 1/50, iter: 600/3401, mean loss: 0.002946282478487774
Epoch 1/50, iter: 700/3401, mean loss: 0.0004199304972598839
Epoch 1/50, iter: 800/3401, mean loss: 0.00040287758041813505
Epoch 1/50, iter: 900/3401, mean loss: 0.0014199040793641871
Epoch 1/50, iter: 1000/3401, mean loss: 0.6121180474207768
Epoch 1/50, iter: 1100/3401, mean loss: 0.23033541007433087
Epoch 1/50, iter: 1200/3401, mean loss: 0.037400278793647886
Epoch 1/50, iter: 1300/3401, mean loss: 0.01491446694766637
Epoch 1/50, iter: 1400/3401, mean loss: 0.01240120724105509
Epoch 1/50, iter: 1500/3401, mean loss: 0.02042275810279534
Epoch 1/50, iter: 1600/3401, mean loss: 0.03190078524276032
Epoch 1/50, iter: 1700/3401, mean loss:

KeyboardInterrupt: 

In [None]:
model.eval()

print("Testing on train")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (*args, target) in enumerate(dataloader_train):
        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(target):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_train)}")

print(f"Accuracy {correct / len(dataloader_train)}")

if use_wandb:
    wandb.run.summary.train_accuracy = correct / len(dataloader_train)
    wandb.run.summary.classified_as = {
        "apple": predictions_cnt[0],
        "google": predictions_cnt[1],
        "microsoft": predictions_cnt[2],
        "twitter": predictions_cnt[3]
    }
    wandb.finish()

In [None]:
print("Testing on test")

correct = 0
predictions_cnt = [0, 0, 0, 0]

with torch.no_grad():
    for i, (*args, target) in enumerate(dataloader_test):
        prediction = model(*args)
        prediction = F.softmax(prediction, dim=0)

        if torch.argmax(prediction) == torch.argmax(target):
            correct += 1

        predictions_cnt[torch.argmax(prediction)] += 1

        if i % 100 == 0:
            print(f"Iter: {i}/{len(dataloader_test)}")

print(f"Accuracy {correct / len(dataloader_test)}")