In [67]:
import pandas as pd
import nltk
import numpy as np
from torchtext.data import Field
import csv
from torchtext.data import TabularDataset
from torchtext.data import Iterator, BucketIterator

In [68]:
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = Field(sequential=False, use_vocab=False)

In [69]:
train_stances = pd.read_csv("fn_data/train_stances.csv")
train_bodies = pd.read_csv("fn_data/train_bodies.csv")

In [70]:
train = pd.read_csv("lstmstuff/train_data_consolidated.csv")
val = pd.read_csv("lstmstuff/val_data_consolidated.csv")

In [71]:
train_datafields = [("Body ID", None),
                 ("Headline", TEXT),
                    ("Body", TEXT),
                 ("unrelated", LABEL),
                ("related", LABEL)]

trn, vld = TabularDataset.splits(
               path="",
               train='lstmstuff/train_data_consolidated.csv', validation="lstmstuff/val_data_consolidated.csv",
               format='csv',
               skip_header=True,
               fields=train_datafields)

In [72]:
TEXT.build_vocab(trn, vectors = 'glove.6B.100d')

In [73]:
train_iter, val_iter = BucketIterator.splits(
 (trn, vld),
 batch_sizes=(64,64),
 sort_key=lambda x: len(x.Body), # the BucketIterator needs to be told what function it should use to group the data.
 sort_within_batch=False,
 repeat=False
)

In [74]:
import torch
class BatchGenerator:
    def __init__(self, dl, x_field1, x_field2, y_field):
        self.dl, self.x_field1, self.x_field2, self.y_field = dl, x_field1, x_field2, y_field
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for batch in self.dl:
            X1 = getattr(batch, self.x_field1)
            X2 = getattr(batch, self.x_field2)
            y = getattr(batch, self.y_field)
            yield (X1, X2, y)
            
train_batch_it = BatchGenerator(train_iter, 'Body','Headline', 'unrelated')
val_batch_it = BatchGenerator(val_iter, 'Body', 'Headline', 'unrelated')

In [75]:
import torch.nn as nn
import torch.nn.functional as F
# class RNN(nn.Module):
#     def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
#         super().__init__()
        
#         self.embedding = nn.Embedding(input_dim, embedding_dim)
#         self.lstm_headline = nn.LSTM(embedding_dim, hidden_dim)
#         self.lstm_article = nn.LSTM(embedding_dim, hidden_dim)
#         self.fc = nn.Linear(hidden_dim, output_dim)
        
#     def forward(self, x):

#         #x = [sent len, batch size]
        
#         embedded = self.embedding(x)
        
#         #embedded = [sent len, batch size, emb dim]
        
#         output, hidden = self.rnn(embedded)
        
#         #output = [sent len, batch size, hid dim]
#         #hidden = [1, batch size, hid dim]
        
#         assert torch.equal(output[-1,:,:], hidden.squeeze(0))
        
#         return self.fc(hidden.squeeze(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class LSTMClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, output_dim, batch_size):
        super(LSTMClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_headline = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)
        self.lstm_article = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True)
        self.fc = nn.Linear(hidden_dim*2, output_dim*8)
        self.fc2 = nn.Linear(output_dim*8, output_dim)
        self.batch_size = batch_size
        self.hidden1 = self.init_hidden()
        self.hidden2 = self.init_hidden()
        
    def init_hidden(self):
        h0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim))
        c0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim))
        return (h0, c0)
    
    def forward(self, headline, article):
        headline_emb = self.embed(headline)
        
        article_emb = self.embed(article)
        
        lstm_headline_out, hidden = self.lstm_headline(headline_emb, self.hidden)
        
        sum_padded_out_lstm1 = 0
        for tensor in torch.split(lstm_headline_out, self.hidden_dim, dim=2):
            sum_padded_out_lstm1 += tensor

        lstm_article_out, hidden2 = self.lstm_article(sum_padded_out_lstm1, self.hidden2)
        res = self.fc(torch.cat((lstm_article_out[-1, :, :self.hidden_dim], lstm_article_out[0, :, self.hidden_dim:]), 1))
        res2 = self.fc2(F.relu(res))
        return res2
        

In [76]:
# model = RNN(len(TEXT.vocab), 100, 256, 1)
model = LSTMClassifier(100, 8, len(TEXT.vocab), 1, 64)

In [77]:
inputs = torch.randn(3, requires_grad=True)
targets = torch.empty(3).random_(2)
print(inputs)
print(targets)

