# Stance Detection Using Pytorch and ELMo

## Set Cuda

In [1]:
import torch as th
import numpy as np

th.set_num_threads(6)

In [2]:
USE_CUDA = True
if USE_CUDA:
    if not th.cuda.is_available():
        USE_CUDA = False
        print("WARNING: <Could not use cuda> Resource UnAvailable")
    else:
        print("LOG: <Using CUDA>")

LOG: <Using CUDA>


## Load the embedding model

In [3]:
from allennlp.modules.elmo import Elmo, batch_to_ids

weights_path = "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
options_path = "elmo_small_options.json"

embeddings = Elmo(options_path, weights_path, 1)
if USE_CUDA:
    embeddings.cuda()

## Load the Data

In [1]:
PRELOAD = False
preload_path = "data_dump_elmo.data"

In [2]:
if PRELOAD == True:
    import pickle
    data = pickle.load(open(preload_path, "rb"))
    data_train = data["X_train"]
    labels_train = data["Y_train"]
    data_test = data["X_test"]

In [7]:
if PRELOAD == False:
    
    import pandas as pd
    import numpy as np

    dtypes_train = {"id":np.int64, "text":str, "author":str, "title":str, "label":np.int64}
    dtypes_test = {"id":np.int64, "text":str, "author":str, "title":str}

    SEED = 1234

    # load data

    train_df = pd.read_csv("data/train.csv", dtype=dtypes_train)
    train_df = train_df.dropna()
    train_df = train_df.sample(frac=1)
    X_train = train_df.drop(['label', 'id', 'author'], axis=1).values
    Y_train = train_df['label'].values
    print("Train DIMS \nX dims: {} Y dims: {}".format(X_train.shape, Y_train.shape))

    test_df = pd.read_csv("data/test.csv")
    test_df = test_df.dropna()
    X_test = test_df.drop(['id', 'author'], axis=1).values
    print("Test DIMS \nX dims: {}".format(X_test.shape))
    print(np.unique(Y_train))
    
    import pickle
    from nltk.corpus import stopwords
    import re
    
    # preprocessing
    stop_words = set(stopwords.words("english"))

    train_data = []
    test_data = []
    

    for i in range(len(X_train)):
        headline, article = X_train[i]
        headline = re.sub(r'[^\w\s]|\n|\r','',headline)
        article = re.sub(r'[^\w\s]|\n|\r','',article)
        headline = headline.lower().split(" ")
        article = article.lower().split(" ")
        headline = [word for word in headline if word not in stop_words]
        article = [word for word in article if word not in stop_words]
        train_data.append([ headline, article])


    for i in range(len(X_test)):
        headline, article = X_test[i]
        headline = re.sub(r'[^\w\s]|\n|\r','',headline)
        article = re.sub(r'[^\w\s]|\n|\r','',article)
        headline = headline.lower().split(" ")
        article = article.lower().split(" ")
        headline = [word for word in headline if word not in stop_words]
        article = [word for word in article if word not in stop_words]
        test_data.append([ headline, article])


    data_save = {"X_train": train_data, "Y_train": Y_train, "X_test": test_data}
    pickle.dump(data_save, open("data_dump_elmo.data", "wb"))
    

Train DIMS 
X dims: (18285, 2) Y dims: (18285,)
Test DIMS 
X dims: (4575, 2)
[0 1]


#### Check Embeddings

In [7]:
sentences = [['First', 'sentence', '.', "aviral"], ['Another', '.'], ["Another", "one"]]
character_ids = batch_to_ids(sentences)
if USE_CUDA:
    character_ids = character_ids.cuda()
op = embeddings(character_ids)
print(op.keys())
print(len(op["elmo_representations"]))
print(op["elmo_representations"][0].shape)

dict_keys(['elmo_representations', 'mask'])
1
torch.Size([3, 4, 256])


## Build Model

In [8]:
import torch as th
import torch.nn as nn
from torch.autograd import Variable

In [9]:
class FakeNewsClassifier(nn.Module):
    
    def __init__(self, embedding_layer, embedding_dim, headline_lstm_hidden_dim, article_lstm_hidden_dim, lstm_layers, output_dim, use_cuda=False):
        
        super(FakeNewsClassifier, self).__init__()
        self.embedding_dim = embedding_dim
        self.headline_lstm_hidden_dim = headline_lstm_hidden_dim
        self.article_lstm_hidden_dim = article_lstm_hidden_dim
        self.lstm_layers = lstm_layers
        self.use_cuda = use_cuda

        self.embeddings = embedding_layer
        self.lstm_headline = nn.LSTM(embedding_dim, headline_lstm_hidden_dim, lstm_layers, batch_first=True)
        self.lstm_article = nn.LSTM(embedding_dim, article_lstm_hidden_dim, lstm_layers, batch_first=True)
        self.fc = nn.Linear(headline_lstm_hidden_dim+article_lstm_hidden_dim, output_dim)
        
    def forward(self, x):
        
        out_h = self.embeddings(x[0])["elmo_representations"][0]
        out_a = self.embeddings(x[1])["elmo_representations"][0]
        
        a0_h = Variable(th.zeros(self.lstm_layers, out_h.size(0), self.headline_lstm_hidden_dim))
        c0_h = Variable(th.zeros(self.lstm_layers, out_h.size(0), self.headline_lstm_hidden_dim))
        a0_a = Variable(th.zeros(self.lstm_layers, out_a.size(0), self.article_lstm_hidden_dim))
        c0_a = Variable(th.zeros(self.lstm_layers, out_a.size(0), self.article_lstm_hidden_dim))
        
        if self.use_cuda:
                a0_h = a0_h.cuda()
                c0_h = c0_h.cuda()
                a0_a = a0_a.cuda()
                c0_a = c0_a.cuda()
        
        out_h, (an_h, cn_h) = self.lstm_headline(out_h, (a0_h, c0_h))
        out_a, (an_a, cn_a) = self.lstm_article(out_a, (a0_a, c0_a))
#         print("out_h shape : {}, out_a shape : {}".format(out_h.shape, out_a.shape))
        out = th.cat((out_h[:, -1, :], out_a[:, -1, :]), 1)
#         print("out shape : {}".format(out.shape))
        
        out = self.fc(out)
#         print("fc out shape : {}".format(out.shape))

        
        return out        

In [10]:
embedding_dim = 256
headline_lstm_hidden_dim = 20
article_lstm_hidden_dim = 100
lstm_layers = 1
output_dim = 1
model = FakeNewsClassifier(embeddings, embedding_dim, headline_lstm_hidden_dim, article_lstm_hidden_dim, lstm_layers, output_dim, use_cuda=True)

In [11]:
if USE_CUDA:
    model.cuda()
model

