In [22]:
import os
import pickle
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np

from support import *

In [23]:
results = [pickle.load(open(f, 'rb')) for f in glob('./result-50p/*.pkl')]
results = [item for sublist in results for item in sublist]

In [24]:
np.max([r.result for r in results]), np.min([r.result for r in results])

(44963, 20)

In [25]:
broad_words = set([word for r in results for word in r.broad_words])
narrow_words = set([word for r in results for word in r.narrow_words])
broad_words_toi = {word: i for i, word in enumerate(broad_words)}
narrow_words_toi = {word: i for i, word in enumerate(narrow_words)}
broad_words_ito = {i: word for word, i in broad_words_toi.items()}
narrow_words_ito = {i: word for word, i in narrow_words_toi.items()}
pickle.dump(broad_words_toi, open('broad_words_toi.pkl', 'wb'))
pickle.dump(narrow_words_toi, open('narrow_words_toi.pkl', 'wb'))

In [26]:
samples = []
for r in results:
    broad = torch.zeros(len(broad_words))
    narrow = torch.zeros(len(narrow_words))
    for (word, sel) in zip(r.broad_words, r.broad_sel):
        broad[broad_words_toi[word]] = 1 if sel else -1
    for (word, sel) in zip(r.narrow_words, r.narrow_sel):
        narrow[narrow_words_toi[word]] = 1 if sel else -1
    samples.append((
        broad.to(torch.float),
        narrow.to(torch.float), 
        torch.tensor([r.result]).to(torch.float)
    ))
torch.save(samples, 'samples.pt')

In [27]:
samples = torch.load('samples.pt')
dataloader = DataLoader(samples, batch_size=128, shuffle=True)
samples.__len__(), len(broad_words), len(narrow_words)

(14472, 127, 608)

In [28]:
class Model(nn.Module):
    def __init__(self, n_broad, n_narrow):
        super(Model, self).__init__()
        self.broad = nn.Sequential(
            nn.Linear(n_broad, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        self.narrow = nn.Sequential(
            nn.Linear(n_narrow, 300),
            nn.ReLU(),
            nn.Linear(300, 150),
            nn.ReLU(),
            nn.Linear(150, 75),
            nn.ReLU(),
            nn.Linear(75, 32),
        )
        self.combined = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.ReLU(),
        )

    def forward(self, broad, narrow):
        broad = self.broad(broad)
        narrow = self.narrow(narrow)
        combined = torch.cat((broad, narrow), dim=1)
        return self.combined(combined)
    
model = Model(len(broad_words), len(narrow_words))
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.MSELoss()

In [29]:
losses = []

In [None]:
for epoch in tqdm(range(200)):
    los = []
    for broad, narrow, result in dataloader:
        result = result / 45000
        optimizer.zero_grad()
        output = model(broad, narrow)
        loss = criterion(output, result)
        los.append(loss.item())
        loss.backward()
        optimizer.step()
    losses.append(np.mean(los))
plt.plot(np.log(losses))

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
model(torch.ones(1, len(broad_words)), torch.ones(1, len(narrow_words))) * 45000

In [None]:
model(torch.ones(1, len(broad_words)) * -1, torch.ones(1, len(narrow_words)) * -1) * 45000

In [None]:
torch.save(model.state_dict(), 'model1.pt')