tensor([ 0.3296, -0.0837, -1.2117], requires_grad=True)
tensor([1., 0., 0.])


In [78]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

In [79]:
counter = []
loss_history = [] 
iteration_number= 0

model.train()
# Train the model
for epoch in range(0,2):
    for i, data in enumerate(train_batch_it,0):
        img0, img1 , label = data
        print("inputs are")
        print(img0.shape)
        print(img1.shape)
        optimizer.zero_grad()
        model.hidden = model.init_hidden()
        model.hidden2 = model.init_hidden()
        output = model(img0,img1)
        print("OUTPUT")
        print(output)
        print(label)
        loss = criterion(output.squeeze(),label.squeeze().float())
        loss.backward()
        optimizer.step()
        if True:
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss.item())


inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.3814],
        [0.3800],
        [0.3698],
        [0.3796],
        [0.3918],
        [0.3804],
        [0.3710],
        [0.3951],
        [0.4026],
        [0.3450],
        [0.3763],
        [0.3923],
        [0.3922],
        [0.3920],
        [0.4257],
        [0.3836],
        [0.3957],
        [0.3657],
        [0.3450],
        [0.3889],
        [0.3698],
        [0.3651],
        [0.3817],
        [0.4167],
        [0.3957],
        [0.3577],
        [0.3750],
        [0.3612],
        [0.4201],
        [0.3667],
        [0.3472],
        [0.3711],
        [0.4016],
        [0.3683],
        [0.3471],
        [0.3969],
        [0.4102],
        [0.3748],
        [0.3601],
        [0.3906],
        [0.3753],
        [0.3796],
        [0.3829],
        [0.4004],
        [0.3865],
        [0.3845],
        [0.3654],
        [0.3651],
        [0.3742],
        [0.4009],
        [0.3657],
        [0.3819],
  

Epoch number 0
 Current loss 0.6360768675804138

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.4290],
        [0.4600],
        [0.4259],
        [0.4178],
        [0.4262],
        [0.4622],
        [0.4087],
        [0.4302],
        [0.4121],
        [0.4548],
        [0.4318],
        [0.4298],
        [0.4123],
        [0.4096],
        [0.4153],
        [0.4386],
        [0.4280],
        [0.4654],
        [0.4487],
        [0.4293],
        [0.4016],
        [0.4368],
        [0.4534],
        [0.4290],
        [0.4065],
        [0.4334],
        [0.4244],
        [0.4344],
        [0.4405],
        [0.4254],
        [0.4153],
        [0.4517],
        [0.4262],
        [0.4188],
        [0.4147],
        [0.4480],
        [0.4262],
        [0.4218],
        [0.4709],
        [0.4142],
        [0.4483],
        [0.4159],
        [0.4257],
        [0.3814],
        [0.4125],
        [0.3957],
        [0.4187],
        [0.4242],
        [0.4508],
       

Epoch number 0
 Current loss 0.608489453792572

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.4616],
        [0.4889],
        [0.4963],
        [0.4943],
        [0.5140],
        [0.4836],
        [0.4940],
        [0.4959],
        [0.4743],
        [0.4601],
        [0.4815],
        [0.4666],
        [0.4669],
        [0.4156],
        [0.4658],
        [0.4863],
        [0.4578],
        [0.4720],
        [0.4571],
        [0.4883],
        [0.4944],
        [0.4597],
        [0.4921],
        [0.4726],
        [0.4968],
        [0.5092],
        [0.4566],
        [0.4628],
        [0.4447],
        [0.4940],
        [0.4899],
        [0.4886],
        [0.4932],
        [0.4754],
        [0.5187],
        [0.4677],
        [0.4700],
        [0.5045],
        [0.4614],
        [0.5165],
        [0.5035],
        [0.4446],
        [0.4986],
        [0.5004],
        [0.4797],
        [0.4648],
        [0.4183],
        [0.4858],
        [0.4711],
        