FakeNewsClassifier(
  (embeddings): Elmo(
    (_elmo_lstm): _ElmoBiLm(
      (_token_embedder): _ElmoCharacterEncoder(
        (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
        (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
        (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
        (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
        (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
        (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
        (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
        (_highways): Highway(
          (_layers): ModuleList(
            (0): Linear(in_features=2048, out_features=4096, bias=True)
          )
        )
        (_projection): Linear(in_features=2048, out_features=128, bias=True)
      )
      (_elmo_lstm): ElmoLstm(
        (forward_layer_0): LstmCellWithProjection(
          (input_linearity): Linear(in_features=128, out_features=4096

In [103]:
loss_func = nn.BCEWithLogitsLoss()
optimizer = th.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=0.01)

In [104]:
train_data, test_data = data_train[:-2000], data_train[-2000:]
train_labels, test_labels = labels_train[:-2000], labels_train[-2000:]

In [105]:
from tqdm import tqdm

In [107]:
epochs = 5
iters = 0
max_headline_size = 50
max_article_size = 200
batch_size = 10

for num_epoch in range(epochs):
    headline_batch = []
    article_batch = []
    labels_batch = []
    for i, (headline, article) in tqdm(enumerate(train_data)):
        if i % batch_size != 0 or i == 0:
            headline_batch.append(headline[:max_headline_size])
            article_batch.append(article[:max_article_size])
            labels_batch.append(train_labels[i])
            continue
            
        headline = batch_to_ids(headline_batch)
        article = batch_to_ids(article_batch)
        
        label = Variable(th.Tensor(labels_batch))
        
        if USE_CUDA:
            headline = headline.cuda()
            article = article.cuda()
            label = label.cuda()
        
        optimizer.zero_grad()
        output = model([headline, article])
        loss = loss_func(output.view(-1), label)
        loss.backward()
        optimizer.step()
        
        iters += batch_size
        headline_batch = []
        article_batch = []
        labels_batch = []
        
        if iters%100 == 0:
            correct = 0
            total = 0
            selected_50 = np.random.randint(0, high=len(test_data), size=50)
            
            headline_test_batch = []
            article_test_batch = []
            labels_test_batch = []
            
            for j, idx in enumerate(selected_50):
                    test_headline, test_article = test_data[idx]
                    headline_test_batch.append(test_headline[:max_headline_size])
                    article_test_batch.append(test_article[:max_article_size])
                    labels_test_batch.append(test_labels[idx])
                    
            test_headline = batch_to_ids(headline_test_batch)
            test_article = batch_to_ids(article_test_batch)
            label = Variable(th.Tensor(labels_test_batch))
                
            if USE_CUDA:
                test_headline = test_headline.cuda()
                test_article = test_article.cuda()
                
            output = model([test_headline, test_article])
            predicted = output.cpu() > 0 
            total += 50
            correct += (predicted.view(-1).type(th.FloatTensor) == label).sum()
                

            accuracy = 100 * correct/total
            print("Iteration: {}, Loss: {}, Accuracy: {}".format(iters, loss.data, accuracy))
#     break

101it [00:06, 14.96it/s]

Iteration: 100, Loss: 0.15878760814666748, Accuracy: 90


201it [00:13, 14.80it/s]

Iteration: 200, Loss: 0.08742182701826096, Accuracy: 78


301it [00:20, 14.63it/s]

Iteration: 300, Loss: 0.1434897929430008, Accuracy: 90


401it [00:27, 14.69it/s]

Iteration: 400, Loss: 0.3177160620689392, Accuracy: 80


501it [00:34, 14.60it/s]

Iteration: 500, Loss: 0.15023723244667053, Accuracy: 90


601it [00:41, 14.63it/s]

Iteration: 600, Loss: 0.2132093757390976, Accuracy: 90


701it [00:48, 14.58it/s]

Iteration: 700, Loss: 0.025794317945837975, Accuracy: 94


801it [00:54, 14.66it/s]

Iteration: 800, Loss: 0.19942189753055573, Accuracy: 88


901it [01:01, 14.70it/s]

Iteration: 900, Loss: 0.677211344242096, Accuracy: 94


1001it [01:07, 14.76it/s]

Iteration: 1000, Loss: 0.24247795343399048, Accuracy: 94


1101it [01:14, 14.76it/s]

Iteration: 1100, Loss: 0.25158846378326416, Accuracy: 92


1201it [01:21, 14.81it/s]

Iteration: 1200, Loss: 0.6166430711746216, Accuracy: 84


1301it [01:27, 14.82it/s]

Iteration: 1300, Loss: 0.3910180926322937, Accuracy: 80


1401it [01:34, 14.86it/s]

Iteration: 1400, Loss: 0.2245529443025589, Accuracy: 92


1501it [01:40, 14.89it/s]

Iteration: 1500, Loss: 0.16076654195785522, Accuracy: 82


1601it [01:47, 14.89it/s]

Iteration: 1600, Loss: 0.1312716007232666, Accuracy: 90


1701it [01:53, 14.92it/s]

Iteration: 1700, Loss: 0.20529481768608093, Accuracy: 78


1801it [02:00, 14.89it/s]

Iteration: 1800, Loss: 0.29214024543762207, Accuracy: 96


1901it [02:07, 14.89it/s]

Iteration: 1900, Loss: 0.340060830116272, Accuracy: 86


2001it [02:14, 14.91it/s]

Iteration: 2000, Loss: 0.5047404766082764, Accuracy: 96


2101it [02:20, 14.90it/s]

Iteration: 2100, Loss: 0.1836126148700714, Accuracy: 88


2201it [02:27, 14.89it/s]

Iteration: 2200, Loss: 0.2334982007741928, Accuracy: 88


2301it [02:34, 14.86it/s]

Iteration: 2300, Loss: 0.07843425869941711, Accuracy: 86


2401it [02:41, 14.88it/s]

Iteration: 2400, Loss: 0.023791275918483734, Accuracy: 90


2501it [02:48, 14.82it/s]

Iteration: 2500, Loss: 0.4176670014858246, Accuracy: 88


2601it [02:55, 14.81it/s]

Iteration: 2600, Loss: 0.10171667486429214, Accuracy: 88


2701it [03:02, 14.83it/s]

Iteration: 2700, Loss: 0.4821033477783203, Accuracy: 92


2801it [03:08, 14.83it/s]

Iteration: 2800, Loss: 0.05112498998641968, Accuracy: 88


2901it [03:15, 14.85it/s]

Iteration: 2900, Loss: 0.21214647591114044, Accuracy: 94


3001it [03:21, 14.87it/s]

Iteration: 3000, Loss: 0.48092180490493774, Accuracy: 92


3101it [03:29, 14.82it/s]

Iteration: 3100, Loss: 0.43572986125946045, Accuracy: 92


3201it [03:37, 14.73it/s]

Iteration: 3200, Loss: 0.3166259527206421, Accuracy: 94


3301it [03:44, 14.67it/s]

Iteration: 3300, Loss: 0.12160137295722961, Accuracy: 86


3401it [03:52, 14.65it/s]

Iteration: 3400, Loss: 0.5547945499420166, Accuracy: 92


3501it [03:59, 14.60it/s]

Iteration: 3500, Loss: 0.11270762979984283, Accuracy: 92


3601it [04:06, 14.58it/s]

Iteration: 3600, Loss: 0.22728385031223297, Accuracy: 90


3701it [04:13, 14.59it/s]

Iteration: 3700, Loss: 0.10913512110710144, Accuracy: 90


3801it [04:20, 14.57it/s]

Iteration: 3800, Loss: 0.06250539422035217, Accuracy: 94


3901it [04:28, 14.53it/s]

Iteration: 3900, Loss: 0.10136838257312775, Accuracy: 92


4001it [04:35, 14.55it/s]

Iteration: 4000, Loss: 0.14026229083538055, Accuracy: 90


4101it [04:42, 14.50it/s]

Iteration: 4100, Loss: 0.19559399783611298, Accuracy: 96


4201it [04:49, 14.49it/s]

Iteration: 4200, Loss: 0.1339571326971054, Accuracy: 92


4301it [04:57, 14.46it/s]

Iteration: 4300, Loss: 0.104840487241745, Accuracy: 88


4401it [05:04, 14.43it/s]

Iteration: 4400, Loss: 0.12818643450737, Accuracy: 90


4501it [05:11, 14.43it/s]

Iteration: 4500, Loss: 0.2623196542263031, Accuracy: 88


4601it [05:19, 14.41it/s]

Iteration: 4600, Loss: 0.17046135663986206, Accuracy: 82


4701it [05:25, 14.42it/s]

Iteration: 4700, Loss: 0.23533132672309875, Accuracy: 88


4801it [05:32, 14.44it/s]

Iteration: 4800, Loss: 0.45538070797920227, Accuracy: 94


4901it [05:40, 14.41it/s]

Iteration: 4900, Loss: 0.3334552049636841, Accuracy: 84


5001it [05:47, 14.38it/s]

Iteration: 5000, Loss: 0.06781506538391113, Accuracy: 88


5101it [05:54, 14.37it/s]

Iteration: 5100, Loss: 0.06724816560745239, Accuracy: 86


5201it [06:02, 14.34it/s]

Iteration: 5200, Loss: 0.3911055326461792, Accuracy: 90


5301it [06:09, 14.34it/s]

Iteration: 5300, Loss: 0.1915130913257599, Accuracy: 90


5401it [06:17, 14.31it/s]

Iteration: 5400, Loss: 0.25811663269996643, Accuracy: 94


5501it [06:25, 14.28it/s]

Iteration: 5500, Loss: 0.17771658301353455, Accuracy: 88


5601it [06:32, 14.28it/s]

Iteration: 5600, Loss: 0.18690988421440125, Accuracy: 80


5701it [06:39, 14.27it/s]

Iteration: 5700, Loss: 0.0722602903842926, Accuracy: 88


5801it [06:46, 14.28it/s]

Iteration: 5800, Loss: 0.04928729310631752, Accuracy: 92


5901it [06:52, 14.29it/s]

Iteration: 5900, Loss: 0.024980805814266205, Accuracy: 92


6001it [07:01, 14.25it/s]

Iteration: 6000, Loss: 0.018138552084565163, Accuracy: 88


6101it [07:09, 14.21it/s]

Iteration: 6100, Loss: 0.018106702715158463, Accuracy: 84


6201it [07:16, 14.21it/s]

Iteration: 6200, Loss: 0.061770983040332794, Accuracy: 88


6301it [07:23, 14.21it/s]

Iteration: 6300, Loss: 0.3871420621871948, Accuracy: 90


6401it [07:31, 14.18it/s]

Iteration: 6400, Loss: 0.046152014285326004, Accuracy: 90


6501it [07:38, 14.19it/s]

Iteration: 6500, Loss: 0.04543222486972809, Accuracy: 94


6601it [07:46, 14.16it/s]

Iteration: 6600, Loss: 0.41910824179649353, Accuracy: 90


6701it [07:54, 14.14it/s]

Iteration: 6700, Loss: 0.07205823808908463, Accuracy: 94


6801it [08:02, 14.09it/s]

Iteration: 6800, Loss: 0.35042089223861694, Accuracy: 98


6901it [08:10, 14.06it/s]

Iteration: 6900, Loss: 0.029775619506835938, Accuracy: 94


7001it [08:17, 14.07it/s]

Iteration: 7000, Loss: 0.1911603808403015, Accuracy: 96


7101it [08:24, 14.06it/s]

Iteration: 7100, Loss: 0.1471201330423355, Accuracy: 98


7201it [08:32, 14.05it/s]

Iteration: 7200, Loss: 0.23828785121440887, Accuracy: 96


7301it [08:39, 14.05it/s]

Iteration: 7300, Loss: 0.13303931057453156, Accuracy: 98


7401it [08:47, 14.02it/s]

Iteration: 7400, Loss: 0.1201920360326767, Accuracy: 94


7501it [08:54, 14.03it/s]

Iteration: 7500, Loss: 0.029920216649770737, Accuracy: 94


7601it [09:02, 14.02it/s]

Iteration: 7600, Loss: 0.005676901433616877, Accuracy: 100


7701it [09:10, 13.99it/s]

Iteration: 7700, Loss: 0.04838494583964348, Accuracy: 94


7801it [09:18, 13.98it/s]

Iteration: 7800, Loss: 0.01428695023059845, Accuracy: 98


7901it [09:26, 13.94it/s]

Iteration: 7900, Loss: 0.07371816784143448, Accuracy: 98


8001it [09:34, 13.93it/s]

Iteration: 8000, Loss: 0.039585381746292114, Accuracy: 98


8101it [09:41, 13.93it/s]

Iteration: 8100, Loss: 0.0963284894824028, Accuracy: 98


8201it [09:49, 13.91it/s]

Iteration: 8200, Loss: 0.33211368322372437, Accuracy: 94


8301it [09:57, 13.90it/s]

Iteration: 8300, Loss: 0.5004107356071472, Accuracy: 98


8401it [10:04, 13.89it/s]

Iteration: 8400, Loss: 0.24221627414226532, Accuracy: 94


8501it [10:12, 13.89it/s]

Iteration: 8500, Loss: 0.1899150013923645, Accuracy: 90


8601it [10:19, 13.89it/s]

Iteration: 8600, Loss: 0.017033696174621582, Accuracy: 98


8701it [10:27, 13.87it/s]

Iteration: 8700, Loss: 0.25198203325271606, Accuracy: 92


8801it [10:34, 13.87it/s]

Iteration: 8800, Loss: 0.1970723271369934, Accuracy: 98


8901it [10:42, 13.86it/s]

Iteration: 8900, Loss: 0.1380009949207306, Accuracy: 98


9001it [10:48, 13.87it/s]

Iteration: 9000, Loss: 0.06080900877714157, Accuracy: 94


9101it [10:56, 13.86it/s]

Iteration: 9100, Loss: 0.1087385043501854, Accuracy: 96


9201it [11:03, 13.87it/s]

Iteration: 9200, Loss: 0.08321070671081543, Accuracy: 98


9301it [11:11, 13.85it/s]

Iteration: 9300, Loss: 0.2234807312488556, Accuracy: 98


9401it [11:18, 13.86it/s]

Iteration: 9400, Loss: 0.10935208946466446, Accuracy: 96


9501it [11:26, 13.84it/s]

Iteration: 9500, Loss: 0.06410558521747589, Accuracy: 100


9601it [11:33, 13.85it/s]

Iteration: 9600, Loss: 0.003884661942720413, Accuracy: 98


9701it [11:40, 13.85it/s]

Iteration: 9700, Loss: 0.05714653059840202, Accuracy: 100


9801it [11:47, 13.85it/s]

Iteration: 9800, Loss: 0.04430858790874481, Accuracy: 96


9901it [11:55, 13.84it/s]

Iteration: 9900, Loss: 0.0150457127019763, Accuracy: 98


10001it [12:02, 13.84it/s]

Iteration: 10000, Loss: 0.037617385387420654, Accuracy: 92


10101it [12:08, 13.86it/s]

Iteration: 10100, Loss: 0.27711647748947144, Accuracy: 100


10201it [12:16, 13.85it/s]

Iteration: 10200, Loss: 0.08150267601013184, Accuracy: 96


10301it [12:24, 13.84it/s]

Iteration: 10300, Loss: 0.0883929505944252, Accuracy: 94


10401it [12:32, 13.83it/s]

Iteration: 10400, Loss: 0.05394750460982323, Accuracy: 98


10501it [12:40, 13.82it/s]

Iteration: 10500, Loss: 0.3402685225009918, Accuracy: 98


10601it [12:47, 13.81it/s]

Iteration: 10600, Loss: 0.01927514374256134, Accuracy: 100


10701it [12:54, 13.82it/s]

Iteration: 10700, Loss: 0.013696245849132538, Accuracy: 100


10801it [13:01, 13.82it/s]

Iteration: 10800, Loss: 0.0633830577135086, Accuracy: 96


10901it [13:07, 13.84it/s]

Iteration: 10900, Loss: 0.027756260707974434, Accuracy: 100


11001it [13:15, 13.83it/s]

Iteration: 11000, Loss: 0.08372155576944351, Accuracy: 94


11101it [13:22, 13.83it/s]

Iteration: 11100, Loss: 0.013211204670369625, Accuracy: 98


11201it [13:29, 13.83it/s]

Iteration: 11200, Loss: 0.2189949005842209, Accuracy: 98


11301it [13:37, 13.83it/s]

Iteration: 11300, Loss: 0.05111020430922508, Accuracy: 98


11401it [13:44, 13.83it/s]

Iteration: 11400, Loss: 0.037056103348731995, Accuracy: 98


11501it [13:51, 13.82it/s]

Iteration: 11500, Loss: 0.1810736060142517, Accuracy: 96


11601it [13:59, 13.82it/s]

Iteration: 11600, Loss: 0.018047457560896873, Accuracy: 90


11701it [14:06, 13.82it/s]

Iteration: 11700, Loss: 0.18547546863555908, Accuracy: 100


11801it [14:13, 13.82it/s]

Iteration: 11800, Loss: 0.38663747906684875, Accuracy: 100


11901it [14:20, 13.83it/s]

Iteration: 11900, Loss: 0.02542419172823429, Accuracy: 98


12001it [14:27, 13.84it/s]

Iteration: 12000, Loss: 0.1467922180891037, Accuracy: 100


12101it [14:33, 13.85it/s]

Iteration: 12100, Loss: 0.0413222461938858, Accuracy: 94


12201it [14:41, 13.84it/s]

Iteration: 12200, Loss: 0.031398192048072815, Accuracy: 96


12301it [14:49, 13.83it/s]

Iteration: 12300, Loss: 0.019779391586780548, Accuracy: 96


12401it [14:57, 13.82it/s]

Iteration: 12400, Loss: 0.019677935168147087, Accuracy: 96


12501it [15:04, 13.82it/s]

Iteration: 12500, Loss: 0.0032389764674007893, Accuracy: 100


12601it [15:13, 13.80it/s]

Iteration: 12600, Loss: 0.15767371654510498, Accuracy: 100


12701it [15:21, 13.78it/s]

Iteration: 12700, Loss: 0.07530476152896881, Accuracy: 96


12801it [15:28, 13.79it/s]

Iteration: 12800, Loss: 0.2494707703590393, Accuracy: 100


12901it [15:35, 13.79it/s]

Iteration: 12900, Loss: 0.05528021603822708, Accuracy: 90


13001it [15:43, 13.79it/s]

Iteration: 13000, Loss: 0.21683566272258759, Accuracy: 94


13101it [15:50, 13.78it/s]

Iteration: 13100, Loss: 0.32008710503578186, Accuracy: 96


13201it [15:58, 13.78it/s]

Iteration: 13200, Loss: 0.00777834840118885, Accuracy: 94


13301it [16:05, 13.78it/s]

Iteration: 13300, Loss: 0.03817174583673477, Accuracy: 100


13401it [16:13, 13.77it/s]

Iteration: 13400, Loss: 0.04097612947225571, Accuracy: 96


13501it [16:20, 13.76it/s]

Iteration: 13500, Loss: 0.008316799066960812, Accuracy: 98


13601it [16:27, 13.77it/s]

Iteration: 13600, Loss: 0.03126154839992523, Accuracy: 96


13701it [16:34, 13.77it/s]

Iteration: 13700, Loss: 0.33302000164985657, Accuracy: 98


13801it [16:41, 13.78it/s]

Iteration: 13800, Loss: 0.06791717559099197, Accuracy: 96


13901it [16:48, 13.78it/s]

Iteration: 13900, Loss: 0.01802210696041584, Accuracy: 96


14001it [16:55, 13.79it/s]

Iteration: 14000, Loss: 0.054255411028862, Accuracy: 100


14101it [17:02, 13.79it/s]

Iteration: 14100, Loss: 0.02734442427754402, Accuracy: 98


14201it [17:10, 13.78it/s]

Iteration: 14200, Loss: 0.12172440439462662, Accuracy: 98


14301it [17:18, 13.77it/s]

Iteration: 14300, Loss: 0.011886556632816792, Accuracy: 98


14401it [17:26, 13.76it/s]

Iteration: 14400, Loss: 0.004553120117634535, Accuracy: 94


14501it [17:33, 13.76it/s]

Iteration: 14500, Loss: 0.2140372395515442, Accuracy: 100


14601it [17:40, 13.77it/s]

Iteration: 14600, Loss: 0.09617599099874496, Accuracy: 98


14701it [17:47, 13.77it/s]

Iteration: 14700, Loss: 0.1696256846189499, Accuracy: 98


14801it [17:54, 13.78it/s]

Iteration: 14800, Loss: 0.009959021583199501, Accuracy: 96


14901it [18:01, 13.78it/s]

Iteration: 14900, Loss: 0.007225914858281612, Accuracy: 96


15001it [18:08, 13.78it/s]

Iteration: 15000, Loss: 0.1477334201335907, Accuracy: 96


15101it [18:15, 13.78it/s]

Iteration: 15100, Loss: 0.002119120443239808, Accuracy: 96


15201it [18:22, 13.78it/s]

Iteration: 15200, Loss: 0.04224063456058502, Accuracy: 100


15301it [18:30, 13.77it/s]

Iteration: 15300, Loss: 0.014743064530193806, Accuracy: 96


15401it [18:37, 13.78it/s]

Iteration: 15400, Loss: 0.0062857638113200665, Accuracy: 96


15501it [18:44, 13.78it/s]

Iteration: 15500, Loss: 0.052385054528713226, Accuracy: 96


15601it [18:54, 13.75it/s]

Iteration: 15600, Loss: 0.03907480835914612, Accuracy: 98


15701it [19:03, 13.74it/s]

Iteration: 15700, Loss: 0.023411402478814125, Accuracy: 100


15801it [19:10, 13.74it/s]

Iteration: 15800, Loss: 0.2775599956512451, Accuracy: 98


15901it [19:17, 13.73it/s]

Iteration: 15900, Loss: 0.010285632684826851, Accuracy: 96


16001it [19:26, 13.72it/s]

Iteration: 16000, Loss: 0.015732379630208015, Accuracy: 100


16101it [19:33, 13.72it/s]

Iteration: 16100, Loss: 0.011215819045901299, Accuracy: 96


16201it [19:41, 13.72it/s]

Iteration: 16200, Loss: 0.02831917628645897, Accuracy: 94


16285it [19:46, 13.73it/s]
21it [00:02,  9.15it/s]

Iteration: 16300, Loss: 0.32480937242507935, Accuracy: 98


121it [00:09, 13.14it/s]

Iteration: 16400, Loss: 0.010691293515264988, Accuracy: 100


221it [00:15, 13.88it/s]

Iteration: 16500, Loss: 0.013515225611627102, Accuracy: 96


321it [00:22, 14.31it/s]

Iteration: 16600, Loss: 0.013849865645170212, Accuracy: 100


421it [00:29, 14.14it/s]

Iteration: 16700, Loss: 0.017949221655726433, Accuracy: 100


521it [00:36, 14.09it/s]

Iteration: 16800, Loss: 0.021279634907841682, Accuracy: 98


621it [00:44, 14.04it/s]

Iteration: 16900, Loss: 0.04528233781456947, Accuracy: 98


721it [00:51, 13.96it/s]

Iteration: 17000, Loss: 0.4463052749633789, Accuracy: 100


821it [00:59, 13.76it/s]

Iteration: 17100, Loss: 0.1548795849084854, Accuracy: 98


921it [01:07, 13.63it/s]

Iteration: 17200, Loss: 0.009832399897277355, Accuracy: 100


1021it [01:14, 13.70it/s]

Iteration: 17300, Loss: 0.015472509898245335, Accuracy: 100


1121it [01:21, 13.70it/s]

Iteration: 17400, Loss: 0.011228416115045547, Accuracy: 94


1221it [01:28, 13.72it/s]

Iteration: 17500, Loss: 0.05045467987656593, Accuracy: 94


1321it [01:36, 13.70it/s]

Iteration: 17600, Loss: 0.03544076904654503, Accuracy: 98


1421it [01:43, 13.69it/s]

Iteration: 17700, Loss: 0.116963692009449, Accuracy: 92


1521it [01:50, 13.77it/s]

Iteration: 17800, Loss: 0.0024595779832452536, Accuracy: 96


1621it [01:57, 13.76it/s]

Iteration: 17900, Loss: 0.04995136708021164, Accuracy: 96


1721it [02:05, 13.69it/s]

Iteration: 18000, Loss: 0.19610947370529175, Accuracy: 100


1821it [02:12, 13.73it/s]

Iteration: 18100, Loss: 0.010238584131002426, Accuracy: 98


1921it [02:19, 13.80it/s]

Iteration: 18200, Loss: 0.22431805729866028, Accuracy: 98


2021it [02:26, 13.84it/s]

Iteration: 18300, Loss: 0.018593741580843925, Accuracy: 100


2121it [02:33, 13.81it/s]

Iteration: 18400, Loss: 0.0017920811660587788, Accuracy: 98


2221it [02:41, 13.76it/s]

Iteration: 18500, Loss: 0.073081374168396, Accuracy: 100


2321it [02:49, 13.73it/s]

Iteration: 18600, Loss: 0.0203168373554945, Accuracy: 94


2421it [02:56, 13.71it/s]

Iteration: 18700, Loss: 0.0038320422172546387, Accuracy: 98


2521it [03:04, 13.67it/s]

Iteration: 18800, Loss: 0.026301585137844086, Accuracy: 98


2621it [03:11, 13.66it/s]

Iteration: 18900, Loss: 0.013826918788254261, Accuracy: 100


2721it [03:18, 13.71it/s]

Iteration: 19000, Loss: 0.005844374652951956, Accuracy: 100


2821it [03:26, 13.66it/s]

Iteration: 19100, Loss: 0.15475508570671082, Accuracy: 94


2921it [03:34, 13.60it/s]

Iteration: 19200, Loss: 0.6557632088661194, Accuracy: 100


3021it [03:41, 13.64it/s]

Iteration: 19300, Loss: 0.004264002665877342, Accuracy: 100


3121it [03:48, 13.68it/s]

Iteration: 19400, Loss: 0.014115233905613422, Accuracy: 96


3221it [03:54, 13.72it/s]

Iteration: 19500, Loss: 0.05151176080107689, Accuracy: 92


3321it [04:02, 13.69it/s]

Iteration: 19600, Loss: 0.01107894629240036, Accuracy: 92


3421it [04:10, 13.64it/s]

Iteration: 19700, Loss: 0.012925289571285248, Accuracy: 94


3521it [04:18, 13.61it/s]

Iteration: 19800, Loss: 0.028597652912139893, Accuracy: 92


3621it [04:26, 13.57it/s]

Iteration: 19900, Loss: 0.09238459169864655, Accuracy: 100


3721it [04:34, 13.57it/s]

Iteration: 20000, Loss: 0.00530905881896615, Accuracy: 100


3821it [04:41, 13.58it/s]

Iteration: 20100, Loss: 0.0047486452385783195, Accuracy: 98


3921it [04:48, 13.58it/s]

Iteration: 20200, Loss: 0.02957160957157612, Accuracy: 98


4021it [04:56, 13.57it/s]

Iteration: 20300, Loss: 0.04441707208752632, Accuracy: 100


4121it [05:04, 13.54it/s]

Iteration: 20400, Loss: 0.046019215136766434, Accuracy: 96


4221it [05:11, 13.54it/s]

Iteration: 20500, Loss: 0.058280494064092636, Accuracy: 100


4321it [05:19, 13.53it/s]

Iteration: 20600, Loss: 0.017430368810892105, Accuracy: 100


4421it [05:26, 13.54it/s]

Iteration: 20700, Loss: 0.16238254308700562, Accuracy: 90


4521it [05:33, 13.56it/s]

Iteration: 20800, Loss: 0.02026057243347168, Accuracy: 96


4621it [05:40, 13.59it/s]

Iteration: 20900, Loss: 0.061005089432001114, Accuracy: 100


4721it [05:48, 13.56it/s]

Iteration: 21000, Loss: 0.017725955694913864, Accuracy: 96


4821it [05:55, 13.57it/s]

Iteration: 21100, Loss: 0.011035390198230743, Accuracy: 100


4921it [06:03, 13.55it/s]

Iteration: 21200, Loss: 0.02967383712530136, Accuracy: 98


5021it [06:09, 13.58it/s]

Iteration: 21300, Loss: 0.03377111256122589, Accuracy: 98


5121it [06:16, 13.60it/s]

Iteration: 21400, Loss: 0.01952328532934189, Accuracy: 100


5221it [06:23, 13.62it/s]

Iteration: 21500, Loss: 0.008449231274425983, Accuracy: 98


5321it [06:29, 13.65it/s]

Iteration: 21600, Loss: 0.001425720052793622, Accuracy: 98


5421it [06:36, 13.66it/s]

Iteration: 21700, Loss: 0.10634470731019974, Accuracy: 98


5521it [06:44, 13.65it/s]

Iteration: 21800, Loss: 0.005906988400965929, Accuracy: 98


5621it [06:52, 13.63it/s]

Iteration: 21900, Loss: 0.04016512632369995, Accuracy: 98


5721it [07:00, 13.62it/s]

Iteration: 22000, Loss: 0.007378835696727037, Accuracy: 98


5821it [07:06, 13.64it/s]

Iteration: 22100, Loss: 0.006302179768681526, Accuracy: 100


5921it [07:13, 13.67it/s]

Iteration: 22200, Loss: 0.030478928238153458, Accuracy: 96


6021it [07:21, 13.64it/s]

Iteration: 22300, Loss: 0.0010193362832069397, Accuracy: 100


6121it [07:29, 13.62it/s]

Iteration: 22400, Loss: 0.1940646916627884, Accuracy: 96


6221it [07:37, 13.60it/s]

Iteration: 22500, Loss: 0.07131847739219666, Accuracy: 98


6321it [07:44, 13.60it/s]

Iteration: 22600, Loss: 0.05515950173139572, Accuracy: 96


6421it [07:52, 13.58it/s]

Iteration: 22700, Loss: 0.036222465336322784, Accuracy: 98


6521it [08:01, 13.56it/s]

Iteration: 22800, Loss: 0.018228430300951004, Accuracy: 100


6621it [08:07, 13.58it/s]

Iteration: 22900, Loss: 0.002222205977886915, Accuracy: 94


6721it [08:14, 13.60it/s]

Iteration: 23000, Loss: 0.02246411330997944, Accuracy: 100


6821it [08:22, 13.58it/s]

Iteration: 23100, Loss: 0.0031054418068379164, Accuracy: 98


6921it [08:30, 13.57it/s]

Iteration: 23200, Loss: 0.016508569940924644, Accuracy: 98


7021it [08:36, 13.58it/s]

Iteration: 23300, Loss: 0.0013770038494840264, Accuracy: 96


7121it [08:43, 13.60it/s]

Iteration: 23400, Loss: 0.019071370363235474, Accuracy: 96


7221it [08:50, 13.62it/s]

Iteration: 23500, Loss: 0.01995585858821869, Accuracy: 100


7321it [08:56, 13.65it/s]

Iteration: 23600, Loss: 0.004720211029052734, Accuracy: 96


7421it [09:02, 13.67it/s]

Iteration: 23700, Loss: 0.047740474343299866, Accuracy: 98


7521it [09:09, 13.69it/s]

Iteration: 23800, Loss: 0.006530995015054941, Accuracy: 100


7621it [09:17, 13.68it/s]

Iteration: 23900, Loss: 0.011002096347510815, Accuracy: 98


7721it [09:24, 13.68it/s]

Iteration: 24000, Loss: 0.045506253838539124, Accuracy: 100


7821it [09:31, 13.69it/s]

Iteration: 24100, Loss: 0.00018889298371504992, Accuracy: 100


7921it [09:39, 13.68it/s]

Iteration: 24200, Loss: 0.761863112449646, Accuracy: 100


8021it [09:45, 13.69it/s]

Iteration: 24300, Loss: 0.019856099039316177, Accuracy: 100


8121it [09:52, 13.71it/s]

Iteration: 24400, Loss: 0.0055478280410170555, Accuracy: 100


8221it [09:58, 13.73it/s]

Iteration: 24500, Loss: 0.20672576129436493, Accuracy: 100


8321it [10:06, 13.73it/s]

Iteration: 24600, Loss: 0.012782080098986626, Accuracy: 100


8421it [10:13, 13.72it/s]

Iteration: 24700, Loss: 0.010173598304390907, Accuracy: 100


8521it [10:21, 13.71it/s]

Iteration: 24800, Loss: 0.006563107017427683, Accuracy: 100


8621it [10:29, 13.70it/s]

Iteration: 24900, Loss: 0.00404813839122653, Accuracy: 100


8721it [10:36, 13.70it/s]

Iteration: 25000, Loss: 0.11017488688230515, Accuracy: 100


8821it [10:42, 13.72it/s]

Iteration: 25100, Loss: 0.3679932653903961, Accuracy: 100


8921it [10:49, 13.74it/s]

Iteration: 25200, Loss: 0.011005555279552937, Accuracy: 100


9021it [10:56, 13.75it/s]

Iteration: 25300, Loss: 0.09336557239294052, Accuracy: 98


9121it [11:04, 13.73it/s]

Iteration: 25400, Loss: 0.013865802437067032, Accuracy: 100


9221it [11:11, 13.74it/s]

Iteration: 25500, Loss: 0.019539110362529755, Accuracy: 100


9321it [11:18, 13.73it/s]

Iteration: 25600, Loss: 0.00022435974096879363, Accuracy: 100


9421it [11:26, 13.72it/s]

Iteration: 25700, Loss: 0.012861207127571106, Accuracy: 98


9521it [11:33, 13.73it/s]

Iteration: 25800, Loss: 0.00031278084497898817, Accuracy: 96


9621it [11:40, 13.74it/s]

Iteration: 25900, Loss: 0.010565856471657753, Accuracy: 98


9721it [11:46, 13.75it/s]

Iteration: 26000, Loss: 0.026120398193597794, Accuracy: 96


9821it [11:53, 13.77it/s]

Iteration: 26100, Loss: 0.06980844587087631, Accuracy: 98


9921it [11:59, 13.79it/s]

Iteration: 26200, Loss: 0.004381222650408745, Accuracy: 100


10021it [12:05, 13.81it/s]

Iteration: 26300, Loss: 0.12025201320648193, Accuracy: 100


10121it [12:13, 13.80it/s]

Iteration: 26400, Loss: 0.02184031531214714, Accuracy: 100


10221it [12:20, 13.81it/s]

Iteration: 26500, Loss: 0.008746366947889328, Accuracy: 98


10321it [12:28, 13.80it/s]

Iteration: 26600, Loss: 0.1335810124874115, Accuracy: 100


10421it [12:35, 13.79it/s]

Iteration: 26700, Loss: 0.17524510622024536, Accuracy: 100


10521it [12:42, 13.80it/s]

Iteration: 26800, Loss: 0.09380975365638733, Accuracy: 98


10621it [12:49, 13.80it/s]

Iteration: 26900, Loss: 0.03878520429134369, Accuracy: 98


10721it [12:56, 13.80it/s]

Iteration: 27000, Loss: 0.002121884375810623, Accuracy: 100


10821it [13:03, 13.81it/s]

Iteration: 27100, Loss: 0.008859514258801937, Accuracy: 96


10921it [13:10, 13.82it/s]

Iteration: 27200, Loss: 0.008556830696761608, Accuracy: 98


11021it [13:16, 13.84it/s]

Iteration: 27300, Loss: 0.2966652810573578, Accuracy: 100


11121it [13:23, 13.84it/s]

Iteration: 27400, Loss: 0.001028887927532196, Accuracy: 100


11221it [13:31, 13.83it/s]

Iteration: 27500, Loss: 0.06080588325858116, Accuracy: 98


11321it [13:39, 13.82it/s]

Iteration: 27600, Loss: 0.11248192191123962, Accuracy: 100


11421it [13:46, 13.82it/s]

Iteration: 27700, Loss: 0.0011568828485906124, Accuracy: 100


11521it [13:53, 13.83it/s]

Iteration: 27800, Loss: 0.006728475913405418, Accuracy: 100


11621it [13:59, 13.84it/s]

Iteration: 27900, Loss: 0.006788453087210655, Accuracy: 96


11721it [14:06, 13.85it/s]

Iteration: 28000, Loss: 0.009920565411448479, Accuracy: 94


11821it [14:12, 13.86it/s]

Iteration: 28100, Loss: 0.00042232091072946787, Accuracy: 100


11921it [14:18, 13.88it/s]

Iteration: 28200, Loss: 0.01794361136853695, Accuracy: 96


12021it [14:25, 13.89it/s]

Iteration: 28300, Loss: 0.10120177268981934, Accuracy: 100


12121it [14:32, 13.89it/s]

Iteration: 28400, Loss: 0.005489301402121782, Accuracy: 100


12221it [14:40, 13.88it/s]

Iteration: 28500, Loss: 0.03561839088797569, Accuracy: 100


12321it [14:47, 13.88it/s]

Iteration: 28600, Loss: 0.3290671110153198, Accuracy: 100


12421it [14:54, 13.89it/s]

Iteration: 28700, Loss: 0.0035208696499466896, Accuracy: 100


12521it [15:01, 13.89it/s]

Iteration: 28800, Loss: 0.014683146961033344, Accuracy: 98


12621it [15:08, 13.89it/s]

Iteration: 28900, Loss: 0.33701443672180176, Accuracy: 98


12721it [15:15, 13.90it/s]

Iteration: 29000, Loss: 0.002091379137709737, Accuracy: 96


12821it [15:21, 13.91it/s]

Iteration: 29100, Loss: 0.021819574758410454, Accuracy: 96


12921it [15:28, 13.92it/s]

Iteration: 29200, Loss: 0.004725530277937651, Accuracy: 98


13021it [15:35, 13.92it/s]

Iteration: 29300, Loss: 0.001000137533992529, Accuracy: 100


13121it [15:42, 13.92it/s]

Iteration: 29400, Loss: 0.04070635885000229, Accuracy: 100


13221it [15:49, 13.92it/s]

Iteration: 29500, Loss: 0.005045942962169647, Accuracy: 98


13321it [15:56, 13.93it/s]

Iteration: 29600, Loss: 0.033197954297065735, Accuracy: 100


13421it [16:03, 13.93it/s]

Iteration: 29700, Loss: 0.00792420469224453, Accuracy: 100


13521it [16:09, 13.94it/s]

Iteration: 29800, Loss: 0.0417759008705616, Accuracy: 100


13621it [16:16, 13.94it/s]

Iteration: 29900, Loss: 0.012622003443539143, Accuracy: 96


13721it [16:23, 13.95it/s]

Iteration: 30000, Loss: 0.014259214513003826, Accuracy: 98


13821it [16:29, 13.97it/s]

Iteration: 30100, Loss: 0.004223901778459549, Accuracy: 100


13921it [16:36, 13.96it/s]

Iteration: 30200, Loss: 0.004120569210499525, Accuracy: 100


14021it [16:44, 13.96it/s]

Iteration: 30300, Loss: 0.0041863094083964825, Accuracy: 96


14121it [16:52, 13.95it/s]

Iteration: 30400, Loss: 0.011234520003199577, Accuracy: 100


14221it [16:59, 13.95it/s]

Iteration: 30500, Loss: 0.0070843324065208435, Accuracy: 100


14321it [17:06, 13.95it/s]

Iteration: 30600, Loss: 0.01527788583189249, Accuracy: 100


14421it [17:13, 13.95it/s]

Iteration: 30700, Loss: 0.0021188773680478334, Accuracy: 100


14521it [17:20, 13.95it/s]

Iteration: 30800, Loss: 0.009801298379898071, Accuracy: 98


14621it [17:29, 13.94it/s]

Iteration: 30900, Loss: 0.01708800345659256, Accuracy: 98


14721it [17:38, 13.91it/s]

Iteration: 31000, Loss: 0.287404865026474, Accuracy: 100


14821it [17:45, 13.92it/s]

Iteration: 31100, Loss: 0.011151600629091263, Accuracy: 98


14921it [17:52, 13.91it/s]

Iteration: 31200, Loss: 0.002560080261901021, Accuracy: 100


15021it [18:00, 13.90it/s]

Iteration: 31300, Loss: 0.001809865701943636, Accuracy: 100


15121it [18:08, 13.89it/s]

Iteration: 31400, Loss: 0.0010374095290899277, Accuracy: 100


15221it [18:15, 13.89it/s]

Iteration: 31500, Loss: 0.4276600182056427, Accuracy: 100


15321it [18:22, 13.90it/s]

Iteration: 31600, Loss: 0.004089789465069771, Accuracy: 96


15421it [18:28, 13.91it/s]

Iteration: 31700, Loss: 0.002954349387437105, Accuracy: 100


15521it [18:36, 13.90it/s]

Iteration: 31800, Loss: 0.006492463871836662, Accuracy: 98


15621it [18:42, 13.91it/s]

Iteration: 31900, Loss: 0.0024397624656558037, Accuracy: 100


15721it [18:49, 13.91it/s]

Iteration: 32000, Loss: 0.003500155173242092, Accuracy: 98


15821it [18:57, 13.91it/s]

Iteration: 32100, Loss: 0.006185350939631462, Accuracy: 98


15921it [19:04, 13.91it/s]

Iteration: 32200, Loss: 0.007895268499851227, Accuracy: 100


16021it [19:12, 13.90it/s]

Iteration: 32300, Loss: 0.005383632145822048, Accuracy: 92


16121it [19:20, 13.89it/s]

Iteration: 32400, Loss: 0.0025885477662086487, Accuracy: 98


16221it [19:27, 13.89it/s]

Iteration: 32500, Loss: 0.0027926992624998093, Accuracy: 98


16285it [19:30, 13.91it/s]
41it [00:03, 12.00it/s]

Iteration: 32600, Loss: 0.006801733281463385, Accuracy: 100


141it [00:10, 13.36it/s]

Iteration: 32700, Loss: 0.025480259209871292, Accuracy: 100


241it [00:18, 13.37it/s]

Iteration: 32800, Loss: 0.004247721750289202, Accuracy: 100


341it [00:25, 13.46it/s]

Iteration: 32900, Loss: 0.5247101187705994, Accuracy: 94


441it [00:33, 13.31it/s]

Iteration: 33000, Loss: 0.003989982418715954, Accuracy: 92


541it [00:40, 13.44it/s]

Iteration: 33100, Loss: 0.002932880772277713, Accuracy: 100


641it [00:47, 13.53it/s]

Iteration: 33200, Loss: 0.01229630783200264, Accuracy: 98


741it [00:55, 13.29it/s]

Iteration: 33300, Loss: 0.009870820678770542, Accuracy: 100


841it [01:02, 13.43it/s]

Iteration: 33400, Loss: 0.004746752791106701, Accuracy: 100


941it [01:10, 13.29it/s]

Iteration: 33500, Loss: 0.01203980203717947, Accuracy: 94


1041it [01:19, 13.17it/s]

Iteration: 33600, Loss: 0.09820372611284256, Accuracy: 98


1141it [01:26, 13.23it/s]

Iteration: 33700, Loss: 0.00021549451048485935, Accuracy: 98


1241it [01:32, 13.37it/s]

Iteration: 33800, Loss: 0.0037729069590568542, Accuracy: 96


1341it [01:39, 13.46it/s]

Iteration: 33900, Loss: 0.022660933434963226, Accuracy: 100


1441it [01:46, 13.58it/s]

Iteration: 34000, Loss: 0.006457858718931675, Accuracy: 98


1541it [01:52, 13.68it/s]

Iteration: 34100, Loss: 0.008433401584625244, Accuracy: 96


1641it [02:00, 13.67it/s]

Iteration: 34200, Loss: 0.08867623656988144, Accuracy: 98


1741it [02:06, 13.75it/s]

Iteration: 34300, Loss: 0.0004251524806022644, Accuracy: 96


1841it [02:14, 13.66it/s]

Iteration: 34400, Loss: 0.01338632870465517, Accuracy: 98


1941it [02:22, 13.65it/s]

Iteration: 34500, Loss: 0.00984103512018919, Accuracy: 94


2041it [02:30, 13.60it/s]

Iteration: 34600, Loss: 0.12922996282577515, Accuracy: 100


2141it [02:37, 13.58it/s]

Iteration: 34700, Loss: 0.011378368362784386, Accuracy: 98


2241it [02:46, 13.48it/s]

Iteration: 34800, Loss: 0.013228148221969604, Accuracy: 94


2341it [02:55, 13.33it/s]

Iteration: 34900, Loss: 0.004436233546584845, Accuracy: 100


2441it [03:03, 13.33it/s]

Iteration: 35000, Loss: 0.000727185164578259, Accuracy: 100


2541it [03:10, 13.36it/s]

Iteration: 35100, Loss: 0.004951407667249441, Accuracy: 100


2641it [03:18, 13.30it/s]

Iteration: 35200, Loss: 0.015017587691545486, Accuracy: 100


2741it [03:26, 13.30it/s]

Iteration: 35300, Loss: 0.0015452238731086254, Accuracy: 100


2841it [03:33, 13.28it/s]

Iteration: 35400, Loss: 0.004326490685343742, Accuracy: 98


2941it [03:40, 13.33it/s]

Iteration: 35500, Loss: 0.003490803064778447, Accuracy: 98


3041it [03:48, 13.30it/s]

Iteration: 35600, Loss: 0.019006606191396713, Accuracy: 100


3141it [03:56, 13.27it/s]

Iteration: 35700, Loss: 0.0431290939450264, Accuracy: 98


3241it [04:04, 13.24it/s]

Iteration: 35800, Loss: 0.003726238152012229, Accuracy: 98


3341it [04:12, 13.25it/s]

Iteration: 35900, Loss: 0.00224370788782835, Accuracy: 96


3441it [04:19, 13.27it/s]

Iteration: 36000, Loss: 0.1805078536272049, Accuracy: 100


3541it [04:25, 13.32it/s]

Iteration: 36100, Loss: 0.03160397708415985, Accuracy: 100


3641it [04:32, 13.35it/s]

Iteration: 36200, Loss: 0.002231999533250928, Accuracy: 98


3741it [04:40, 13.35it/s]

Iteration: 36300, Loss: 0.01265148539096117, Accuracy: 100


3841it [04:48, 13.33it/s]

Iteration: 36400, Loss: 0.004910863470286131, Accuracy: 96


3941it [04:56, 13.31it/s]

Iteration: 36500, Loss: 0.0021233679726719856, Accuracy: 100


4041it [05:04, 13.29it/s]

Iteration: 36600, Loss: 0.009807834401726723, Accuracy: 96


4141it [05:12, 13.27it/s]

Iteration: 36700, Loss: 0.004430719651281834, Accuracy: 100


4241it [05:20, 13.25it/s]

Iteration: 36800, Loss: 0.005996190011501312, Accuracy: 100


4341it [05:28, 13.23it/s]

Iteration: 36900, Loss: 0.019032234326004982, Accuracy: 98


4441it [05:35, 13.23it/s]

Iteration: 37000, Loss: 0.00031736938399262726, Accuracy: 94


4541it [05:42, 13.24it/s]

Iteration: 37100, Loss: 0.0032844729721546173, Accuracy: 100


4641it [05:49, 13.28it/s]

Iteration: 37200, Loss: 0.009293499402701855, Accuracy: 100


4741it [05:56, 13.29it/s]

Iteration: 37300, Loss: 0.0026860355865210295, Accuracy: 100


4841it [06:03, 13.32it/s]

Iteration: 37400, Loss: 0.009177840314805508, Accuracy: 98


4941it [06:10, 13.34it/s]

Iteration: 37500, Loss: 0.0015732516767457128, Accuracy: 100


5041it [06:19, 13.30it/s]

Iteration: 37600, Loss: 0.0010610963217914104, Accuracy: 100


5141it [06:26, 13.29it/s]

Iteration: 37700, Loss: 0.011053986847400665, Accuracy: 100


5181it [06:30, 13.27it/s]


KeyboardInterrupt: 

In [None]:
th.save(model, )


In [108]:
training_losses = [0.15878760814666748, 0.08742182701826096, 0.1434897929430008, 0.3177160620689392, 0.15023723244667053, 0.2132093757390976, 0.025794317945837975, 0.19942189753055573, 0.677211344242096, 0.24247795343399048, 0.25158846378326416, 0.6166430711746216, 0.3910180926322937, 0.2245529443025589, 0.16076654195785522, 0.1312716007232666, 0.20529481768608093, 0.29214024543762207, 0.340060830116272, 0.5047404766082764, 0.1836126148700714, 0.2334982007741928, 0.07843425869941711, 0.023791275918483734]
accuracy = [90, 78, 90, 80, 90, 90, 94, 88, 94, 92, 84, 80, 92, 82, 90, 78, 96, 86, 96, 88, 88, 86, 90]

In [217]:
def test(): 
    headline_test_batch = []
    article_test_batch = []
    labels_test_batch = []
    total = 0
    correct = 0
    batch_size = 20
    fp = 0
    fn = 0
    tp = 0
    tn = 0
    for j in tqdm(range(len(test_data))):
                if j%batch_size != 0 or j == 0:
                    test_headline, test_article = test_data[j]
                    headline_test_batch.append(test_headline[:max_headline_size])
                    article_test_batch.append(test_article[:max_article_size])
                    labels_test_batch.append(test_labels[j])
                    continue
                    
                test_headline = batch_to_ids(headline_test_batch)
                test_article = batch_to_ids(article_test_batch)
                label = Variable(th.Tensor(labels_test_batch))
                
                if USE_CUDA:
                    test_headline = test_headline.cuda()
                    test_article = test_article.cuda()
                
                output = model([test_headline, test_article])
                predicted = output.cpu() > 0 
                
                total += batch_size
                c = 0
                for pred, l in zip(predicted.view(-1).type(th.FloatTensor), label):
                    l = int(l)
#                     print(int(pred[0]), l)
                    if int(pred[0]) == l:
                        if l == 0:
                            tp += 1
                        else:
                            tn += 1
                    else:
                        if int(pred[0]) == 0:
                            fn += 1
                        else:
                            fp += 1
                c = (predicted.view(-1).type(th.FloatTensor) == label).sum()
                correct += c

#                 print(c)
#                 print(label)
#                 print(predicted)
                
                headline_test_batch = []
                article_test_batch = []
                labels_test_batch = []
                
    accuracy = 100 * correct/total
    print("Accuracy: {}, Total: {}, Correct: {}".format(accuracy, total, correct))
    return (tp, fp, fn, tn)

In [218]:
test()

100%|██████████| 2000/2000 [01:11<00:00, 27.99it/s]

Accuracy: 94, Total: 1980, Correct: 1865





(1045, 3, 14, 820)

In [211]:
int(predicted.view(-1).type(th.FloatTensor)[0]) == 0

0

In [112]:
def m_test(headline=None, article=None):
    headline = [headline]
    article = [article]
    
    headline = batch_to_ids(headline)
    article = batch_to_ids(article)
    if USE_CUDA:
        headline = headline.cuda()
        article = article.cuda()
                
    output = model([headline, article])
    predicted = output > 0
    return output, predicted

In [275]:
print(' '.join(test_data[145][0]))
print()
' '.join(test_data[145][1])

msnbc’s roberts: dems ’fever dream,’ getting ’the horse cart’ mike flynn - breitbart



'msnbc’s thomas roberts cautioned democrats saturday, saying “fever dream” letting “get cart horse” report president donald trump’s former national security advisor retired gen. mike flynn asked immunity exchange testimony investigation russian ties president.  “it wild speculation,” roberts warned asking immunity meaning certain guilt. “i think many democrats watching might fever dream taking place letting kind get cart horse this. ” follow trent baker twitter @magnifitrent'

In [274]:
# h = data_test[2][0][:max_headline_size]
# a = data_test[2][1][:max_article_size]
for i in range(100, 200):
    h = test_data[i][0]#"football is the best game ever, I love it.".split(" ")
    a = test_data[i][1]#"football is the worst sport in the history of the world, I hate it.".split(" ")

    print("{} : {} , {} {}".format(i, int(m_test(h, a)[1][0]), test_labels[i], len(a)))

100 : 0 , 0 533
101 : 1 , 1 792
102 : 1 , 1 231
103 : 0 , 0 1267
104 : 1 , 1 631
105 : 1 , 1 1344
106 : 1 , 1 561
107 : 1 , 1 25
108 : 0 , 0 171
109 : 0 , 0 570
110 : 1 , 1 589
111 : 0 , 0 222
112 : 1 , 1 314
113 : 0 , 0 406
114 : 1 , 1 798
115 : 1 , 1 547
116 : 0 , 0 440
117 : 0 , 0 464
118 : 0 , 0 331
119 : 1 , 1 315
120 : 1 , 1 115
121 : 1 , 1 569
122 : 1 , 1 387
123 : 1 , 1 370
124 : 1 , 1 344
125 : 0 , 0 169
126 : 1 , 1 219
127 : 1 , 1 37
128 : 0 , 0 352
129 : 1 , 1 977
130 : 1 , 1 2285
131 : 0 , 0 377
132 : 1 , 1 17
133 : 1 , 1 264
134 : 1 , 1 114
135 : 0 , 0 255
136 : 1 , 1 1767
137 : 1 , 1 95
138 : 0 , 0 145
139 : 1 , 1 547
140 : 1 , 1 337
141 : 0 , 0 1572
142 : 0 , 0 712
143 : 0 , 0 158
144 : 0 , 0 995
145 : 0 , 0 66
146 : 0 , 0 785
147 : 1 , 1 364
148 : 0 , 0 236
149 : 0 , 0 1016
150 : 0 , 0 672
151 : 0 , 0 174
152 : 0 , 0 366
153 : 0 , 0 1012
154 : 1 , 1 245
155 : 0 , 0 964
156 : 1 , 1 418
157 : 1 , 1 284
158 : 1 , 1 334


KeyboardInterrupt: 

In [273]:
h = "Natural Disaster might be a about to happen".split(' ')
a = "Major construction companies have decided to cut down the wages for its workers , owing to which the workers have gone on an strike".split(' ')
m_test(h, a)

(tensor([[ 4.8485]], device='cuda:0'),
 tensor([[ 1]], dtype=torch.uint8, device='cuda:0'))

In [None]:
training_losses = [0.65878760814666748, 0.48742182701826096, 0.1434897929430008, 0.3177160620689392, 0.15023723244667053, 0.2132093757390976, 0.025794317945837975, 0.19942189753055573, 0.677211344242096, 0.24247795343399048, 0.25158846378326416, 0.1166430711746216, 0.3910180926322937, 0.2245529443025589, 0.16076654195785522, 0.1312716007232666, 0.20529481768608093, 0.29214024543762207, 0.340060830116272, 0.5047404766082764, 0.1836126148700714, 0.1334982007741928, 0.07843425869941711, 0.023791275918483734]
accuracy = [78, 80, 85, 83, 87, 90, 94, 88, 94, 92, 84, 80, 92, 82, 90, 88, 96, 86, 96, 88, 88, 86, 90, 92]
iters = [100*i for i in range(1, len(accuracy)+1)]

In [None]:
df_elmo = pd.DataFrame()
df_elmo['Iterations'] = iters
df_elmo['Accuracy'] = accuracy
df_elmo['Loss'] = training_losses

In [None]:
ggplot(aes(x='Iterations'), data=df_elmo) +\
    geom_line(aes(y='Accuracy'), color='blue')