In [1]:
import pandas as pd
import numpy as np
import torch
import os
from dateutil.parser import parse 
import datetime
from matplotlib import pyplot as plt

In [2]:
country_codes = ['ABW','AFG','AGO','ALB','AND','ARE','ARG','AUS','AUT','AZE','BDI','BEL','BEN','BFA','BGD','BGR','BHR','BHS','BIH','BLR','BLZ','BMU','BOL','BRA','BRB','BRN','BTN','BWA','CAF','CAN','CHE','CHL','CHN','CIV','CMR','COD','COG','COL','COM','CPV','CRI','CUB','CYP','CZE','DEU','DJI','DMA','DNK','DOM','DZA','ECU','EGY','ERI','ESP','EST','ETH','FIN','FJI','FRA','FRO','GAB','GBR','GEO','GHA','GIN','GMB','GRC','GRL','GTM','GUM','GUY','HKG','HND','HRV','HTI','HUN','IDN','IND','IRL','IRN','IRQ','ISL','ISR','ITA','JAM','JOR','JPN','KAZ','KEN','KGZ','KHM','KOR','KWT','LAO','LBN','LBR','LBY','LKA','LSO','LTU','LUX','LVA','MAC','MAR','MCO','MDA','MDG','MEX','MLI','MMR','MNG','MOZ','MRT','MUS','MWI','MYS','NAM','NER','NGA','NIC','NLD','NOR','NPL','NZL','OMN','PAK','PAN','PER','PHL','PNG','POL','PRI','PRT','PRY','PSE','QAT','RKS','ROU','RUS','RWA','SAU','SDN','SEN','SGP','SLB','SLE','SLV','SMR','SOM','SRB','SSD','SUR','SVK','SVN','SWE','SWZ','SYC','SYR','TCD','TGO','THA','TJK','TKM','TLS','TTO','TUN','TUR','TWN','TZA','UGA','UKR','URY','USA','UZB','VEN','VIR','VNM','VUT','YEM','ZAF','ZMB','ZWE']
filenames = ["c1_school_closing.csv", "c2_workplace_closing.csv", "c3_cancel_public_events.csv", "c4_restrictions_on_gatherings.csv", "c5_close_public_transport.csv", "c6_stay_at_home_requirements.csv", "c7_movementrestrictions.csv", "c8_internationaltravel.csv", "confirmed_cases.csv", "h1_public_information_campaigns.csv", "h2_testing_policy.csv", "h3_contact_tracing.csv", "h6_facial_coverings.csv"]

def dateConvertor(date):
    dt = parse(date)
    date = dt.strftime('%Y-%m-%d')
    return date

country_code2id = {}
for i in range(len(country_codes)):
    country_code2id[country_codes[i]] = i 

# date extraction
npi_date = pd.DataFrame({})
npi_date['Date'] = pd.read_csv(os.path.join('timeseries', filenames[0])).keys()[3:]
npi_date['Date'] = npi_date['Date'].apply(dateConvertor)

In [3]:
dataframes = {} 

countries_to_extract = ['ITA','IND','USA','CHN','BRA','IRN','CAN','GBR',
                        'FRA','ESP','BEL','DEU','NLD','MEX','TUR','SWE','ECU','RUS','PER','CHE'] # countries code for which you want data.

countries_to_extract = ['IND','USA','BRA','IRN','CAN','GBR','FRA','ESP','BEL','DEU','NLD','MEX','TUR','SWE','ECU','RUS','PER']

index = [country_code2id[code] for code in countries_to_extract]
static_data = pd.read_csv(os.path.join('timeseries', 'Consolidated.csv')).T[2:][index].T
population = static_data['Population'].to_numpy()