Epoch number 0
 Current loss 0.5757606625556946

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.4904],
        [0.5300],
        [0.5064],
        [0.5428],
        [0.5794],
        [0.5211],
        [0.5535],
        [0.4943],
        [0.5405],
        [0.5018],
        [0.5464],
        [0.5311],
        [0.5573],
        [0.4633],
        [0.5303],
        [0.5210],
        [0.5465],
        [0.5101],
        [0.5094],
        [0.5316],
        [0.5486],
        [0.4985],
        [0.5156],
        [0.4884],
        [0.4884],
        [0.5288],
        [0.5013],
        [0.5660],
        [0.4884],
        [0.5486],
        [0.4997],
        [0.5181],
        [0.5530],
        [0.5164],
        [0.5288],
        [0.5448],
        [0.4811],
        [0.4915],
        [0.5248],
        [0.4937],
        [0.5012],
        [0.5185],
        [0.5740],
        [0.5446],
        [0.5157],
        [0.5443],
        [0.4825],
        [0.5333],
        [0.5415],
       

Epoch number 0
 Current loss 0.616249680519104

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.5605],
        [0.5888],
        [0.5411],
        [0.5549],
        [0.5677],
        [0.5723],
        [0.5781],
        [0.5930],
        [0.5501],
        [0.5985],
        [0.5469],
        [0.5922],
        [0.5885],
        [0.6116],
        [0.5283],
        [0.5783],
        [0.6145],
        [0.6282],
        [0.5306],
        [0.6421],
        [0.6393],
        [0.6310],
        [0.5958],
        [0.5724],
        [0.6146],
        [0.6313],
        [0.5620],
        [0.5306],
        [0.5742],
        [0.5515],
        [0.5806],
        [0.5416],
        [0.5587],
        [0.5901],
        [0.6465],
        [0.5555],
        [0.5659],
        [0.5952],
        [0.5802],
        [0.5624],
        [0.6150],
        [0.5430],
        [0.5552],
        [0.5703],
        [0.5523],
        [0.6024],
        [0.5465],
        [0.6168],
        [0.5858],
        

Epoch number 0
 Current loss 0.5709412097930908

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.6091],
        [0.7089],
        [0.6803],
        [0.6430],
        [0.5946],
        [0.6556],
        [0.6826],
        [0.6869],
        [0.5934],
        [0.6170],
        [0.6778],
        [0.5977],
        [0.6518],
        [0.6633],
        [0.6632],
        [0.6518],
        [0.6488],
        [0.6965],
        [0.6635],
        [0.6010],
        [0.6541],
        [0.7177],
        [0.5985],
        [0.6561],
        [0.6423],
        [0.6778],
        [0.6142],
        [0.6138],
        [0.5878],
        [0.5742],
        [0.7175],
        [0.5963],
        [0.6290],
        [0.6000],
        [0.6295],
        [0.7040],
        [0.6512],
        [0.5650],
        [0.6484],
        [0.6350],
        [0.6970],
        [0.6389],
        [0.6148],
        [0.5895],
        [0.5827],
        [0.6055],
        [0.6412],
        [0.6331],
        [0.6277],
       

Epoch number 0
 Current loss 0.590815007686615

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.7107],
        [0.6749],
        [0.7720],
        [0.6896],
        [0.6753],
        [0.6597],
        [0.6419],
        [0.6902],
        [0.5983],
        [0.6462],
        [0.7329],
        [0.7822],
        [0.6750],
        [0.6211],
        [0.6554],
        [0.6711],
        [0.6381],
        [0.7469],
        [0.6869],
        [0.7737],
        [0.7461],
        [0.6451],
        [0.7012],
        [0.6902],
        [0.6636],
        [0.6983],
        [0.6933],
        [0.6851],
        [0.6711],
        [0.6974],
        [0.6270],
        [0.6820],
        [0.5872],
        [0.7124],
        [0.6063],
        [0.6798],
        [0.6808],
        [0.6320],
        [0.6400],
        [0.6807],
        [0.7899],
        [0.7173],
        [0.7375],
        [0.6105],
        [0.7663],
        [0.7007],
        [0.6381],
        [0.7588],
        [0.5837],
        

Epoch number 0
 Current loss 0.5709192156791687

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.8220],
        [0.7435],
        [0.7896],
        [0.8616],
        [0.7819],
        [0.6435],
        [0.8389],
        [0.7061],
        [0.8190],
        [0.8493],
        [0.7613],
        [0.7833],
        [0.7787],
        [0.7121],
        [0.7384],
        [0.7095],
        [0.7900],
        [0.7177],
        [0.6503],
        [0.6468],
        [0.6459],
        [0.8171],
        [0.7159],
        [0.7510],
        [0.6205],
        [0.8029],
        [0.7935],
        [0.7482],
        [0.7852],
        [0.8447],
        [0.7809],
        [0.7306],
        [0.6360],
        [0.6868],
        [0.6732],
        [0.6256],
        [0.6696],
        [0.7922],
        [0.8020],
        [0.7997],
        [0.7000],
        [0.6781],
        [0.7331],
        [0.7315],
        [0.6961],
        [0.6999],
        [0.8693],
        [0.6869],
        [0.6962],
       

