In [1]:
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.functional import binary_cross_entropy_with_logits as BCElogits
from torch.nn.functional import softmax

import numpy as np

from tqdm import tqdm


class Model(nn.Module):
    def __init__(self, 
                 #n_hidden, 
                 dim_input, 
                 dim_hidden, 
                 dim_output=None,
                 temp=1.):
        dim_output = dim_output or dim_input
        super().__init__()
        self.temperature = temp
        layers = [
            nn.Linear(dim_input, dim_hidden), 
            nn.GELU(),
            nn.Linear(dim_hidden, dim_hidden), 
            nn.GELU(),
            nn.LayerNorm(dim_hidden),
            nn.Dropout(),
            nn.Linear(dim_hidden, dim_hidden), 
            nn.GELU(), 
            nn.Dropout(),
            nn.Linear(dim_hidden, dim_output)
            ]

        self.layers = nn.Sequential(*layers)
        self.mlp_ln = nn.LayerNorm(dim_output)

    def forward(self, x):
        x = self.mlp_ln(self.layers(x))/self.temperature
        return x
    

class SearchDataset(Dataset):

    def __init__(self, source_data, device):
        self.data = []
        self.labels = []
        for date, _, source_label, embed in source_data:
            self.labels.append(self.make_label(source_label))
            self.data.append(self.prepend_date(date, embed))
        self.data = np.array(self.data, dtype=np.float32)
        self.labels = np.array(self.labels, dtype=np.float32)

        self.data = torch.from_numpy(self.data).to(device)
        self.labels = torch.from_numpy(self.labels).to(device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def prepend_date(self, date, embed):
        date = date.split("-")
        date[0] = date[0][2:]
        date_list = [int(i)/100. for i in date]
        return date_list + embed

    def make_label(self, label):
        label_np = np.zeros(30+20+20)
        if label[0] == -1:
            for i in range(0, label_np.shape[0], 10):
                label_np[i] = 1
            return label_np
        year = str(label[0])[1:]
        for i in range(len(year)):
            label_np[i*10+int(year[i])] = 1

        month = str(label[1]).rjust(2, "0")
        for i in range(len(month)):
            label_np[30+i*10+int(month[i])] = 1

        day = str(label[2]).rjust(2, "0")
        for i in range(len(day)):
            label_np[50+i*10+int(day[i])] = 1
        return label_np


def save_model(model, optim, pth):
  torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            }, f"{pth}.ckpt")
  

def load_model(model, optim, pth):
  checkpoint = torch.load(f"{pth}.ckpt")
  model.load_state_dict(checkpoint['model_state_dict'])
  optim.load_state_dict(checkpoint['optimizer_state_dict'])
  return model, optim


def make_prediction(logits):
    year1 = torch.argmax(softmax(logits[:10], dim=0))
    year2 = torch.argmax(softmax(logits[10:20], dim=0))
    year3 = torch.argmax(softmax(logits[20:30], dim=0))

    month1 = torch.argmax(softmax(logits[30:40], dim=0))
    month2 = torch.argmax(softmax(logits[40:50], dim=0))

    day1 = torch.argmax(softmax(logits[50:60], dim=0))
    day2 = torch.argmax(softmax(logits[60:70], dim=0))

    return f"2{year1}{year2}{year3}-{month1}{month2}-{day1}{day2}"


def train_one_epoch(model,
                    optimizer, 
                    training_loader, 
                    loss_fn):
    running_loss = 0.

    for i, data in enumerate(training_loader):
        inputs, labels = data

        optimizer.zero_grad()

        logits = model(inputs)

        loss = loss_fn(logits, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    return running_loss/i


def get_available_device():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    return torch.device(device)


Main settings

In [2]:
DATA_PATH = "full_samples.pkl"
BATCH_SIZE = 32
TRAIN_SHARE = 0.8
EPOCHS = 1000
SAVE_PATH = "search_model.ckpt"
DATA_PATH = "full_samples.pkl"
DEVICE = get_available_device()

  return torch._C._cuda_getDeviceCount() > 0


Load data and create training and validation sets and dataloaders

In [3]:
with open(DATA_PATH, "rb") as f:
  full_samples = pickle.load(f)

full_set = SearchDataset(full_samples, DEVICE)
train_set, val_set = random_split(full_set, [TRAIN_SHARE, 1-TRAIN_SHARE])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)

Create model and optimizer

In [4]:
model = Model(train_set[0][0].shape[0], train_set[0][0].shape[0]*4, 70).to(DEVICE)
optim = torch.optim.Adam(model.parameters())

Train model.

In [None]:

best_vloss = 1_000_000.

loss_fn = BCElogits

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    model.train(True)
    avg_loss = train_one_epoch(model, 
                               optim, 
                               train_loader, 
                               loss_fn)

    model.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(val_loader):
        vinputs, vlabels = vdata
        voutputs = model(vinputs)
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print(f'LOSS train {avg_loss} valid {avg_vloss}')

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        save_model(model, optim, SAVE_PATH)


Check accuracy on training and validation sets. Load the best checkpoint instead of using the last model state from the training loop.

In [None]:
model = Model(train_set[0][0].shape[0], train_set[0][0].shape[0]*4, 70).to(DEVICE)
optim = torch.optim.Adam(model.parameters())
load_model(model, optim, SAVE_PATH)

model.eval()


with torch.no_grad():
    results = []
    for inp, lab in tqdm(train_set):
        results.append(make_prediction(model(inp)) == make_prediction(lab))
    print (f"Training set accuracy: {sum(results)/len(results)}")

    results = []
    for inp, lab in tqdm(val_set):
        results.append(make_prediction(model(inp)) == make_prediction(lab))
    print (f"Validation set accuracy: {sum(results)/len(results)}")