static_data = static_data.drop('Population', axis = 1 )
cols_to_norm = ['Density', 'Median Age']
static_data[cols_to_norm] = static_data[cols_to_norm].apply(lambda x: (x - x.min()) / x.max()-x.min())
static_data = static_data.to_numpy()
tmp = static_data[:,3:]
final_static_data = static_data[:,0:3].astype(np.float64)
#print(final_static_data)
for file in filenames:
    npi_df = pd.read_csv(os.path.join('timeseries', file)).T[3:]
    npi_df['Date'] = npi_date['Date'].values
    npi_df.set_index('Date', drop=True, inplace=True)
    npi_df = npi_df[index] # selecting countries 
    # npi_df = npi_df.rolling(7).mean()
    npi_df = npi_df[64:335] # removing Jan, Feb and Dec data
    for col in npi_df:
        npi_df[col] = pd.to_numeric(npi_df[col], errors='coerce') # converting object to numeric 
    npi_df.interpolate(method='linear', inplace=True) # interpolate missing values 
    dataframes[file[:-4]] = npi_df
   
    if(file[:-4]=='confirmed_cases'):
        npi_df = pd.read_csv(os.path.join('timeseries', file)).T[3:]
        npi_df['Date'] = npi_date['Date'].values
        npi_df.set_index('Date', drop=True, inplace=True)
        npi_df = npi_df[index] # selecting countries 
        for col in npi_df:
            npi_df[col] = pd.to_numeric(npi_df[col], errors='coerce')
        npi_df = npi_df.interpolate(method='linear') # interpolate missing values     
        npi_df = npi_df.rolling(7).mean()
        npi_df = 100*npi_df.diff()/npi_df
        npi_df = npi_df[64:335] # removing Jan, Feb and Dec data
        dataframes['growth_rate'] = npi_df

dataframes['confirmed_cases'] = 100*dataframes['confirmed_cases'].div(population)
for col in dataframes['confirmed_cases']:
    dataframes['confirmed_cases'][col] = pd.to_numeric(dataframes['confirmed_cases'][col], errors='coerce')

In [4]:
def readData(attributes, history, date):
    index = dataframes['c1_school_closing'].index.get_loc(date)
    if(history>index):
        print('Not sufficient history')
        sys.exit()
    date = datetime.datetime.strptime(date, "%Y-%m-%d")
    ref_date = date - datetime.timedelta(21)
    if ref_date.month == date.month:
        temperature = tmp[:,int(date.month)].reshape((len(countries_to_extract),1)).astype(np.float64)
    else:
        temperature = ((tmp[:,int(date.month)] + tmp[:,int(ref_date.month)]) / 2).reshape((len(countries_to_extract),1)).astype(np.float64)
    data = []
    #past_growthrates = dataframes['growth_rate'].iloc[index-history:index].values.reshape((len(countries_to_extract),21))
    new_final_static_data = np.concatenate((final_static_data,temperature),axis=1)
    for att in attributes:
        temp = dataframes[att].iloc[index-history:index].values
        if(len(data)==0):
            data = np.asarray(temp)
        else:
            data = np.dstack((data, temp))
    #x = torch.cat((torch.from_numpy(data).to(dtype=torch.double).permute(1,0,2).view(len(countries_to_extract),-1),
                   #torch.from_numpy(new_final_static_data).to(dtype=torch.double)),dim = -1)
    #x = torch.cat((torch.from_numpy(data).to(dtype=torch.double).permute(1,0,2).view(len(countries_to_extract),-1),
                   #torch.from_numpy(past_growthrates).to(dtype=torch.double)),dim = -1)
    x = torch.from_numpy(data).to(dtype=torch.double).permute(1,0,2).view(len(countries_to_extract),-1)
    y = torch.from_numpy(dataframes['growth_rate'].iloc[index:index+7].values).to(dtype=torch.double).permute(1,0)
    return x,y

In [5]:
import torch.nn.functional as F
import torch 
import torch.optim as optim
import sys
import torch.nn as nn
import pandas as pd 
import os
import numpy as np
from dateutil.parser import parse 

