In [527]:
import torch.nn as nn
from torch import optim
from tqdm import tqdm_notebook as tqdm
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset, TensorDataset
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [528]:
class Encoder(nn.Module):
    def __init__(self, num_features, num_timesteps, hidden_size):
        super(Encoder, self).__init__()
        
        self.num_features = num_features
        self.num_timesteps = num_timesteps
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(num_features, hidden_size)
        
    def forward(self, inputs):
        """
        :params inputs: Input time series data of shape(batch_size, num_timesteps, num_features)
        """
        inputs = inputs.permute(1, 0, 2)
        
        out, last_h = self.rnn(inputs)

        out = out.permute(1, 0, 2)
        last_h = last_h.view(-1, self.hidden_size)
        
        return out, last_h
    
class Decoder(nn.Module):
    def __init__(self, num_timesteps_out, hidden_size, num_features_out, concat=True):
        super(Decoder, self).__init__()
        
        self.num_timesteps_out = num_timesteps_out
        self.hidden_size = hidden_size
        self.num_features_out = num_features_out
        self.concat = concat
        
        if concat:
            self.rnn_cell = nn.GRUCell(hidden_size*2, hidden_size*2)
            self.linear = nn.Linear(hidden_size*2, num_features_out)
        else:
            self.rnn_cell = nn.GRUCell(hidden_size, hidden_size)
            self.linear = nn.Linear(hidden_size, num_features_out)    
        
    def forward(self, encoder_ts_hid, encoder_features_hid):
        '''
        :param encoder_ts_hid: (batch_size, hidden_size)
        :param encoder_features_hid: (batch_size, hidden_size)
        '''
        decoder_out = []
        if self.concat:
            hid = torch.cat([encoder_ts_hid, encoder_features_hid], dim=-1)
        else:
            hid = encoder_ts_hid + encoder_features_hid
        
        for step in range(self.num_timesteps_out):
            if step == 0:
                out = self.linear(hid)
            else:
                hid = self.rnn_cell(hid,hid)
                out = self.linear(hid)
            out = out.view(-1,1,self.num_features_out)
            decoder_out.append(out)
        
        decoder_out = torch.cat(decoder_out,dim=-2)
        return decoder_out
    
class Project(nn.Module):
    def __init__(self, num_features, hidden_size, dropout=0.5):
        super(Project, self).__init__()
        
        self.num_features = num_features
        self.linear_1 = nn.Linear(num_features, hidden_size)
        self.dropout_1 = nn.Dropout(p=dropout)
        self.bn_1 = nn.BatchNorm1d(hidden_size)
        
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.dropout_2 = nn.Dropout(p=dropout)
        self.bn_2 = nn.BatchNorm1d(hidden_size)        
        
    def forward(self, inputs):
        """
        :params inputs: Input static features of shape(batch_size, num_features)
        """
        fc = self.bn_1(self.dropout_1(self.linear_1(inputs)))
        fc = self.bn_2(self.dropout_2(self.linear_2(fc)))

        return fc

class SelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super(SelfAttention, self).__init__()
        
        self.weight = nn.Parameter(torch.Tensor(hidden_size,))
        nn.init.uniform_(self.weight)
        
    def forward(self, inputs):
        """
        :params inputs: Input hidden features of shape(atten_dim, hidden_size)
        """
        epsilon = 1e-10
        e_ij = torch.tanh(torch.matmul(inputs, self.weight.view(-1,1)))
        a = torch.exp(e_ij)
        a = a / (torch.sum(a,dim=-2,keepdim=True) + epsilon)
        weighted_input =inputs * a
        return torch.sum(weighted_input, axis=-2),a

In [529]:
class PropagationNet(nn.Module): 
    def __init__(self, num_timesteps_in=14,
                     num_timesteps_out=7,
                     ts_num_features=3,
                     weather_num_features=11,
                     containment_num_features=18,
                     population_num_features=11,
                     healthcare_num_features=18,
                     hidden_size=64):
        super(PropagationNet, self).__init__()
        
        self.hidden_size = hidden_size
        self.ts_encoder = Encoder(ts_num_features,num_timesteps_in,hidden_size)
        
        self.weather_encoder = Encoder(weather_num_features,num_timesteps_in,hidden_size)
        self.containment_encoder = Encoder(containment_num_features,num_timesteps_in,hidden_size)
        
        self.population_project = Project(population_num_features,hidden_size)
        self.healthcare_project = Project(healthcare_num_features,hidden_size)
        
        self.attention = SelfAttention(hidden_size)
        self.decoder = Decoder(num_timesteps_out,hidden_size,ts_num_features)
        
    
    def forward(self, inputs):
        """
        :params inputs: list of 
        ts_input (batch_size, timesteps_in, ts_num_features)
        weather_input (batch_size, timesteps_in, weather_num_features)
        containment_input (batch_size, timesteps_in, containment_num_features)
        
        population_input (batch_size, population_num_features)
        healthcare_input (batch_size, healthcare_num_features)
        """
        assert(len(inputs)==5)
        ts, weather_ts, containment_ts, population, healthcare = inputs

        _, encoder_ts_hid = self.ts_encoder(ts)
        
        _, encoder_weather_hid = self.weather_encoder(weather_ts)
        _, encoder_containment_hid = self.containment_encoder(containment_ts)
        
        project_population_hid = self.population_project(population)
        project_healthcare_hid = self.healthcare_project(healthcare)
        
        
        features_hid = torch.cat([encoder_weather_hid.view(-1,1,self.hidden_size),
                                     encoder_containment_hid.view(-1,1,self.hidden_size),
                                     project_population_hid.view(-1,1,self.hidden_size),
                                     project_healthcare_hid.view(-1,1,self.hidden_size)], 
                                    dim = 1)
        
        features_att,a = self.attention(features_hid)
        
        outs = self.decoder(encoder_ts_hid, features_att)
        
        return features_hid, a, outs

