In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from Vocab import *
from model import AttentionModel

import pandas as pd
import os
import string
import random

In [11]:
df = pd.read_csv('stsds.csv')

In [12]:
df = df.sample(frac=1)
df.head()

Unnamed: 0,sentence_A,sentence_B,relatedness_score
2898,The people are walking on the road beside a be...,A waterfall is flowing calmly into a shallow pool,0.54
4609,The Collections API is a set of classes and in...,The Collections API is a set of classes and in...,1.0
3075,A man is holding a mask in his raised hand,A man elegantly dressed in black is wearing an...,0.64
3077,Two children are crouching under some metal bars,Two children are leaning on a rusty ledge,0.5
1914,A man is cutting a potato,There is no man cutting a potato,0.72


In [13]:
#Hyperparams
lr = 1
gamma = 0.95
embed_size = 128
hidden_size = 256
num_epochs = 20

In [14]:
textcat = open('stsds-cat.txt').read()
textcat[:400]

'a group of kids is playing in a yard and an old man is standing in the background a group of boys in a yard is playing and a man is standing in the background a group of children is playing in the house and there is no man standing in the background a group of kids is playing in a yard and an old man is standing in the background the young boys are playing outdoors and the man is smiling nearby th'

In [15]:
vocab = Vocabulary(textcat)

In [16]:
vocab_size = vocab.size()
print(vocab_size)

2394


In [17]:
model = AttentionModel(embed_size, hidden_size, vocab_size)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

In [24]:
for e in range(num_epochs):
    total_loss = 0
    for i in range(len(df)):
        optimizer.zero_grad()
        
        t_a = torch.tensor(vocab.getSentenceArray(df["sentence_A"][i]))
        t_b = torch.tensor(vocab.getSentenceArray(df["sentence_B"][i]))
        
        out = model(t_a, t_b)
        loss = criterion(out, torch.tensor(df["relatedness_score"][i]).unsqueeze(0).unsqueeze(0))
        loss.backward()
        total_loss += loss.item()
        
        optimizer.step()

    for param_group in optimizer.param_groups:
        param_group['lr'] *= gamma

    print("Epoch", e, "Loss", total_loss)

  corr_attn_params = F.softmax(self.lin_attn(corr).view(1, -1))
  test_attn_params = F.softmax(self.lin_attn(test).view(1, -1))


Epoch 0 Loss 13.115756695424169
Epoch 1 Loss 12.599229694126262
Epoch 2 Loss 12.144656993464086
Epoch 3 Loss 11.671067405861818
Epoch 4 Loss 11.175695365036088
Epoch 5 Loss 10.767326957229649
Epoch 6 Loss 10.41033419956874
Epoch 7 Loss 10.24638896744735
Epoch 8 Loss 9.851840242232063
Epoch 9 Loss 9.646162262089256
Epoch 10 Loss 9.328080380828697
Epoch 11 Loss 9.078449785744514
Epoch 12 Loss 8.821452577373062
Epoch 13 Loss 8.63097195883105
Epoch 14 Loss 8.446839044158683
Epoch 15 Loss 8.299452829458549
Epoch 16 Loss 8.16892458541787
Epoch 17 Loss 8.042206304182194
Epoch 18 Loss 7.934978671360413
Epoch 19 Loss 7.825563083080411


In [25]:
torch.save(model, "saved_models/stsds4.pt")

In [20]:
mini = 1
for i in range(len(df)):
    if df["relatedness_score"][i] < mini:
        mini = df["relatedness_score"][i]
print(mini)

0.2


In [21]:
vocab.vocab[:100]

['a',
 'group',
 'of',
 'kid',
 'be',
 'play',
 'in',
 'yard',
 'and',
 'an',
 'old',
 'man',
 'stand',
 'the',
 'background',
 'boy',
 'child',
 'house',
 'there',
 'no',
 'young',
 'outdoors',
 'smile',
 'nearby',
 'near',
 'with',
 'two',
 'dog',
 'fight',
 'wrestle',
 'hug',
 'brown',
 'attack',
 'another',
 'animal',
 'front',
 'pant',
 'nobody',
 'rid',
 'bicycle',
 'on',
 'one',
 'wheel',
 'person',
 'black',
 'jacket',
 'do',
 'trick',
 'motorbike',
 'jersey',
 'dunk',
 'ball',
 'at',
 'basketball',
 'game',
 'by',
 'who',
 'into',
 'net',
 'crowd',
 'player',
 'people',
 'kickboxing',
 'spectator',
 'not',
 'watch',
 'woman',
 'spar',
 'match',
 'three',
 'jump',
 'leave',
 'sit',
 'red',
 'shirt',
 'angel',
 'make',
 'snow',
 'lie',
 'draw',
 'snowsuit',
 'wear',
 'costume',
 'gather',
 'forest',
 'look',
 'same',
 'direction',
 'mask',
 'scatter',
 'different',
 'some',
 'vicinity',
 'little',
 'girl',
 'like',
 'lone',
 'biker',
 'air',
 'alone']