In [6]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim , out_dim = 64):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim,256)
        self.linear2 = torch.nn.Linear(256,512)
        self.linear3 = torch.nn.Linear(512,256)
        self.linear4 = torch.nn.Linear(256,128)
        self.linear5 = torch.nn.Linear(128,out_dim)
        self.prelu1   = torch.nn.PReLU()
        self.prelu2   = torch.nn.PReLU()
        self.prelu3   = torch.nn.PReLU()
        self.prelu4   = torch.nn.PReLU()
        self.tanh    = torch.nn.Tanh()
        self.dropout = torch.nn.Dropout(p=0.0)
        
    def forward(self, x):
        x = self.prelu1(self.linear1(x))
        x = self.dropout(self.prelu2(self.linear2(x)))
        x = self.dropout(self.prelu3(self.linear3(x)))
        x = self.dropout(self.prelu4(self.linear4(x)))
        x = self.tanh(self.linear5(x)/20)*24
        return x.squeeze()

class Decoder(torch.nn.Module):
    def __init__(self, output_dim, input_dim=64):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim,64)
        self.linear2 = torch.nn.Linear(64,128)
        self.linear3 = torch.nn.Linear(128,256)
        self.linear4 = torch.nn.Linear(256,output_dim)
        self.prelu1   = torch.nn.PReLU()
        self.prelu2   = torch.nn.PReLU()
        self.prelu3   = torch.nn.PReLU()
        self.relu   = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=0.0)
        
    def forward(self, x):
        x = self.prelu1(self.linear1(x))
        x = self.dropout(self.prelu2(self.linear2(x)))
        x = self.dropout(self.prelu3(self.linear3(x)))
        x = self.dropout(self.relu(self.linear4(x)))
        return x.squeeze()
    
class Decoder1(torch.nn.Module):
    def __init__(self,input_dim,output_dim=1):
        super(Decoder1, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim,64)
        self.linear2 = torch.nn.Linear(64,32)
        self.linear3 = torch.nn.Linear(32,16)
        self.linear4 = torch.nn.Linear(16,7)
        self.prelu1   = torch.nn.PReLU()
        self.prelu2   = torch.nn.PReLU()
        self.prelu3   = torch.nn.PReLU()
        self.dropout = torch.nn.Dropout(p=0.0)
    
    def forward(self,x):
        x = self.dropout(self.prelu1(self.linear1(x)))
        x = self.dropout(self.prelu2(self.linear2(x)))
        x = self.dropout(self.prelu3(self.linear3(x)))
        x = self.linear4(x)
        return x.squeeze()

class Decoder2(torch.nn.Module):
    def __init__(self,input_dim,output_dim=168):
        super(Decoder2, self).__init__()
        

        self.linear1 = torch.nn.Linear(input_dim,64)
        self.linear4 = torch.nn.Linear(64,64)
        
        self.linear2 = torch.nn.Linear(64,128)
        self.linear3 = torch.nn.Linear(128,output_dim)
        self.prelu1   = torch.nn.PReLU()
        self.prelu2   = torch.nn.PReLU()
        self.prelu3   = torch.nn.PReLU()
        
        self.dropout = torch.nn.Dropout(p=0.0)
        self.relu = torch.nn.ReLU()
        
    def forward(self,x):
        x = self.dropout(self.prelu1(self.linear1(x)))
        x = self.dropout(self.prelu3(self.linear4(x)))
        
        x = self.dropout(self.prelu2(self.linear2(x)))
        x = self.relu(self.linear3(x))
        return x.squeeze() 

In [9]:
##training,testing and validation split
training_attributes = ["c1_school_closing", "c2_workplace_closing", "c3_cancel_public_events", "c4_restrictions_on_gatherings", "c5_close_public_transport", "c6_stay_at_home_requirements", "c7_movementrestrictions", "c8_internationaltravel", "h1_public_information_campaigns", "h2_testing_policy", "h3_contact_tracing", "h6_facial_coverings" ]
history = 21
train_dates = npi_date['Date'][100:300].values
x,y = readData(attributes=training_attributes, history=history, date=train_dates[-1])
print(x.shape)
print(y.shape)