Epoch number 0
 Current loss 0.5559014081954956

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.7425],
        [0.7417],
        [0.7388],
        [0.8396],
        [0.7458],
        [0.8150],
        [0.7576],
        [0.7308],
        [0.6919],
        [0.8089],
        [0.7462],
        [0.7816],
        [0.7664],
        [0.6399],
        [0.8019],
        [0.8509],
        [0.7736],
        [0.7409],
        [0.8320],
        [0.7259],
        [0.7781],
        [0.7156],
        [0.8002],
        [0.8082],
        [0.9402],
        [0.8474],
        [0.8662],
        [0.8509],
        [0.7937],
        [0.8956],
        [0.7618],
        [0.6860],
        [0.7749],
        [0.8242],
        [0.6949],
        [0.9409],
        [0.7763],
        [0.8391],
        [0.8644],
        [0.8707],
        [0.7255],
        [0.8415],
        [0.7983],
        [0.7897],
        [0.6983],
        [0.7284],
        [0.7397],
        [0.8237],
        [0.8666],
       

Epoch number 0
 Current loss 0.6116665601730347

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.8145],
        [0.8526],
        [0.9647],
        [0.8175],
        [0.9893],
        [0.7812],
        [0.8602],
        [0.9494],
        [0.7537],
        [0.8649],
        [0.9586],
        [0.7352],
        [0.9808],
        [0.9128],
        [0.8074],
        [0.8929],
        [0.8375],
        [0.8751],
        [0.9521],
        [0.8602],
        [0.8572],
        [0.8434],
        [0.8711],
        [0.8378],
        [0.8479],
        [0.9356],
        [0.8330],
        [0.8439],
        [0.9782],
        [0.8361],
        [1.0040],
        [0.7828],
        [0.8772],
        [0.8803],
        [0.8422],
        [0.8772],
        [0.9056],
        [0.9298],
        [0.9243],
        [1.0301],
        [0.9257],
        [0.8500],
        [0.9372],
        [0.8616],
        [0.8419],
        [0.9022],
        [0.9946],
        [0.9587],
        [0.7874],
       

Epoch number 0
 Current loss 0.559809148311615

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0047],
        [0.9392],
        [0.9183],
        [0.9019],
        [1.0342],
        [1.0073],
        [1.1358],
        [0.8880],
        [0.9375],
        [0.9292],
        [0.8837],
        [1.0641],
        [0.9458],
        [0.8437],
        [0.9372],
        [1.0547],
        [0.8737],
        [1.0168],
        [0.8986],
        [1.0142],
        [0.8633],
        [0.8253],
        [0.8104],
        [0.7746],
        [1.0698],
        [0.8981],
        [0.9717],
        [0.8777],
        [0.7733],
        [0.8266],
        [0.8376],
        [0.7837],
        [0.9285],
        [0.8626],
        [1.0292],
        [1.0655],
        [0.9185],
        [0.9458],
        [0.8888],
        [0.8333],
        [1.0036],
        [1.0468],
        [0.9137],
        [0.8550],
        [0.8738],
        [0.8572],
        [1.0002],
        [0.9347],
        [0.8668],
        

Epoch number 0
 Current loss 0.472965806722641

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0232],
        [1.0852],
        [0.9534],
        [1.0278],
        [0.9648],
        [1.0816],
        [0.9688],
        [1.1409],
        [0.9806],
        [0.8685],
        [1.1084],
        [0.8098],
        [1.1066],
        [1.1084],
        [0.9032],
        [1.1217],
        [1.0245],
        [0.9324],
        [1.0703],
        [1.0186],
        [0.8894],
        [0.7529],
        [0.9476],
        [0.8971],
        [0.9707],
        [1.1070],
        [0.9110],
        [0.9478],
        [1.0366],
        [1.1218],
        [1.1066],
        [0.9186],
        [1.0690],
        [0.9337],
        [1.1066],
        [1.1066],
        [1.0147],
        [0.9692],
        [0.9858],
        [0.9631],
        [1.0989],
        [1.0443],
        [1.0102],
        [0.9551],
        [0.9439],
        [0.9988],
        [0.9757],
        [0.8651],
        [0.9891],
        