In [530]:
class COVID_Dataset(Dataset):
    def __init__(self, data, y=None):
        super(COVID_Dataset, self).__init__()
        
        self.ts = data[0]
        self.weather = data[1]
        self.policy = data[2]
        self.population = data[3]
        self.healthcare = data[4]
        
        self.y = y
        
    def __getitem__(self, index):
        if self.y is not None:
            return [self.ts[index],self.weather[index],self.policy[index],self.population[index],self.healthcare[index]],self.y[index]
        return self.ts[index],self.weather[index],self.policy[index],self.population[index],self.healthcare[index]
 
    def __len__(self):
        return self.ts.shape[0]

In [531]:
class COVID_Loss(nn.Module):
    
    def __init__(self, reg_weight=0.5):
        super(COVID_Loss, self).__init__()
        self.reg_weight = reg_weight
        
    def forward(self, predictions, actuals):
        hid, a, preds = predictions
        reg = torch.mean(torch.mean(torch.sum(hid * hid, dim=-1), dim=-1), dim=-1)
        mae = nn.L1Loss()(preds,actuals)
        return reg *self.reg_weight + mae * (1 - self.reg_weight)

In [532]:
train, train_y = pd.read_pickle('../features/train_set.5.7.pkl')
test, test_y = pd.read_pickle('../features/test_set.5.7.pkl')
countries = pd.read_pickle('../features/countries.5.7.pkl')

feautures_list = pd.read_pickle('../features/feature_list.5.7.pkl')

dtrain = COVID_Dataset(train,train_y)
dtest = COVID_Dataset(test,test_y)
train_loader = DataLoader(dtrain, batch_size=64,shuffle=True)
test_loader = DataLoader(dtest, batch_size=64, shuffle=False)

In [533]:
len(feautures_list[0]),len(feautures_list[1]),len(feautures_list[2]),len(feautures_list[3]),len(feautures_list[4])

(3, 11, 18, 18, 11)

In [580]:
out_dir = '../results/exp-01/'
initial_checkpoint = None
os.makedirs(out_dir +'/checkpoint', exist_ok=True)
net = PropagationNet().cuda()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, net.parameters()),lr=0.001, weight_decay=0.02)
checkpoint = {
    'model': net.state_dict(),
}

In [581]:
if initial_checkpoint is not None:
    checkpoint = torch.load(inital_checkpoint)
    net.load_state_dict(initial_checkpoint['model'])
    print('load model from', initial_checkpoint)
    pass

In [582]:
epochs = 12
criterion = COVID_Loss(reg_weight=0.1)
loss_meter = []
for epoch in range(epochs):
    net.train()
    loss_meter = []
    for inputs, truth in tqdm(train_loader):
        optimizer.zero_grad()
        inputs = [item.float().cuda() for item in inputs]
        truth = truth.float().cuda()
        logit = net(inputs)
        loss  = criterion(logit, truth)

        loss.backward()
        optimizer.step()
        loss_meter.append(loss.item())
        
    train_loss = np.mean(loss_meter)
    print('TRAIN LOSS: {}'.format(train_loss))
    net.eval()
    eval_loss_meter = []
    for inputs, truth in tqdm(test_loader):
        inputs = [item.float().cuda() for item in inputs]
        truth = truth.float().cuda()
        logit = net(inputs)
        loss  = criterion(logit, truth)

        eval_loss_meter.append(loss.item())

    eval_loss = np.mean(eval_loss_meter)
    print('EVAL LOSS: {}'.format(eval_loss))
    
    if False:
        #print('validation summmary: {}_{}'.format(val_loss,val_smape_loss))
        torch.save(checkpoint,out_dir +'/checkpoint/epoch_%02d_val_group_smape_%s_model.pth'%(epoch, val_str))
        print('\n save model to /checkpoint/epoch_%02d_val_group_smape_%s_model.pth'%(epoch,val_str))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 3.8369567662897244


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 2.0428675413131714


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 3.0641234794133148


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.8204519152641296


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 2.6199166909070084


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.5248517990112305


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 2.2685406778899715


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.2704797983169556


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 1.959607502104531


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.2156828045845032


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 1.6973458679629043


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.0922191143035889


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 1.4711020278259062


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 1.0432737469673157


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 1.2721382127681249


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 0.9504579901695251


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 1.1032636526604773


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 0.914818525314331


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 0.9637511691576998


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 0.857856273651123


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 0.843052670149736


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 0.8285035490989685


HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))


TRAIN LOSS: 0.7403007661792594


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EVAL LOSS: 0.760029673576355


In [583]:
eval_loss_meter = []
preds = []
hiddens = []
attens = []
labels = []

for inputs, truth in tqdm(test_loader):
    inputs = [item.float().cuda() for item in inputs]
    truth = truth.float().cuda()
    logit = net(inputs)
    _hid,_a,_pred = logit
    _hid = _hid.detach().cpu().numpy()
    _a = _a.detach().cpu().numpy()
    _pred = _pred.detach().cpu().numpy()
    _label = truth.detach().cpu().numpy()
    
    preds.append(np.expm1(_pred))
    labels.append(np.expm1(_label))
    hiddens.append(_hid)
    attens.append(_a)
    
preds = np.vstack(preds)
labels = np.vstack(labels)
hiddens = np.vstack(hiddens)
attens = np.vstack(attens)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




In [584]:
pd.DataFrame(preds[:,:,0],index=countries)

Unnamed: 0,0,1,2,3,4,5,6
Afghanistan,17.468708,17.615965,19.022629,19.831581,20.709557,21.703835,22.795341
Albania,21.691587,19.018944,21.278215,23.446054,24.880253,25.808229,26.450159
Algeria,110.215225,84.447662,92.827911,100.786606,106.877098,111.546852,115.084541
Argentina,133.488129,123.582825,132.501862,139.436050,142.984344,145.088394,146.425339
Armenia,47.007568,45.778870,50.354271,53.662048,54.918179,55.301556,55.470726
...,...,...,...,...,...,...,...
US,7068.974609,16785.025391,18066.712891,19117.410156,19286.054688,19214.947266,19063.462891
Ukraine,153.146698,137.118378,158.585052,189.375717,212.239471,223.251007,229.897018
United Arab Emirates,66.274788,62.924728,70.687508,76.060783,78.164436,79.173302,79.829895
United Kingdom,2069.060303,2901.153809,3195.222900,3597.949951,3819.951660,3960.794922,4053.763672


In [585]:
pd.DataFrame(labels[:,:,0],index=countries)

Unnamed: 0,0,1,2,3,4,5,6
Afghanistan,63.000000,35.999996,8.000000,18.000000,50.000000,18.000000,56.000000
Albania,16.000000,18.000000,27.000002,29.000002,27.999998,16.000000,6.000000
Algeria,131.000015,138.999985,185.000000,80.000008,68.999992,103.000015,45.000000
Argentina,0.000000,79.000000,131.999969,186.000000,0.000000,103.000015,74.000008
Armenia,39.000000,92.000000,73.000008,34.000004,51.999996,11.000000,20.000002
...,...,...,...,...,...,...,...
US,25069.998047,30379.988281,31745.011719,33283.007812,28151.994141,29514.990234,30803.992188
Ukraine,149.000015,103.000015,175.000015,152.999985,82.999985,11.000000,143.000000
United Arab Emirates,149.999969,210.000000,240.000015,241.000031,294.000031,277.000031,283.000061
United Kingdom,4384.001953,4309.000977,4516.001465,3788.000977,5958.997070,3842.999512,3670.000488


In [586]:
pd.DataFrame((hiddens **2).sum(axis=-1)[:,:],index=countries,
             columns=['weather','policy','population','healthcare']).loc[['China','France','Germany',
                                                                          'Iran','Spain','Italy','United Kingdom','US'
                                                                          ]]

Unnamed: 0,weather,policy,population,healthcare
China,0.004009,0.009616,0.861645,23.558056
France,0.043741,0.006399,47.969727,1.098384
Germany,0.003977,0.321409,3.001686,2.396453
Iran,0.003712,0.178437,0.095415,0.087101
Spain,0.001889,0.091989,0.340818,0.647001
Italy,0.003584,0.002957,116.511742,0.89131
United Kingdom,0.012887,0.010778,0.533335,1.830337
US,0.004171,1.25953,19.111851,80.03653


In [587]:
pd.DataFrame(attens[:,:,0],index=countries,
             columns=['weather','policy','population','healthcare']).loc[['China','France','Germany','Iran','Spain']]

Unnamed: 0,weather,policy,population,healthcare
China,0.173039,0.201219,0.079034,0.546708
France,0.190516,0.125471,0.343536,0.340477
Germany,0.156808,0.129694,0.345636,0.367862
Iran,0.294845,0.37617,0.173687,0.155298
Spain,0.193559,0.198087,0.08972,0.518634