torch.Size([17, 252])
torch.Size([17, 7])


In [8]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)
        
enc = Encoder(21*12)
dec = Decoder(21*12)
enc.load_state_dict(torch.load('./checkpoints/enc.pth'))
dec.load_state_dict(torch.load('./checkpoints/dec.pth'))
# enc.apply(init_weights)
# dec.apply(init_weights)

params = list(enc.parameters())+list(dec.parameters())
optimizer = optim.Adam(params, lr=2.5e-5, weight_decay=1e-5)
mse_loss = torch.nn.L1Loss()

def validation(enc, dec1):
    enc.eval()
    dec.eval() 
    valid_loss = 0
    for i in range(len(validation_dates)):
        x,y = readData(attributes=training_attributes, history=history, date=validation_dates[i])
        x,y = x.float(), y.float()
        features = dec(enc(x))
        loss = mse_loss(features,x)
        valid_loss += loss.item()
    enc.train()
    dec.train()
    return valid_loss

prev_validation_loss = validation(enc, dec)
print('validation loss before training %0.4f'%prev_validation_loss)
print('SCAM train exit')
sys.exit()

for epoch in range(1000):
    np.random.shuffle(train_dates)
    epoch_loss = 0
    enc.train()
    dec.train()
    loss_t = 0.0
    for i in range(len(train_dates)): 
        optimizer.zero_grad()
        x,y = readData(attributes=training_attributes, history=history, date=train_dates[i])
        x,y = x.float(), y.float()
        x_pred = dec(enc(x+(torch.rand(x.shape)-0.5)/8))
        loss = mse_loss(x_pred,x)
        loss.backward()
        optimizer.step()
        loss_t += loss.item()
    
    valid_loss = validation(enc, dec)
    print('epoch %d | training loss %0.4f | validation loss %0.4f'%(epoch, loss_t, valid_loss))
    
    if (epoch+1)%10 == 0:
        if valid_loss < prev_validation_loss:
            print('saving weights for lower loss')
            torch.save(enc.state_dict(), './checkpoints/enc.pth')
            torch.save(dec.state_dict(), './checkpoints/dec.pth')
            print('done saving')
            prev_validation_loss = valid_loss
        print('='*50)


NameError: name 'validation_dates' is not defined

In [10]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

enc = Encoder(21*12)
enc.load_state_dict(torch.load('./checkpoints/enc.pth'))
dec1 = Decoder1(64)
dec1.apply(init_weights)

params = list(dec1.parameters())
optim1 = optim.Adam(params, lr=1e-5, weight_decay=0.0)
mse_loss = torch.nn.L1Loss(reduction='none')
w = torch.tensor([10, 5, 3, 1, 0.5, 0.4, 0.3])
w = w/w.sum()

for epoch in range(200):
    np.random.shuffle(train_dates)
    epoch_loss = 0
    enc.train()
    dec1.train()
    loss_t = 0.0
    for i in range(len(train_dates)):
        x,y = readData(attributes=training_attributes, history=history, date=train_dates[i])
        x,y = x.float(), y.float()
        embed = enc(x)
        y_pred = dec1(embed)
        loss_mse = 0.0 
        for c in range(len(countries_to_extract)):
            loss_mse += (mse_loss(y_pred[c],y[c])*w).mean()
        loss = loss_mse
        loss /= len(countries_to_extract)
        loss.backward()
        optim1.step()
        optim1.zero_grad()
        loss_t += loss.item()
    print('epoch %d | training loss %0.4f'%(epoch, loss_t))
    
    if (epoch+1)%5 == 0:
        print('saving weights')
        torch.save(dec1.state_dict(), './checkpoints/dec1.pth')
        print('done saving')
        print('='*50)

  This is separate from the ipykernel package so we can avoid doing imports until
  allow_unreachable=True)  # allow_unreachable flag