Epoch number 0
 Current loss 0.5250301957130432

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.9939],
        [0.9919],
        [0.9257],
        [0.9494],
        [0.8767],
        [1.1846],
        [1.1716],
        [0.9366],
        [0.9909],
        [0.8322],
        [1.1179],
        [0.9715],
        [1.0850],
        [1.1846],
        [1.1031],
        [0.9553],
        [0.9452],
        [1.0051],
        [0.9886],
        [1.0942],
        [0.7122],
        [1.0027],
        [1.0267],
        [1.1447],
        [0.9661],
        [0.9713],
        [1.1016],
        [0.9285],
        [1.0282],
        [1.0989],
        [1.1574],
        [0.8104],
        [0.9417],
        [0.9939],
        [0.9407],
        [0.9196],
        [1.0860],
        [1.0066],
        [1.0258],
        [0.9960],
        [1.1478],
        [0.9708],
        [0.9127],
        [1.0243],
        [0.9665],
        [0.9917],
        [1.0606],
        [1.0192],
        [1.1317],
       

Epoch number 0
 Current loss 0.612392246723175

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.9890],
        [0.9374],
        [0.9984],
        [0.9845],
        [1.1155],
        [1.0201],
        [1.0045],
        [0.8817],
        [1.1590],
        [0.9597],
        [1.1578],
        [1.2164],
        [1.0892],
        [1.0078],
        [1.0108],
        [1.0723],
        [1.2197],
        [1.1660],
        [1.2099],
        [1.0339],
        [1.0072],
        [1.1446],
        [1.0256],
        [1.1415],
        [1.1039],
        [1.2070],
        [1.0786],
        [0.9529],
        [1.0935],
        [1.0068],
        [1.1270],
        [1.0771],
        [1.0490],
        [1.2649],
        [1.0723],
        [1.0379],
        [1.2158],
        [1.0905],
        [1.2064],
        [1.0250],
        [1.0500],
        [1.2510],
        [1.2151],
        [1.2064],
        [1.2038],
        [1.2314],
        [1.2614],
        [1.1469],
        [1.0494],
        

Epoch number 0
 Current loss 0.6609072089195251

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0516],
        [1.0571],
        [1.0070],
        [1.0970],
        [1.0527],
        [1.0979],
        [1.2747],
        [1.1632],
        [1.0240],
        [1.0783],
        [1.0603],
        [1.2201],
        [1.1614],
        [1.1373],
        [1.2103],
        [0.9751],
        [1.0057],
        [1.0726],
        [1.0343],
        [1.0222],
        [1.0691],
        [1.2336],
        [1.0971],
        [1.1592],
        [1.0783],
        [0.9589],
        [1.1775],
        [1.2177],
        [1.1509],
        [1.2838],
        [1.1708],
        [1.0304],
        [1.1934],
        [1.2025],
        [1.3052],
        [1.3674],
        [1.0981],
        [1.3130],
        [1.1002],
        [1.2527],
        [0.9827],
        [1.1706],
        [1.1173],
        [1.1257],
        [1.1111],
        [1.2646],
        [1.0618],
        [1.1620],
        [1.0204],
       

Epoch number 0
 Current loss 0.5414247512817383

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.1954],
        [1.2613],
        [1.2810],
        [1.1498],
        [1.1193],
        [1.0785],
        [1.0597],
        [1.2544],
        [1.0659],
        [1.2282],
        [1.0961],
        [1.1363],
        [1.2743],
        [1.1516],
        [1.0969],
        [1.1923],
        [1.1954],
        [1.0721],
        [1.0685],
        [1.0347],
        [1.0862],
        [1.2476],
        [1.0525],
        [1.0407],
        [1.0654],
        [1.1209],
        [1.0018],
        [1.0479],
        [1.3116],
        [1.0785],
        [0.7887],
        [1.2997],
        [0.9551],
        [1.3068],
        [1.0705],
        [1.0967],
        [1.2878],
        [0.8551],
        [1.0275],
        [1.0659],
        [1.3431],
        [1.1952],
        [1.1007],
        [1.2581],
        [0.9812],
        [1.0382],
        [1.1106],
        [1.0523],
        [1.3106],
       

Epoch number 0
 Current loss 0.5013499855995178

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.2807],
        [1.0970],
        [1.0867],
        [1.0681],
        [1.0524],
        [0.9788],
        [1.1839],
        [1.0730],
        [1.2238],
        [1.3046],
        [1.2385],
        [1.0511],
        [1.2125],
        [1.2241],
        [1.1860],
        [0.8961],
        [1.2141],
        [0.9333],
        [1.0326],
        [1.1397],
        [1.0318],
        [1.0572],
        [0.9842],
        [1.1116],
        [1.1517],
        [1.1996],
        [1.1025],
        [1.2477],
        [1.1589],
        [1.0457],
        [1.0733],
        [1.2138],
        [1.0451],
        [1.0954],
        [1.1070],
        [0.9938],
        [1.1281],
        [0.9189],
        [1.2273],
        [0.7471],
        [0.9220],
        [1.2518],
        [1.1831],
        [1.1123],
        [1.2133],
        [1.0142],
        [1.0761],
        [1.0789],
        [1.1435],
       

Epoch number 0
 Current loss 0.7115350961685181

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.2707],
        [0.7594],
        [1.1067],
        [1.0072],
        [1.1052],
        [0.9946],
        [1.0741],
        [0.9000],
        [1.1296],
        [1.0782],
        [1.2129],
        [1.1582],
        [1.1206],
        [1.2527],
        [1.2435],
        [1.0923],
        [0.8849],
        [0.9455],
        [1.0140],
        [0.9734],
        [1.1261],
        [1.0477],
        [1.1594],
        [1.1827],
        [0.9491],
        [1.1998],
        [1.3160],
        [1.0805],
        [1.2430],
        [0.9793],
        [1.0140],
        [0.9918],
        [1.1421],
        [1.2000],
        [1.0570],
        [1.0083],
        [0.9910],
        [1.0756],
        [1.1600],
        [1.0102],
        [1.0331],
        [0.9846],
        [0.8481],
        [1.1963],
        [1.1080],
        [1.2298],
        [1.1512],
        [1.0911],
        [1.0102],
       

Epoch number 0
 Current loss 0.6159089803695679

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.1508],
        [1.0189],
        [1.0561],
        [1.0723],
        [1.1089],
        [1.0887],
        [1.0465],
        [1.1973],
        [0.9754],
        [1.2071],
        [1.1005],
        [0.9919],
        [1.1293],
        [1.1770],
        [1.1832],
        [1.0214],
        [0.9795],
        [1.0782],
        [0.8238],
        [1.1737],
        [0.8371],
        [1.0974],
        [0.9554],
        [1.0340],
        [0.9812],
        [1.1579],
        [1.0562],
        [1.0252],
        [0.9533],
        [1.0853],
        [0.9222],
        [0.9873],
        [1.0706],
        [1.1889],
        [1.0358],
        [0.9130],
        [0.9030],
        [1.2431],
        [1.0508],
        [0.9432],
        [1.0180],
        [1.1838],
        [1.2874],
        [0.9743],
        [1.0750],
        [1.0506],
        [1.0756],
        [1.1330],
        [1.1277],
       

Epoch number 0
 Current loss 0.6256895065307617

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[0.9387],
        [1.1265],
        [0.6920],
        [0.9701],
        [0.8686],
        [1.0343],
        [1.0133],
        [1.0425],
        [1.2357],
        [0.9931],
        [1.0356],
        [1.0005],
        [0.8880],
        [1.0176],
        [0.9846],
        [1.0887],
        [1.1442],
        [1.1622],
        [0.9642],
        [0.8368],
        [1.1246],
        [0.9142],
        [1.0322],
        [0.8494],
        [1.0675],
        [0.7929],
        [0.9850],
        [1.2123],
        [1.2252],
        [0.8474],
        [1.1432],
        [1.0547],
        [1.0173],
        [0.9219],
        [1.0825],
        [1.0251],
        [1.0373],
        [0.9999],
        [0.8977],
        [1.2075],
        [1.2163],
        [1.0891],
        [0.8424],
        [0.9771],
        [0.8686],
        [1.1173],
        [1.0217],
        [1.0597],
        [1.0992],
       