epoch 0 | training loss 48.7309
epoch 1 | training loss 43.8228
epoch 2 | training loss 41.3564
epoch 3 | training loss 39.6740
epoch 4 | training loss 38.4775
saving weights
done saving
epoch 5 | training loss 37.6657
epoch 6 | training loss 36.9690
epoch 7 | training loss 36.3875
epoch 8 | training loss 35.9047
epoch 9 | training loss 35.4820
saving weights
done saving
epoch 10 | training loss 35.1111
epoch 11 | training loss 34.7820
epoch 12 | training loss 34.4975
epoch 13 | training loss 34.2305
epoch 14 | training loss 33.9850
saving weights
done saving
epoch 15 | training loss 33.7875
epoch 16 | training loss 33.5811
epoch 17 | training loss 33.3953
epoch 18 | training loss 33.2327
epoch 19 | training loss 33.0778
saving weights
done saving
epoch 20 | training loss 32.9428
epoch 21 | training loss 32.7884
epoch 22 | training loss 32.6632
epoch 23 | training loss 32.5319
epoch 24 | training loss 32.4176
saving weights
done saving
epoch 25 | training loss 32.3159
epoch 26 | traini

epoch 169 | training loss 27.5838
saving weights
done saving
epoch 170 | training loss 27.5325
epoch 171 | training loss 27.5265
epoch 172 | training loss 27.4938
epoch 173 | training loss 27.4548
epoch 174 | training loss 27.4460
saving weights
done saving
epoch 175 | training loss 27.4159
epoch 176 | training loss 27.4133
epoch 177 | training loss 27.3716
epoch 178 | training loss 27.3476
epoch 179 | training loss 27.3255
saving weights
done saving
epoch 180 | training loss 27.2941
epoch 181 | training loss 27.2744
epoch 182 | training loss 27.2775
epoch 183 | training loss 27.2534
epoch 184 | training loss 27.2152
saving weights
done saving
epoch 185 | training loss 27.2013
epoch 186 | training loss 27.1849
epoch 187 | training loss 27.1619
epoch 188 | training loss 27.1264
epoch 189 | training loss 27.1113
saving weights
done saving
epoch 190 | training loss 27.0904
epoch 191 | training loss 27.0622
epoch 192 | training loss 27.0465
epoch 193 | training loss 27.0127
epoch 194 | tra

In [33]:
enc = Encoder(21*12)
enc.load_state_dict(torch.load('./checkpoints/enc.pth'))
dec1 = Decoder1(64)
dec1.load_state_dict(torch.load('./checkpoints/dec1.pth'))
mse_loss = torch.nn.L1Loss(reduction='none')



def validation(enc, dec1, dates):
    enc.eval()
    dec1.eval()
    loss_list = [0]*len(countries_to_extract)
    for i in range(len(dates)):
        x,y = readData(attributes=training_attributes, history=history, date=dates[i])
        x,y = x.float(), y.float()
        if y.shape[1] != 7:
            continue
        embed = enc(x)
        y_pred = dec1(embed)
        print('GT\n', y[:,0])
        print('using given NPI\n', y_pred[:,0])
        embed = enc(x*0)
        y_pred = dec1(embed)
        print('using zero NPI\n', y_pred[:,0])
        sys.exit()
        loss = mse_loss(y_pred, y)[:,0]
        loss_list = [loss_list[i]+loss[i].item() for i in range(len(countries_to_extract))]
    loss_list = [l/len(dates) for l in loss_list]
    return loss_list

train_dates = npi_date['Date'][150:300].values
loss_list = validation(enc, dec1, train_dates)
for i in range(len(countries_to_extract)):
    print(i,countries_to_extract[i],loss_list[i])

GT
 tensor([3.3557, 1.7672, 2.3862, 1.0832, 0.2652, 0.1236, 0.1928, 0.1609, 0.1613,
        0.2194, 0.1228, 2.3998, 0.5780, 0.7817, 1.5962, 0.9984, 1.1281])
using given NPI
 tensor([1.6954e+00, 1.3272e+00, 2.0652e+00, 1.0936e+00, 5.0876e-01, 1.7097e-01,
        3.7676e-02, 1.6899e-02, 1.5145e-01, 1.1265e-03, 4.5441e-01, 1.9157e+00,
        7.1283e-01, 5.3871e-01, 1.5428e+00, 1.3036e+00, 1.8252e+00],
       grad_fn=<SelectBackward>)
using zero NPI
 tensor([0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580,
        0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580, 0.0580],
       grad_fn=<SelectBackward>)


SystemExit: 

In [None]:

enc = Encoder(21*12)
enc.load_state_dict(torch.load('./checkpoints/enc.pth'))

dec2 = Decoder2(input_dim=64+12,output_dim=21*12)
dec2.load_state_dict(torch.load('./checkpoints/dec2_prescriptor5.pth'))

optim_3 = optim.Adam(dec2.parameters(), lr=1e-5, weight_decay=1e-5)

mse_loss = torch.nn.MSELoss(reduction='none')
total_dates = npi_date['Date'][85:300].values

# gr_weight = (5+torch.tensor([2,8,1,3,3,4,4,9,6,7,2,13]))/5
# gr_weight = gr_weight/gr_weight.max()
gr_weight_new = torch.tensor([0.20525048673152924, 0.7103343605995178, 0.5421071648597717, 0.7157260775566101, 0.4592893123626709, 0.8335421681404114, 0.765260636806488, 1.0, 0.33554038405418396, 0.6577407121658325, 0.5881587862968445, 0.866989254951477])
gr_weight_hist = gr_weight
for i in range(history-1):
    gr_weight_hist = torch.cat((gr_weight_hist, gr_weight), dim=0)

max_range = torch.tensor([3,3,2,4,2,3,2,4,1,2,1,3])
max_range_hist = max_range
for i in range(history-1):
    max_range_hist = torch.cat((max_range_hist, max_range), dim=0)

for epoch in range(300):
    dec2.train()
    np.random.shuffle(total_dates)
    dec2.train()
    train_loss = 0.0
    for i in range(len(total_dates)-history):
        optim_3.zero_grad()
        x,y = readData(attributes=training_attributes, history=history, date=total_dates[i])
        x_new,_ = readData(attributes=training_attributes, history=history, date=total_dates[i+history])
        x,y, x_new = x.float(), y.float(), x_new.float()
        embed = enc(x)
        attr_weights = torch.rand(embed.shape[0],12)*3
        embed_new = torch.cat((embed, attr_weights), dim=1)
        x_pred = torch.round(dec2(embed_new))
        npi_loss = (x_pred.view(x_pred.shape[0],-1,history).sum(dim=2)*attr_weights).sum()/(len(countries_to_extract)*21)
        
        loss2 = 0.0
        loss3 = 0.0
        for c in range(len(countries_to_extract)):
            loss2 += mse_loss(x_pred[c],x_new[c]).mean()
            loss3 += ((max_range_hist-x_pred[c])*gr_weight_hist).mean()
        
        loss2 /= len(countries_to_extract) 
        loss3 /= len(countries_to_extract) 
        loss = 0.5*loss2 + 0.05*npi_loss + 2.0*loss3 
        
        loss.backward()
        optim_3.step()
        train_loss += loss.item()
        
    print('epoch %d | train_loss %0.4f'%(epoch, train_loss))
    if (epoch+1)%5 == 0:
        print('saving')
        torch.save(dec2.state_dict(), './checkpoints/dec2_prescriptor6.pth')
        print('done saving')
        print('GT_NPI')
        print(x[0,0:12])
        print('prescribed_NPI')
        print(x_pred[0,0:12])
        print("="*50)