Epoch number 0
 Current loss 0.5452011823654175

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0373],
        [0.8032],
        [0.9157],
        [1.1456],
        [1.0416],
        [0.9839],
        [0.8752],
        [1.1456],
        [0.9890],
        [0.9422],
        [1.1456],
        [0.8703],
        [0.8238],
        [0.9605],
        [0.9333],
        [0.9226],
        [0.9075],
        [0.8390],
        [0.8183],
        [0.9071],
        [0.9941],
        [0.9347],
        [1.1976],
        [1.0055],
        [0.8547],
        [0.8314],
        [1.0075],
        [1.1493],
        [1.0109],
        [1.1444],
        [0.7941],
        [1.0326],
        [0.9162],
        [1.0087],
        [0.8018],
        [1.0865],
        [0.8203],
        [0.8703],
        [0.8589],
        [0.9435],
        [0.7736],
        [1.0326],
        [0.8702],
        [1.0226],
        [0.9629],
        [0.8632],
        [0.6763],
        [0.8342],
        [1.0150],
       

Epoch number 0
 Current loss 0.47686469554901123

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0246],
        [1.0485],
        [0.9686],
        [1.2129],
        [0.8047],
        [1.1212],
        [0.8999],
        [1.0114],
        [0.9356],
        [1.0837],
        [0.8231],
        [1.1191],
        [1.0228],
        [0.9350],
        [1.0488],
        [1.1536],
        [1.0671],
        [1.1544],
        [1.0395],
        [0.9105],
        [1.0063],
        [0.9383],
        [0.9675],
        [0.8752],
        [1.1339],
        [1.0103],
        [0.9841],
        [1.0248],
        [1.1274],
        [1.1286],
        [1.0940],
        [1.0217],
        [1.0063],
        [0.9631],
        [0.8895],
        [1.2184],
        [1.0586],
        [1.0241],
        [0.8329],
        [0.9198],
        [1.1243],
        [1.1037],
        [1.1201],
        [0.9616],
        [0.8184],
        [1.1313],
        [1.1178],
        [0.9582],
        [0.8179],
      

Epoch number 0
 Current loss 0.6027752161026001

inputs are
torch.Size([200, 64])
torch.Size([200, 64])
OUTPUT
tensor([[1.0130],
        [1.1396],
        [1.2541],
        [1.0997],
        [1.0449],
        [0.9866],
        [1.0865],
        [0.9875],
        [1.2381],
        [1.3201],
        [1.1638],
        [1.2690],
        [0.9089],
        [1.0837],
        [1.1438],
        [1.0327],
        [1.0174],
        [0.9503],
        [0.7661],
        [1.0507],
        [1.0596],
        [1.1978],
        [1.0612],
        [0.9126],
        [1.1602],
        [1.0380],
        [1.0397],
        [1.1458],
        [0.8561],
        [1.2139],
        [0.9896],
        [1.1205],
        [1.1632],
        [1.0984],
        [1.1438],
        [1.0927],
        [0.8839],
        [1.1602],
        [1.1395],
        [1.1036],
        [0.8854],
        [1.1074],
        [1.0360],
        [0.9492],
        [0.8649],
        [0.6101],
        [0.9924],
        [0.8447],
        [1.1481],
       

KeyboardInterrupt: 

In [None]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum()/len(correct)
    return acc

In [None]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    batch_num = 0
    with torch.no_grad():
        for i, data in enumerate(iterator,0):
            img0, img1 , label = data
            output = model(img0,img1)
            print(output)
            accuracy = binary_accuracy(output, label)
            print(accuracy)
            print("*********")

In [92]:
evaluate(model, val_batch_it, criterion)


tensor([[[0.9185],
         [0.9187],
         [0.9170],
         [0.9189],
         [0.9106],
         [0.9188],
         [0.9092],
         [0.8995],
         [0.8916],
         [0.9031],
         [0.9194],
         [0.9069],
         [0.9171],
         [0.9066],
         [0.9158],
         [0.9192],
         [0.9131],
         [0.9097],
         [0.9165],
         [0.8739],
         [0.9172],
         [0.9167],
         [0.9177],
         [0.9173],
         [0.9176],
         [0.9158],
         [0.9020],
         [0.9132],
         [0.7335],
         [0.9157],
         [0.9154],
         [0.8896],
         [0.9159],
         [0.9177],
         [0.9068],
         [0.9103],
         [0.9181],
         [0.8871],
         [0.9093],
         [0.9139],
         [0.9161],
         [0.9128],
         [0.9160],
         [0.8978],
         [0.8889],
         [0.9132],
         [0.9145],
         [0.8390],
         [0.9104],
         [0.9141],
         [0.9147],
         [0.9195],
         [0.

RuntimeError: Expected object of type torch.FloatTensor but found type torch.LongTensor for argument #2 'other'