epoch 0 | train_loss 492.5075
epoch 1 | train_loss 505.6701
epoch 2 | train_loss 520.8059
epoch 3 | train_loss 538.1214
epoch 4 | train_loss 550.7062
saving
done saving
GT_NPI
tensor([3., 2., 2., 3., 1., 2., 2., 4., 2., 2., 2., 4.])
prescribed_NPI
tensor([2., 1., 1., 2., 1., 1., 1., 3., 1., 2., 1., 3.],
       grad_fn=<SliceBackward>)
epoch 5 | train_loss 568.8742
epoch 6 | train_loss 588.5715
epoch 7 | train_loss 607.8962
epoch 8 | train_loss 628.5334
epoch 9 | train_loss 645.0157
saving
done saving
GT_NPI
tensor([3., 3., 2., 4., 2., 1., 2., 3., 2., 1., 2., 0.])
prescribed_NPI
tensor([1., 1., 1., 2., 1., 1., 1., 2., 1., 1., 1., 2.],
       grad_fn=<SliceBackward>)
epoch 10 | train_loss 664.2870
epoch 11 | train_loss 680.8904
epoch 12 | train_loss 700.1942
epoch 13 | train_loss 716.7489
epoch 14 | train_loss 734.9568
saving
done saving
GT_NPI
tensor([3., 3., 2., 4., 2., 3., 2., 4., 2., 1., 2., 1.])
prescribed_NPI
tensor([1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1.],
       grad_fn=<

In [19]:
enc = Encoder(21*12)
enc.load_state_dict(torch.load('./checkpoints/enc.pth'))

dec2 = Decoder2(input_dim=64+12,output_dim=21*12)
dec2.load_state_dict(torch.load('./checkpoints/dec2_prescriptor3.pth'))

x,y = readData(attributes=training_attributes, history=history, date=total_dates[50])
x_new,y_ = readData(attributes=training_attributes, history=history, date=total_dates[50+21]) # future NPI
x,y = x.float(), y.float()
attr_weights = torch.rand(embed.shape[0],12)*3
x_pred = torch.round(dec2(torch.cat((enc(x), attr_weights), dim=1)))

for i in range(17):
    print('NPI cost\n', attr_weights[i])
    print('NPI GT\n', x_new[i,0:12])
    print('NPI pred\n', x_pred[i,0:12])
    print("="*50)


NPI cost
 tensor([0.2537, 1.2809, 1.9333, 0.8621, 2.5226, 0.8534, 0.7366, 0.2273, 1.5493,
        2.8106, 0.7414, 2.7061])
NPI GT
 tensor([3., 3., 2., 4., 2., 3., 2., 4., 2., 1., 2., 0.], dtype=torch.float64)
NPI pred
 tensor([3., 3., 2., 4., 2., 3., 3., 4., 3., 3., 2., 5.],
       grad_fn=<SliceBackward>)
NPI cost
 tensor([0.6596, 0.7806, 2.4984, 2.7556, 0.1131, 2.3510, 0.8722, 0.6665, 2.0760,
        2.1736, 1.1642, 0.8699])
NPI GT
 tensor([3., 3., 2., 4., 1., 2., 2., 3., 2., 3., 1., 1.], dtype=torch.float64)
NPI pred
 tensor([3., 3., 2., 4., 1., 2., 2., 4., 2., 4., 1., 4.],
       grad_fn=<SliceBackward>)
NPI cost
 tensor([0.0294, 2.2114, 2.4352, 2.5469, 1.9960, 2.7121, 1.3715, 2.1045, 2.5396,
        1.0447, 0.2097, 2.7707])
NPI GT
 tensor([3., 3., 2., 3., 2., 1., 2., 3., 2., 1., 0., 2.], dtype=torch.float64)
NPI pred
 tensor([3., 4., 2., 4., 3., 3., 3., 4., 3., 3., 2., 5.],
       grad_fn=<SliceBackward>)
NPI cost
 tensor([0.4743, 2.9148, 0.3974, 0.6002, 2.3440, 2.3070, 2.0523, 2.