# Pathcheck Team LSTM
Link to [Master Spreadsheet](https://docs.google.com/document/d/1-fFRebE7bEcE4AUnVpj63r9dqJF-0PbTC95FciGVYY0/edit)

## Instructions 
Before you try to start this notebook load the data from [here](https://github.com/leaf-ai/covid-xprize/tree/master/covid_xprize/examples/predictors/lstm/data)

~~Make sure you click the `..` folder at the top of google collab's folder navigation tab to navigate to the root.~~ Make a new folder called `data` ~~in the root~~ in the default folder `/content/`, and upload all of the csvs there.

See the `MockArgs` class to change the parameters. Don't use clr yet. i'd also recommend changing the stepsize parameter (it determines how many the gradients are retained before the optimizer update step is called). 

Ideally we can use tensorboardX to see the charts for error metrics but i havne't figured out how to ssh tunnel a port from a google colab yet. 

## Things to work on

**MVP**

* **Tuning the model for good performance** (these parameters seem to plateau at about 10 epochs) decreasing lr or trying with a different optimizer (sgd) would be a good place to start (if you get a good predictor please download the .pth file!)
* add validation loop # currently trains 10 models and picks the best one.
* ~~**finish the lstm rollout rework from keras to using the new pytorch model (see `_lstm_get_test_rollouts`)** loks like # m(features)~~
* **When this is done ^ upload a copy of the notebook to our sandbox for the robo judge**
* ~~figure out tensorboard for metrics~~ (others will need to test theri configs) (see the note in [welcome.ipynb](https://prcx-pathcheck4307.xprizenotebooks.org/lab/tree/tutorials/welcome.ipynb))

**Ayush's pointers (Novel Qualitative contributions)**
* ~~Use a GRU (in progress)~~
* Use seq2seq model pytorch model [(drop in)](https://github.com/gautham20/pytorch-ts), 
* ~~use smooth l1 loss~~
* [weight losses](https://github.com/ActiveConclusion/COVID19_mobility/blob/master/google_reports) by the mobility data in each county/region

**Ablation Studies**
* Vary the number of days parameter `NB_LOOKBACK_DAYS` (if allowed)
* Study Different spatial region resolutions (if applicable)
* Add or preprocess context (numerical) data (such as [weather csv](https://repo.ijs.si/vitojanko/covid-from-scratch/-/blob/master/Data/features.csv), [more xprize columns](https://docs.google.com/spreadsheets/d/1waAGAoF0NE9AUHP6094kUjG8lsCdyrM53AlzaeegQic/edit#gid=1237230743))


**Utility/minor issues**
* ~~double check that saved models load~~
* get pytext-nlp to work (or use the original nlp notebook to improve the model). Pytext currently crashes colab when loaded
* Cleaning up and ~~organizing code~~



In [None]:
# ! pip install pytext-nlp
# ! pip install tensorboardX
%load_ext tensorboard


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
import os
import pandas as pd
import requests
import numpy as np
import re
from pandas.api.types import is_string_dtype, is_numeric_dtype
import warnings
from pdb import set_trace
from torch import nn, optim, as_tensor
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.nn.init import *
import sklearn
# from sklearn_pandas import DataFrameMapper
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer
import urllib.request
# import pytext
from sklearn.metrics import classification_report, f1_score
from sklearn.model_selection import train_test_split
import pickle
import gc
import time
import copy
import shutil
import copy

In [None]:
import os
import shutil
import json
import argparse
import time
from datetime import datetime
from collections import Counter

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter


In [None]:
NPI_COLUMNS = ['C1_School closing',
               'C2_Workplace closing',
               'C3_Cancel public events',
               'C4_Restrictions on gatherings',
               'C5_Close public transport',
               'C6_Stay at home requirements',
               'C7_Restrictions on internal movement',
               'C8_International travel controls',
               'H1_Public information campaigns',
               'H2_Testing policy',
               'H3_Contact tracing',
               'H6_Facial Coverings']

### Metric/Misc Helper functions

In [None]:
from sklearn import metrics
import sys
from functools import partial

def get_evaluation(y_true, y_prob, list_metrics):
    y_pred = np.round(y_prob)
    output = {}
    rmse = partial(metrics.mean_squared_error, squared=False)
    func_map = {
        "explained_variance": metrics.explained_variance_score,
        "mean_absolute_error": metrics.mean_absolute_error,
        "mean_squared_error": metrics.mean_squared_error,
        "rmse": rmse,
        "r2": metrics.r2_score
    }

    for m, metric_func in func_map.items():
        if m in list_metrics:
            output[m] = metric_func(y_true, y_pred)

    return output

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# cyclic learning rate scheduling

def cyclical_lr(stepsize, min_lr=1.7e-3, max_lr=1e-2):

    # Scaler: we can adapt this if we do not want the triangular CLR
    def scaler(x):
        return 1.0

    # Lambda function to calculate the LR
    def lr_lambda(it):
        return min_lr + (max_lr - min_lr) * relative(it, stepsize)

    # Additional function to see where on the cycle we are
    def relative(it, stepsize):
        cycle = math.floor(1 + it / (2 * stepsize))
        x = abs(it / stepsize - 2 * cycle + 1)
        return max(0, (1 - x)) * scaler(cycle)

    return lr_lambda

### Pytorch Model and Dataset Classes

See [this thread](https://discuss.pytorch.org/t/initializing-rnn-gru-and-lstm-correctly/23605/5) for details on initialization

In [None]:
#PYTORCH MODELS

def init_weights(m):
    if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
        for name, param in m.named_parameters():
            for idx in range(4):
                mul = param.shape[0]//4
                if 'weight_ih' in name:
                    torch.nn.init.xavier_uniform_(param.data[idx*mul:(idx+1)*mul])
                elif 'weight_hh' in name:
                    torch.nn.init.orthogonal_(param.data[idx*mul:(idx+1)*mul])
                elif 'bias' in name:
                    param.data.fill_(0.01)
    elif type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0.01)

class ContextEncoder(nn.Module):
  def __init__(self, nb_loopback_days, nb_context_dim, hidden_dim, args):
    super(ContextEncoder, self).__init__()
    self.args = args
    self.hidden_dim = hidden_dim
    self.nb_loopback_days = nb_loopback_days  # this is just the sequence length
    self.nb_context_dim = nb_context_dim
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.activation = nn.Softplus()

    self.rnn = nn.LSTM(self.nb_context_dim, hidden_size=self.hidden_dim, num_layers=args.n_hidden_layers, batch_first=False)
    self.ln_rnn = nn.LayerNorm(self.hidden_dim)
    self.hidden2out = nn.Linear(self.hidden_dim, 1)
    # h0 = torch.zeros(1, self.args.batch_size, self.hidden_dim).to(device).float()
    # # h1 = torch.zeros(1, self.args.batch_size, self.hidden_dim).to(device).double()
    # self.hidden = h0
  
  def forward(self, sequence):
    rnn_out, _ = self.rnn(sequence.view(self.nb_loopback_days, len(sequence), -1))
    rnn_out = self.ln_rnn(rnn_out)
    rnn_out = torch.transpose(rnn_out, 1, 0) #swap the batch and sequence dims back
    pre_activation = self.hidden2out(rnn_out[:,-1])
    pred = self.activation(pre_activation)
    return pred

  def _detach(self):
      pass
      # self.hidden = self.hidden.detach()

class ActionEncoder(nn.Module):
  def __init__(self, nb_loopback_days, nb_action_dim, hidden_dim, args):
    super(ActionEncoder, self).__init__()
    self.args = args
    self.hidden_dim = hidden_dim
    self.nb_loopback_days = nb_loopback_days
    self.nb_action_dim = nb_action_dim
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.activation = nn.Sigmoid()

    # input size, hidden size, num_layers
    self.rnn = nn.LSTM(self.nb_action_dim, hidden_size=self.hidden_dim, num_layers=args.n_hidden_layers, batch_first=False)
    self.ln_rnn = nn.LayerNorm(self.hidden_dim)
    self.hidden2out = nn.Linear(self.hidden_dim, 1)
    # h0 = torch.zeros(1, self.args.batch_size, self.hidden_dim).to(device).double()
    # # h1 = torch.zeros(1, self.args.batch_size, self.hidden_dim).to(device).double()
    # self.hidden = h0 


  def forward(self, sequence):
    rnn_out, _ = self.rnn(sequence.view(self.nb_loopback_days, len(sequence), -1))
    rnn_out = self.ln_rnn(rnn_out)
    rnn_out = torch.transpose(rnn_out, 1, 0) #swap the batch and sequence dims back
    pre_activation = self.hidden2out(rnn_out[:,-1])
    pred = self.activation(pre_activation)
    return pred

  def _detach(self):
      pass
      # self.hidden = self.hidden.detach()
  

class CombinedModel(nn.Module):
    def __init__(self, nb_loopback_days, nb_action_dim, nb_context_dim, hidden_dim, args):
      super(CombinedModel, self).__init__()
      self.args = args
      self.hidden_dim = hidden_dim
      self.nb_loopback_days = nb_loopback_days
      self.nb_action_dim = nb_action_dim
      self.nb_context_dim = nb_context_dim

      self.action_encoder = ActionEncoder(nb_loopback_days, nb_action_dim, args.hidden_dim_action, args).float()
      self.context_encoder = ContextEncoder(nb_loopback_days, nb_context_dim, hidden_dim, args).float()

      self.lambda_layer = _combine_r_and_d
      self.apply(init_weights)

    def forward(self, sequence):
      """
      take in a concatenated sequcne in the form [context, action], return prediction
      """
      context, action = torch.split(sequence, [self.nb_context_dim, self.nb_action_dim], dim=2)
      c = self.context_encoder(context)
      a = self.action_encoder(action)
      return self.lambda_layer((c,a))

    def _detach(self):
        self.context_encoder._detach()
        self.action_encoder._detach()

    @torch.no_grad() 
    def predict(self, input):
        """
        Predict the output for a single data input
        input shapes (TEST_DAYS, LOOPBACK_DAYS, feature_dim)
        """
        [context_input, action_input] = input
        data = XPrizeDataset(context_input, action_input, self.args, zero_pad = True)
        prediction_loader = DataLoader(data, batch_size=self.args.batch_size, shuffle=False)
        _, (feats, _) = list(enumerate(prediction_loader))[0]
        pred = self.forward(feats).squeeze(dim=-1)
        pred = pred.cpu().detach().numpy()[0] #remove dummy inputs for batch prediction
        return pred

class XPrizeDataset(Dataset):
    def __init__(self, context, action, args, label=None, zero_pad=False, transform=torch.from_numpy):
        batch_size, to_sequence = args.batch_size, args.to_sequence
        self.args = args
        self.context = context.astype(float)
        self.action = action.astype(float)
        if to_sequence:
            # NOTE assumes inputs are given in order
            label_builder = []
            for i in range(len(self.context) - args.nb_lookback_days -1):
                label_builder.append(label[i + 1 : i + 1 + args.nb_lookback_days] if label is not None else np.zeros(args.nb_lookback_days))
            label = np.array(label_builder)

        # print(f"pre reshaping for dataset action shape {self.action.shape}")
        if zero_pad: #pad with 0 entries
            seq_len, seq_depth, feature_dim = self.context.shape
            self.context = np.vstack((context, np.zeros((np.abs(batch_size - seq_len), seq_depth, feature_dim))))
            # print(self.context.shape)
            seq_len, seq_depth, feature_dim = self.action.shape
            self.action = np.vstack((action, np.zeros((np.abs(batch_size - seq_len), seq_depth, feature_dim))))
            # print(self.action.shape)
            label = np.zeros(len(self.context)) if not to_sequence else np.vstack((label, np.zeros((np.abs(batch_size - len(label)), label.shape[1]))))
        self.label = label.astype(float)
        self.transform = transform
        # print(f"dataset conext {self.context.shape}")
        # print(f"dataset action {self.action.shape}")
        # print(f"dataset label {self.label.shape}")

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        selected_contexts = self.context[idx]
        selected_actions = self.action[idx]
        labels = self.label[idx]
        inputs = np.hstack((selected_contexts, selected_actions))
        if self.transform:
          # pre_shape = inputs.shape
          inputs = self.transform(inputs).float()
          labels = torch.tensor(labels).float()
        return inputs, labels


## Regularizers

We have a special regularizer from Ayush. Both are meant to create a monotonic approach to the optimal solution. #1 is a version of knowledge distillation that prevents the gradient update from being too far form the last model's weights. #2 is from [this paper](https://arxiv.org/pdf/2006.13593.pdf) and is better for regression so that is what we will be using. They both help to disentangle different components of the loss for better updates.

**Retrospective (#2) loss usage**

Note: each time we update the backup models we should set K = K + K_0*0.05
```
# # Step 1  - model def
model = MyModel()
model_retr = MyModel() # an add-on backup model
model.train()
model_retr.eval() # dont train this
# Step 2 - Do forward propogation
X, Y = image_batch, label_batch
outputs = model(X)
stale_outputs = model_retr(X)
labels = Y
# Step 3 - Update retr model periodically
if step % N == 0:
   model_retr.load_state_dict(model.state_dict())
```

In [None]:
# Self-distillation objectives
# Loss Version 1
# def retr_loss_kd(outputs, stale_outputs, labels, alpha=0.7, T=1):
#     KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
#                                  F.softmax(stale_outputs/T, dim=1)) * (alpha * T * T) + \
#                   F.CRITERION(outputs, labels) * (1. - alpha) # knowledge distillation this contrasts with the second one (taek smal steps forward (divergence from the past ))
#     return KD_loss
# Loss Version 2 (From the paper)
def retr_loss(outputs, stale_outputs, labels, l1_loss=None, K=2):
    if l1_loss is None:
        l1_loss = nn.L1Loss()
    retr_loss = (K+1)*l1_loss(outputs, labels) - K*l1_loss(outputs, stale_outputs)
    return retr_loss

#maybe ignore 1 and just use 2 (not for regression)


def xprize_regularizer(model, lambda_l1=0.1):
    """ forces weights tto be positive"""
    # with torch.enable_grad():
    lossl1 = torch.tensor(0).float()
    for model_param_name, model_param_value in model.named_parameters():
        # print(model_param_value)
        if "action_encoder" in model_param_name and not model_param_name.endswith("hidden2out.bias"):
            lossl1 += lambda_l1 * model_param_value.clip(max=0).sum()
    lossl1 = lossl1.abs() 
    return lossl1  

def l2_regularizer(model, reg=1e-6):
    """ forces magnitude of weights to be small"""
    # with torch.enable_grad():
    l2_loss = torch.tensor(0).float()
    for name, param in model.named_parameters():
        if 'bias' not in name:
            l2_loss = l2_loss + (0.5 * reg * torch.sum(torch.pow(param, 2)))
    return l2_loss

def l1_regularizer(model, reg=1e-6):
    """ lasso regression forces us to focus on a subset of the weights"""

    # with torch.enable_grad():
    l1_loss = torch.tensor(0).float()
    for name, param in model.named_parameters():
        if 'bias' not in name:
            l1_loss = l1_loss + (reg * torch.sum(torch.abs(param)))
    return l1_loss

def orth_regularizer(model, reg=1e-6):
    """ forces weights to be orthogonal, TODO check to see if we should take out the fully connected weights from this penalty, and or the rnn ones"""
    # with torch.enable_grad():
    orth_loss = torch.tensor(0).float()
    for name, param in model.named_parameters():
        if 'bias' not in name:
            param_flat = param.view(param.shape[0], -1)
            sym = torch.mm(param_flat, torch.t(param_flat))
            sym -= torch.eye(param_flat.shape[0])
            orth_loss = orth_loss + (reg * sym.abs().sum())
    return orth_loss

def call_regularizers(model, args):
    with torch.enable_grad():
        loss = torch.tensor(0).float()
        if args.xprize_regularizer:
            loss += xprize_regularizer(model, args.xprize_lambda)
        if args.l1_regularizer:
            assert not args.l2_regularizer, "shouldn't use l2 and l1 reg"  
            loss += l1_regularizer(model, args.l1_lambda)
        if args.l2_regularizer:
            assert not args.l1_regularizer, "shouldn't use l2 and l1 reg" 
            loss += l2_regularizer(model, args.l2_lambda) 
        if args.orth_regularizer:
            loss += orth_regularizer(model, args.orth_lambda)
        return loss

### Args Run and Train

In [None]:
class MockArgs(object):
  def __init__(self):
    #META
    self.dry_run = False
    self.flush_history = False
    self.log_path = "./logs/"
    self.model_name="lstm2xc4xa"
    self.output = "./modelsaves/"
    self.log_every= 100
    self.checkpoint = False


    # Run loop Params
    self.scheduler = None # one of "clr", "step"
    self.optimizer = "adam" # one of "adam", "sgd"
    self.epochs = 17
    self.learning_rate = .008
    self.min_lr = 1.7e-3 # CLR ONLY
    self.max_lr = 1e-2 # CLR ONLY
    self.early_stopping = True
    self.patience = 3 #patience for early stopping (how long to wait for improvement)
    self.weight_decay = 0
    self.momentum = .8
    self.model_selecting_metric = "loss" # highly recommend "loss", see get_evaluation() for metrics, note for some metrics higher is better, may need to change the scoring logic
    self.criterion = "l1" # one of l1, smooth_l1
    self.smooth_l1_beta = 1 # 0 = l1 loss


    #DATA Params
    self.batch_size = 250
    self.shuffle = True
    self.num_workers = 4
    self.val_split_percentage = .2


    ## From Xprize model params
    self.nb_lookback_days = 21
    self.nb_test_days = 14
    self.window_size = 7
    self.num_trials = 3
    self.n_hidden_layers = 1
    self.hidden_dim= 32
    self.hidden_dim_action= 64
    self.max_nb_countries = 20


    #Regularizers
    self.xprize_regularizer = False
    self.xprize_lambda = 1e-6
    self.l1_regularizer = False
    self.l1_lambda = 1e-6
    self.orth_regularizer = False
    self.orth_lambda = 1e-6
    self.l2_regularizer = False
    self.l2_lambda = 1e-6
    self.regularizer_ratio = .5
    self.retr_loss= True
    self.retr_warmup = 0
    self.retr_burnin = [0]

    if not os.path.exists(self.output):
        os.makedirs(self.output)
    if not os.path.exists(self.log_path):
        os.makedirs(self.log_path)

    self.stepsize = 1
    # For seq2seq (not impl)
    self.embedding_dim = 16 #DO NOT USE
    self.to_sequence = False

  def set_regularizers(self):
      self.l1_lambda = self.xprize_lambda / self.regularizer_ratio
      self.l2_lambda = self.xprize_lambda / self.regularizer_ratio
      self.orth_lambda = self.xprize_lambda / (self.regularizer_ratio*2)

def criterion_to_function(criterion):
    if criterion == 'l1':
        return  nn.L1Loss(reduction="mean")
    elif criterion == "smooth_l1":
        return nn.SmoothL1Loss(beta=args.smooth_l1_beta)
    elif criterion == "mse":
        return nn.MSELoss()

    else:
        assert False, f"{criterion} criterion not found"


 
def run(args, model, X_context, X_action, y, trial_num=0):

    if args.flush_history == 1:
        objects = os.listdir(args.log_path)
        for f in objects:
            if os.path.isdir(args.log_path + f):
                shutil.rmtree(args.log_path + f)

    now = datetime.now()
    logdir = args.log_path + args.model_name + f"_{trial_num}/"
    if os.path.exists(logdir):
        shutil.rmtree(logdir)
    os.makedirs(logdir, exist_ok=True)
    log_file = logdir + "log.txt"
    writer = SummaryWriter(logdir)
    criterion = criterion_to_function(args.criterion)


    if args.optimizer == "sgd":
        if args.scheduler == "clr":
            optimizer = torch.optim.SGD(
                model.parameters(), lr=1, momentum=0.9, weight_decay=0.00001
            )
        else:
            optimizer = torch.optim.SGD(
                model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay
            )
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    best_score = float('inf')
    best_epoch = 0
    best_model_path = ""

    root_set = XPrizeDataset(X_context, X_action, args, label=y, zero_pad=False)
    val_n = int(len(root_set)*args.val_split_percentage)
    train_n = len(root_set) - val_n
    train_set, val_set = torch.utils.data.random_split(root_set, [train_n, val_n])

    training_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, drop_last=True)
    val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, drop_last=True)

    if args.scheduler == "clr":
        stepsize = int(args.stepsize * len(training_dataloader))
        clr = cyclical_lr(stepsize, args.min_lr, args.max_lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr])
    else:
        scheduler = None

    model_retr = None
    K_0 = 0
    K = 0
    if args.retr_loss:
        model_retr = copy.deepcopy(model)
        model_retr.load_state_dict(model.state_dict())
        K_0 = 2
        K = 2

    for epoch in range(args.epochs):
        if epoch % args.stepsize == 0 and epoch >= args.retr_warmup:
            K += K_0*.15
            # model_retr.load_state_dict(model.state_dict())

        training_loss, mets, K = train(
              model,
              training_dataloader,
              optimizer,
              criterion,
              epoch,
              writer,
              log_file,
              scheduler,
              args,
              args.log_every,
              model_retr,
              K,
          )
        
        val_loss, val_mets = evaluate(
              model,
              training_dataloader,
              criterion,
              epoch,
              writer,
              log_file,
              args,
              args.log_every,
        )

        update_str = "[Epoch: {} / {}] train_loss: {:.4f} | val_loss: {:.4f}".format(
                epoch + 1,
                args.epochs,
                training_loss,
                val_loss,
            )
        for m, metric in val_mets.items():
            update_str += f" | val {m}: {metric.avg:.4f}"
            if m == args.model_selecting_metric:
                epoch_score = metric.avg

        if args.model_selecting_metric == "loss":
            epoch_score = val_loss

        print(update_str)
        print("=" * 50)

        # learning rate scheduling

        if args.scheduler == "step":
            if args.optimizer == "sgd" and ((epoch + 1) % 3 == 0) and epoch > 0:
                current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
                current_lr /= 2
                print("Decreasing learning rate to {0}".format(current_lr))
                for param_group in optimizer.param_groups:
                    param_group["lr"] = current_lr

        # model checkpoint

        if epoch_score < best_score or epoch == args.epochs:
            best_score = epoch_score 
            best_epoch = epoch
            if args.checkpoint == 1:
                torch.save(
                    model.state_dict(),
                    args.output
                    + "model_{}_epoch_{}_lr_{}_loss_{}_score_{}.pth".format(
                        args.model_name,
                        epoch,
                        optimizer.state_dict()["param_groups"][0]["lr"],
                        round(training_loss, 4),
                        round(epoch_score, 4),
                    ),
                )
            torch.save(model.state_dict(), f"{args.output}{args.model_name}.pth")

        if bool(args.early_stopping):
            if epoch - best_epoch > args.patience > 0:
                print(
                    "Stop training at epoch {}. The lowest loss achieved is {} at epoch {}".format(
                        epoch, val_loss, best_epoch
                    )
                )

                break
    model.load_state_dict(torch.load(f"{args.output}{args.model_name}.pth"))



def train(
    model,
    training_dataloader,
    optimizer,
    criterion,
    epoch,
    writer,
    log_file,
    scheduler,
    args,
    print_every=25,
    model_retr = None,
    K=0
):
    model.train()
    do_retr = epoch >= args.retr_warmup
    if model_retr is not None:
        model_retr.eval()
    losses = AverageMeter()
    num_iter_per_epoch = len(training_dataloader)
    # torch.autograd.set_detect_anomaly(True)
    progress_bar = tqdm(enumerate(training_dataloader), total=num_iter_per_epoch)

    y_true = []
    y_pred = []
    LIST_METRICS = ["rmse", "mean_absolute_error", "explained_variance"] # add to the args object
    meters = {m: AverageMeter() for m in LIST_METRICS}
    if args.dry_run:
        i, (feats, labels) = list(enumerate(training_dataloader))[0]
        temp_out = model(feats)
        pred_shape = temp_out.shape
        print(f"DRY RUNNING, prediction shape will be filled with zeros of shape {temp_out.shape} and type {temp_out.type()}")
        print(f"labels are {labels.shape} and type {labels.type()}")

    for iter, batch in progress_bar:
        features, labels = batch
        if torch.cuda.is_available():
            features = features.cuda()
            labels = labels.cuda()
        
        features.requires_grad_()
        optimizer.zero_grad()
        if not args.dry_run:
            predictions = model(features).squeeze(dim=-1)
            if model_retr is not None and do_retr:
                stale_predictions = model_retr(features).squeeze(dim=-1)
            # print(predictions.shape)
            # print(torch.squeeze(predictions, dim=-1))
        else:
            predictions = torch.zeros(temp_out.shape)

        # y_true += labels.cpu().numpy().tolist()
        # y_pred += predictions.cpu().detach().numpy().tolist()

        loss = criterion(predictions, labels)
        if args.retr_loss and do_retr:
             r_loss = retr_loss(predictions, stale_predictions, labels, K=K).clip(min=0)
             loss = (loss*1.5 + r_loss*.5) if epoch in args.retr_burnin else loss + r_loss
        reg = call_regularizers(model, args)
        loss += reg
        if not args.dry_run:
            loss.backward()

        optimizer.step() #step
        #detatch hidden states
        # model._detach()
        if args.scheduler == "clr":
            scheduler.step()

        training_metrics = get_evaluation(
            labels.cpu().numpy(),
            predictions.cpu().detach().numpy(),
            list_metrics=LIST_METRICS,
        )

        losses.update(loss.item(), features.size(0))
        for m, meter in meters.items():
            meter.update(training_metrics[m], features.size(0))

        writer.add_scalar("Train/Loss", loss.item(), epoch * num_iter_per_epoch + iter)

        for metric, value in training_metrics.items():
            writer.add_scalar(
                f"Train/{metric}",
                value,
                epoch * num_iter_per_epoch + iter,
            )


        lr = optimizer.state_dict()["param_groups"][0]["lr"]

        if (iter % print_every == 0) and (iter > 0):
            print(
                "[Training - Epoch: {}], LR: {} , Iteration: {}/{} , Loss: {}".format(
                    epoch + 1, lr, iter, num_iter_per_epoch, losses.avg
                )
            )

        gc.collect()

    writer.add_scalar("Train/loss/epoch", losses.avg, epoch)
    for m, meter in meters.items():
        writer.add_scalar(f"Train/{m}/epoch", meter.avg, epoch)

    with open(log_file, "a") as f:
        f.write(f"Training on Epoch {epoch} \n")
        f.write(f"Average loss: {losses.avg} \n")
        for m, meter in meters.items():
            f.write(f"Average {m}: {meter.avg} \n")
        f.write("*" * 25)
        f.write("\n")
    model.eval()
    return losses.avg, meters, K




def evaluate(
    model,
    val_dataloader,
    criterion,
    epoch,
    writer,
    log_file,
    args,
    print_every=25,
):
    model.eval()
    losses = AverageMeter()
    num_iter_per_epoch = len(val_dataloader)
    # torch.autograd.set_detect_anomaly(True)
    progress_bar = tqdm(enumerate(val_dataloader), total=num_iter_per_epoch)

    y_true = []
    y_pred = []
    LIST_METRICS = ["rmse", "mean_absolute_error", "explained_variance"] # add to the args object
    meters = {m: AverageMeter() for m in LIST_METRICS}
    for iter, batch in progress_bar:
        features, labels = batch
        if torch.cuda.is_available():
            features = features.cuda()
            labels = labels.cuda()
        with torch.no_grad():
            predictions = model.forward(features).squeeze(dim=-1)
            loss = criterion(predictions, labels)
            reg = call_regularizers(model, args)
            loss += reg

        #detatch hidden states
        # model._detach()

        val_metrics = get_evaluation(
            labels.cpu().numpy(),
            predictions.cpu().detach().numpy(),
            list_metrics=LIST_METRICS,
        )

        losses.update(loss.item(), features.size(0))
        for m, meter in meters.items():
            meter.update(val_metrics[m], features.size(0))

        # writer.add_scalar("Val/Loss", loss.item(), epoch * num_iter_per_epoch + iter)

        # for metric, value in val_metrics.items():
        #     writer.add_scalar(
        #         f"Val/{metric}",
        #         value,
        #         epoch * num_iter_per_epoch + iter,
        #     )
        # lr = optimizer.state_dict()["param_groups"][0]["lr"]

        if (iter % print_every == 0) and (iter > 0):
            print(
                "[Val - Epoch: {}], Iteration: {}/{} , Loss: {}".format(
                    epoch + 1, iter, num_iter_per_epoch, losses.avg
                )
            )

        gc.collect()

    writer.add_scalar("Val/loss/epoch", losses.avg, epoch)
    for m, meter in meters.items():
        writer.add_scalar(f"Val/{m}/epoch", meter.avg, epoch)

    with open(log_file, "a") as f:
        f.write(f"Val on Epoch {epoch} \n")
        f.write(f"Average loss: {losses.avg} \n")
        for m, meter in meters.items():
            f.write(f"Average {m}: {meter.avg} \n")
        f.write("*" * 25)
        f.write("\n")
    return losses.avg, meters



### XPrize Predictor Class (csvs, rollout,etc)

In [None]:
# High level flow
## context input to the model h(x)
## action input to the model g()
## both input go in parallel to a LSTM model
## LSTM model's outputs are fed to respective dense layers
## The dense layer output is obtained
## The final output is (1-g)h()


# Copyright 2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.

import os
import urllib.request

# Suppress noisy Tensorflow debug logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# noinspection PyPep8Naming
# import keras.backend as K
import numpy as np
import pandas as pd
# from keras.callbacks import EarlyStopping
# from keras.constraints import Constraint
# from keras.layers import Dense
# from keras.layers import Input
# from keras.layers import LSTM
# from keras.layers import Lambda
# from keras.models import Model

# See https://github.com/OxCGRT/covid-policy-tracker
DATA_URL = "https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv"

# ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # CHANGE_FOR_SANDBOX
ROOT_DIR = os.path.abspath('') # For colab and ipynb
DATA_PATH = os.path.join(ROOT_DIR, 'data')
DATA_FILE_PATH = os.path.join(DATA_PATH, 'OxCGRT_latest.csv')
ADDITIONAL_CONTEXT_FILE = os.path.join(DATA_PATH, "Additional_Context_Data_Global.csv")
ADDITIONAL_US_STATES_CONTEXT = os.path.join(DATA_PATH, "US_states_populations.csv")
ADDITIONAL_UK_CONTEXT = os.path.join(DATA_PATH, "uk_populations.csv")
ADDITIONAL_BRAZIL_CONTEXT = os.path.join(DATA_PATH, "brazil_populations.csv")

NPI_COLUMNS = ['C1_School closing',
               'C2_Workplace closing',
               'C3_Cancel public events',
               'C4_Restrictions on gatherings',
               'C5_Close public transport',
               'C6_Stay at home requirements',
               'C7_Restrictions on internal movement',
               'C8_International travel controls',
               'H1_Public information campaigns',
               'H2_Testing policy',
               'H3_Contact tracing',
               'H6_Facial Coverings']

CONTEXT_COLUMNS = ['CountryName',
                   'RegionName',
                   'GeoID',
                   'Date',
                   'ConfirmedCases',
                   'ConfirmedDeaths',
                   'Population']
NB_LOOKBACK_DAYS = 21
NB_TEST_DAYS = 14
WINDOW_SIZE = 7
US_PREFIX = "United States / "
NUM_TRIALS = 1
LSTM_SIZE = 32
MAX_NB_COUNTRIES = 20


# Functions to be used for lambda layers in model
def _combine_r_and_d(x):
    r, d = x
    return r * (1. - d)


class XPrizePredictor(object):
    """
    A class that computes a fitness for Prescriptor candidates.
    """

    def __init__(self, path_to_model_weights, data_url, args = None):
        self.args = args if args else MockArgs()

        if path_to_model_weights:

            # Load model weights
            nb_context = 1  # Only time series of new cases rate is used as context
            nb_action = len(NPI_COLUMNS)
            self.predictor = self._construct_model(nb_context=nb_context,
                                                      nb_action=nb_action,
                                                      lstm_size=self.args.hidden_dim,
                                                      nb_lookback_days=self.args.nb_lookback_days)
            self.predictor.load_state_dict(torch.load(path_to_model_weights))
            self.predictor.eval()

            # Make sure data is available to make predictions
            if not os.path.exists(DATA_FILE_PATH):
                urllib.request.urlretrieve(DATA_URL, DATA_FILE_PATH)

        self.df = self._prepare_dataframe(data_url)
        geos = self.df.GeoID.unique()
        self.country_samples = self._create_country_samples(self.df, geos, self.args)
        
    def predict(self,
                start_date_str: str,
                end_date_str: str,
                path_to_ips_file: str) -> pd.DataFrame:
        start_date = pd.to_datetime(start_date_str, format='%Y-%m-%d')
        end_date = pd.to_datetime(end_date_str, format='%Y-%m-%d')
        nb_days = (end_date - start_date).days + 1

        # Load the npis into a DataFrame, handling regions
        npis_df = self._load_original_data(path_to_ips_file)

        # Prepare the output
        forecast = {"CountryName": [],
                    "RegionName": [],
                    "Date": [],
                    "PredictedDailyNewCases": []}

        # For each requested geo
        geos = npis_df.GeoID.unique()
        for g in geos:
            cdf = self.df[self.df.GeoID == g]
            if len(cdf) == 0:
                # we don't have historical data for this geo: return zeroes
                pred_new_cases = [0] * nb_days
                geo_start_date = start_date
            else:
                last_known_date = cdf.Date.max()
                # Start predicting from start_date, unless there's a gap since last known date
                geo_start_date = min(last_known_date + np.timedelta64(1, 'D'), start_date)
                npis_gdf = npis_df[(npis_df.Date >= geo_start_date) & (npis_df.Date <= end_date)]

                pred_new_cases = self._get_new_cases_preds(cdf, g, npis_gdf)

            # Append forecast data to results to return
            country = npis_df[npis_df.GeoID == g].iloc[0].CountryName
            region = npis_df[npis_df.GeoID == g].iloc[0].RegionName
            for i, pred in enumerate(pred_new_cases):
                forecast["CountryName"].append(country)
                forecast["RegionName"].append(region)
                current_date = geo_start_date + pd.offsets.Day(i)
                forecast["Date"].append(current_date)
                forecast["PredictedDailyNewCases"].append(pred)

        forecast_df = pd.DataFrame.from_dict(forecast)
        # Return only the requested predictions
        return forecast_df[(forecast_df.Date >= start_date) & (forecast_df.Date <= end_date)]

    def _get_new_cases_preds(self, c_df, g, npis_df):
        cdf = c_df[c_df.ConfirmedCases.notnull()]
        initial_context_input = self.country_samples[g]['X_test_context'][-1]
        initial_action_input = self.country_samples[g]['X_test_action'][-1]
        # Predictions with passed npis
        cnpis_df = npis_df[npis_df.GeoID == g]
        npis_sequence = np.array(cnpis_df[NPI_COLUMNS])
        # Get the predictions with the passed NPIs
        self.predictor.eval()
        preds = self._roll_out_predictions(self.predictor,
                                           initial_context_input,
                                           initial_action_input,
                                           npis_sequence)
        # Gather info to convert to total cases
        prev_confirmed_cases = np.array(cdf.ConfirmedCases)
        prev_new_cases = np.array(cdf.NewCases)
        initial_total_cases = prev_confirmed_cases[-1]
        pop_size = np.array(cdf.Population)[-1]  # Population size doesn't change over time
        # Compute predictor's forecast
        pred_new_cases = self._convert_ratios_to_total_cases(
            preds,
            self.args.window_size,
            prev_new_cases,
            initial_total_cases,
            pop_size)

        return pred_new_cases

    def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
        """
        Loads the Oxford dataset, cleans it up and prepares the necessary columns. Depending on options, also
        loads the Johns Hopkins dataset and merges that in.
        :param data_url: the url containing the original data
        :return: a Pandas DataFrame with the historical data
        """
        # Original df from Oxford
        df1 = self._load_original_data(data_url)

        # Additional context df (e.g Population for each country)
        df2 = self._load_additional_context_df()

        # Merge the 2 DataFrames
        df = df1.merge(df2, on=['GeoID'], how='left', suffixes=('', '_y'))

        # Drop countries with no population data
        df.dropna(subset=['Population'], inplace=True)

        #  Keep only needed columns
        columns = CONTEXT_COLUMNS + NPI_COLUMNS
        df = df[columns]

        # Fill in missing values
        self._fill_missing_values(df)

        # Compute number of new cases and deaths each day
        df['NewCases'] = df.groupby('GeoID').ConfirmedCases.diff().fillna(0)
        df['NewDeaths'] = df.groupby('GeoID').ConfirmedDeaths.diff().fillna(0)

        # Replace negative values (which do not make sense for these columns) with 0
        df['NewCases'] = df['NewCases'].clip(lower=0)
        df['NewDeaths'] = df['NewDeaths'].clip(lower=0)

        # Compute smoothed versions of new cases and deaths each day
        df['SmoothNewCases'] = df.groupby('GeoID')['NewCases'].rolling(
            self.args.window_size, center=False).mean().fillna(0).reset_index(0, drop=True)
        df['SmoothNewDeaths'] = df.groupby('GeoID')['NewDeaths'].rolling(
            self.args.window_size, center=False).mean().fillna(0).reset_index(0, drop=True)

        # Compute percent change in new cases and deaths each day
        df['CaseRatio'] = df.groupby('GeoID').SmoothNewCases.pct_change(
        ).fillna(0).replace(np.inf, 0) + 1
        df['DeathRatio'] = df.groupby('GeoID').SmoothNewDeaths.pct_change(
        ).fillna(0).replace(np.inf, 0) + 1

        # Add column for proportion of population infected
        df['ProportionInfected'] = df['ConfirmedCases'] / df['Population']

        # Create column of value to predict
        df['PredictionRatio'] = df['CaseRatio'] / (1 - df['ProportionInfected'])

        return df

    @staticmethod
    def _load_original_data(data_url):
        latest_df = pd.read_csv(data_url,
                                parse_dates=['Date'],
                                encoding="ISO-8859-1",
                                dtype={"RegionName": str,
                                       "RegionCode": str},
                                error_bad_lines=False)
        # GeoID is CountryName / RegionName
        # np.where usage: if A then B else C
        latest_df["GeoID"] = np.where(latest_df["RegionName"].isnull(),
                                      latest_df["CountryName"],
                                      latest_df["CountryName"] + ' / ' + latest_df["RegionName"])
        return latest_df

    @staticmethod
    def _fill_missing_values(df):
        """
        # Fill missing values by interpolation, ffill, and filling NaNs
        :param df: Dataframe to be filled
        """
        df.update(df.groupby('GeoID').ConfirmedCases.apply(
            lambda group: group.interpolate(limit_area='inside')))
        # Drop country / regions for which no number of cases is available
        df.dropna(subset=['ConfirmedCases'], inplace=True)
        df.update(df.groupby('GeoID').ConfirmedDeaths.apply(
            lambda group: group.interpolate(limit_area='inside')))
        # Drop country / regions for which no number of deaths is available
        df.dropna(subset=['ConfirmedDeaths'], inplace=True)
        for npi_column in NPI_COLUMNS:
            df.update(df.groupby('GeoID')[npi_column].ffill().fillna(0))

    @staticmethod
    def _load_additional_context_df():
        # File containing the population for each country
        # Note: this file contains only countries population, not regions
        additional_context_df = pd.read_csv(ADDITIONAL_CONTEXT_FILE,
                                            usecols=['CountryName', 'Population'])
        additional_context_df['GeoID'] = additional_context_df['CountryName']

        # US states population
        additional_us_states_df = pd.read_csv(ADDITIONAL_US_STATES_CONTEXT,
                                              usecols=['NAME', 'POPESTIMATE2019'])
        # Rename the columns to match measures_df ones
        additional_us_states_df.rename(columns={'POPESTIMATE2019': 'Population'}, inplace=True)
        # Prefix with country name to match measures_df
        additional_us_states_df['GeoID'] = US_PREFIX + additional_us_states_df['NAME']

        # Append the new data to additional_df
        additional_context_df = additional_context_df.append(additional_us_states_df)

        # UK population
        additional_uk_df = pd.read_csv(ADDITIONAL_UK_CONTEXT)
        # Append the new data to additional_df
        additional_context_df = additional_context_df.append(additional_uk_df)

        # Brazil population
        additional_brazil_df = pd.read_csv(ADDITIONAL_BRAZIL_CONTEXT)
        # Append the new data to additional_df
        additional_context_df = additional_context_df.append(additional_brazil_df)

        return additional_context_df


    @staticmethod
    def _create_country_samples(df: pd.DataFrame, geos: list, args: MockArgs) -> dict:
        """
        For each country, creates numpy arrays for Keras
        :param df: a Pandas DataFrame with historical data for countries (the "Oxford" dataset)
        :param geos: a list of geo names
        :return: a dictionary of train and test sets, for each specified country
        """
        context_column = 'PredictionRatio'
        action_columns = NPI_COLUMNS
        outcome_column = 'PredictionRatio'
        country_samples = {}
        for g in geos:
            cdf = df[df.GeoID == g]
            cdf = cdf[cdf.ConfirmedCases.notnull()]
            context_data = np.array(cdf[context_column])
            action_data = np.array(cdf[action_columns])
            outcome_data = np.array(cdf[outcome_column])
            context_samples = []
            action_samples = []
            outcome_samples = []
            nb_total_days = outcome_data.shape[0]
            for d in range(args.nb_lookback_days, nb_total_days):
                context_samples.append(context_data[d - args.nb_lookback_days:d])
                action_samples.append(action_data[d - args.nb_lookback_days:d])
                outcome_samples.append(outcome_data[d])
            if len(outcome_samples) > 0:
                X_context = np.expand_dims(np.stack(context_samples, axis=0), axis=2)
                X_action = np.stack(action_samples, axis=0)
                y = np.stack(outcome_samples, axis=0)
                country_samples[g] = {
                    'X_context': X_context,
                    'X_action': X_action,
                    'y': y,
                    'X_train_context': X_context[:-args.nb_test_days],
                    'X_train_action': X_action[:-args.nb_test_days],
                    'y_train': y[:-args.nb_test_days],
                    'X_test_context': X_context[-args.nb_test_days:],
                    'X_test_action': X_action[-args.nb_test_days:],
                    'y_test': y[-args.nb_test_days:],
                }
        return country_samples




    # Function for performing roll outs into the future
    # TODO debug with pytoch
    @staticmethod
    def _roll_out_predictions(predictor, initial_context_input, initial_action_input, future_action_sequence):
        nb_roll_out_days = future_action_sequence.shape[0]
        pred_output = np.zeros(nb_roll_out_days)
        context_input = np.expand_dims(np.copy(initial_context_input), axis=0) #
        action_input = np.expand_dims(np.copy(initial_action_input), axis=0)

        for d in range(nb_roll_out_days):
            action_input[:, :-1] = action_input[:, 1:]
            # Use the passed actions
            action_sequence = future_action_sequence[d]
            action_input[:, -1] = action_sequence
            pred = predictor.predict([context_input, action_input]) #TODO fix this for pytorch outputs model(input) (once model has een trained)
            pred_output[d] = pred
            context_input[:, :-1] = context_input[:, 1:]
            context_input[:, -1] = pred
        return pred_output

    # Functions for converting predictions back to number of cases
    @staticmethod
    def _convert_ratio_to_new_cases(ratio,
                                    window_size,
                                    prev_new_cases_list,
                                    prev_pct_infected):
        return (ratio * (1 - prev_pct_infected) - 1) * \
               (window_size * np.mean(prev_new_cases_list[-window_size:])) \
               + prev_new_cases_list[-window_size]

    def _convert_ratios_to_total_cases(self,
                                       ratios,
                                       window_size,
                                       prev_new_cases,
                                       initial_total_cases,
                                       pop_size):
        new_new_cases = []
        prev_new_cases_list = list(prev_new_cases)
        curr_total_cases = initial_total_cases
        for ratio in ratios:
            new_cases = self._convert_ratio_to_new_cases(ratio,
                                                         window_size,
                                                         prev_new_cases_list,
                                                         curr_total_cases / pop_size)
            # new_cases can't be negative!
            new_cases = max(0, new_cases)
            # Which means total case/s can't go down
            curr_total_cases += new_cases
            # Update prev_new_cases_list for next iteration of the loop
            prev_new_cases_list.append(new_cases)
            new_new_cases.append(new_cases)
        return new_new_cases

    @staticmethod
    def _smooth_case_list(case_list, window):
        return pd.Series(case_list).rolling(window).mean().to_numpy()

    def train(self, skip_training = False):
        print("Creating numpy arrays for Keras for each country...")
        geos = self._most_affected_geos(self.df, self.args.max_nb_countries, self.args.nb_lookback_days)
        country_samples = self._create_country_samples(self.df, geos, self.args)
        print("Numpy arrays created")

        # Aggregate data for training
        all_X_context_list = [country_samples[c]['X_train_context']
                              for c in country_samples]
        all_X_action_list = [country_samples[c]['X_train_action']
                             for c in country_samples]
        all_y_list = [country_samples[c]['y_train']
                      for c in country_samples]
        X_context = np.concatenate(all_X_context_list)
        X_action = np.concatenate(all_X_action_list)
        y = np.concatenate(all_y_list)

        # Clip outliers
        MIN_VALUE = 0.
        MAX_VALUE = 2.
        X_context = np.clip(X_context, MIN_VALUE, MAX_VALUE)
        y = np.clip(y, MIN_VALUE, MAX_VALUE)

        # Aggregate data for testing only on top countries
        test_all_X_context_list = [country_samples[g]['X_train_context']
                                   for g in geos]
        test_all_X_action_list = [country_samples[g]['X_train_action']
                                  for g in geos]
        test_all_y_list = [country_samples[g]['y_train']
                           for g in geos]
        test_X_context = np.concatenate(test_all_X_context_list)
        test_X_action = np.concatenate(test_all_X_action_list)
        test_y = np.concatenate(test_all_y_list)

        test_X_context = np.clip(test_X_context, MIN_VALUE, MAX_VALUE)
        test_y = np.clip(test_y, MIN_VALUE, MAX_VALUE)



        # Run full training several times to find best model
        # and gather data for setting acceptance threshold
        models = [] if not skip_training else [self.predictor]
        if not skip_training:
            train_losses = []
            val_losses = []
            test_losses = []
            for t in range(self.args.num_trials):
                print('Trial', t)

                # X_context, X_action, y = self._permute_data(X_context, X_action, y, seed=t)
                
                model = self._construct_model(
                    nb_context=X_context.shape[-1],
                    nb_action=X_action.shape[-1],
                    lstm_size=self.args.hidden_dim,
                    nb_lookback_days=self.args.nb_lookback_days
                )
                run(self.args, model, X_context, X_action, y, t)
                models.append(model)

        # Gather test info
        country_indeps = []
        country_predss = []
        country_casess = []
        for model in models: # TODO update rollouts to pytorch
            country_indep, country_preds, country_cases = self._lstm_get_test_rollouts(model,
                                                                                       self.df,
                                                                                       geos,
                                                                                       country_samples)
            country_indeps.append(country_indep)
            country_predss.append(country_preds)
            country_casess.append(country_cases)

        # Compute cases mae
        test_case_maes = []
        for m in range(len(models)):
            total_loss = 0
            for g in geos:
                true_cases = np.sum(np.array(self.df[self.df.GeoID == g].NewCases)[-self.args.nb_test_days:])
                pred_cases = np.sum(country_casess[m][g][-self.args.nb_test_days:])
                total_loss += np.abs(true_cases - pred_cases)
            test_case_maes.append(total_loss)

        # Select best model
        print(f"MAE per model: {test_case_maes}")
        best_model = models[np.argmin(test_case_maes)]
        print(f"best model was number {np.argmin(test_case_maes)}")
        for t in range(self.args.num_trials):
            logdir = args.log_path + args.model_name + f"_{t}/"
            writer = SummaryWriter(logdir)
            for i, mae in enumerate(test_case_maes):
                writer.add_scalar("Xprize/test/mae", mae, i)

        self.predictor = best_model
        torch.save(best_model.state_dict(), f"{self.args.output}{self.args.model_name}.pth")
        print("Done")
        return best_model

    @staticmethod
    def _most_affected_geos(df, nb_geos, min_historical_days):
        """
        Returns the list of most affected countries, in terms of confirmed deaths.
        :param df: the data frame containing the historical data
        :param nb_geos: the number of geos to return
        :param min_historical_days: the minimum days of historical data the countries must have
        :return: a list of country names of size nb_countries if there were enough, and otherwise a list of all the
        country names that have at least min_look_back_days data points.
        """
        # By default use most affected geos with enough history
        gdf = df.groupby('GeoID')['ConfirmedDeaths'].agg(['max', 'count']).sort_values(by='max', ascending=False)
        filtered_gdf = gdf[gdf["count"] > min_historical_days]
        geos = list(filtered_gdf.head(nb_geos).index)
        return geos

    # Shuffling data prior to train/val split
    def _permute_data(self, X_context, X_action, y, seed=301):
        np.random.seed(seed)
        p = np.random.permutation(y.shape[0])
        X_context = X_context[p]
        X_action = X_action[p]
        y = y[p]
        return X_context, X_action, y

    # Construct model
    def _construct_model(self, nb_context, nb_action, lstm_size=32, nb_lookback_days=21):
        model = CombinedModel(nb_lookback_days, nb_action, nb_context, lstm_size, self.args)
        return model

    # Train model
    def _train_model(self, training_model, X_context, X_action, y, epochs=1, verbose=0):
        pass

    # Functions for computing test metrics
    def _lstm_roll_out_predictions(self, model, initial_context_input, initial_action_input, future_action_sequence):
        nb_test_days = future_action_sequence.shape[0]
        pred_output = np.zeros(nb_test_days)
        context_input = np.expand_dims(np.copy(initial_context_input), axis=0)
        action_input = np.expand_dims(np.copy(initial_action_input), axis=0) # TODO make a dataset out of these using XPrizeDataloader see the pytorch train() and run() methods
        for d in range(nb_test_days):
            action_input[:, :-1] = action_input[:, 1:]
            action_input[:, -1] = future_action_sequence[d]

            pred = model.predict([context_input, action_input])
            pred_output[d] = pred
            context_input[:, :-1] = context_input[:, 1:]
            context_input[:, -1] = pred
        return pred_output

    def _lstm_get_test_rollouts(self, model, df, top_geos, country_samples):
        country_indep = {}
        country_preds = {}
        country_cases = {}
        for g in top_geos:
            X_test_context = country_samples[g]['X_test_context']
            X_test_action = country_samples[g]['X_test_action']
            country_indep[g] = model.predict([X_test_context, X_test_action]) # TODO make a dataset out of these using XPrizeDataloader see the pytorch train() and run() methods

            initial_context_input = country_samples[g]['X_test_context'][0]
            initial_action_input = country_samples[g]['X_test_action'][0]
            y_test = country_samples[g]['y_test']

            nb_test_days = y_test.shape[0]
            nb_actions = initial_action_input.shape[-1]

            future_action_sequence = np.zeros((nb_test_days, nb_actions))
            future_action_sequence[:nb_test_days] = country_samples[g]['X_test_action'][:, -1, :]
            current_action = country_samples[g]['X_test_action'][:, -1, :][-1]
            future_action_sequence[14:] = current_action
            preds = self._lstm_roll_out_predictions(model,
                                                    initial_context_input,
                                                    initial_action_input,
                                                    future_action_sequence)
            country_preds[g] = preds

            prev_confirmed_cases = np.array(
                df[df.GeoID == g].ConfirmedCases)[:-nb_test_days]
            prev_new_cases = np.array(
                df[df.GeoID == g].NewCases)[:-nb_test_days]
            initial_total_cases = prev_confirmed_cases[-1]
            pop_size = np.array(df[df.GeoID == g].Population)[0]

            pred_new_cases = self._convert_ratios_to_total_cases(
                preds, self.args.window_size, prev_new_cases, initial_total_cases, pop_size)
            country_cases[g] = pred_new_cases
        return country_indep, country_preds, country_cases


# MAIN
To train new models run the first two predictor lines
```
predictor = XPrizePredictor(None, DATA_URL, args = args)
predictor_model = predictor.train()
```


To load an existing model run the second two
```
predictor = XPrizePredictor(f"./modelsaves/{args.model_name}.pth", DATA_URL, args = args)
predictor_model = predictor.train(skip_training=True)
```

In [None]:
# This block was used for hyperparameter sweep. FOllow these examples for tensorbaord or skip to the prediction section.

VERSION_NUM = 12

block_criterion = "l1"

args = MockArgs()
# args.batch_size = 50
args.dry_run = False
# args.epochs = 20
# args.learning_rate = .003
# args.num_trials = 5
args.optimizer= "adam"
args.criterion= block_criterion
args.flush_history = False
args.xprize_lambda = 1e-6
args.set_regularizers()
args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}"

# model_path = f'{args.output}{args.model_name}.pth'
# predictor = XPrizePredictor(None, DATA_URL, args = args)
# predictor_model = predictor.train() # TERRIBLE with ratio .5

RUN_NUM = "2xa-K.15"


args.hidden_dim= 32
args.hidden_dim_action= 64
args.xprize_regularizer = True
args.set_regularizers()
args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}"

# model_path = f'{args.output}{args.model_name}.pth'
# predictor2 = XPrizePredictor(None, DATA_URL, args = args)
# predictor_model2 = predictor2.train() # pretty steady, 2.4 million mae with ratio .5

RUN_NUM = "2xca-K.15"


args.hidden_dim= 64
args.hidden_dim_action= 64
args.xprize_regularizer = True
args.set_regularizers()
args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}"

# model_path = f'{args.output}{args.model_name}.pth'
# predictor2 = XPrizePredictor(None, DATA_URL, args = args)
# predictor_model2 = predictor2.train() # pretty steady, 2.4 million mae with ratio .5

RUN_NUM = "2xc4xa-K.15"
args.hidden_dim= 64
args.hidden_dim_action= 128
args.xprize_regularizer = True
args.set_regularizers()
args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}"

model_path = f'{args.output}{args.model_name}.pth'
predictor2 = XPrizePredictor(None, DATA_URL, args = args)
predictor_model2 = predictor2.train() # pretty steady, 2.4 million mae with ratio .5


# args.l2_regularizer=True
# args.set_regularizers()
# args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}_L2"

# model_path = f'{args.output}{args.model_name}.pth'
# predictor3 = XPrizePredictor(None, DATA_URL, args = args)
# predictor_model3 = predictor3.train() # TERRIBLE with ratio .5


# args.orth_regularizer=True
# args.set_regularizers()
# args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}_L2_ORTH"

# model_path = f'{args.output}{args.model_name}.pth'
# predictor4 = XPrizePredictor(None, DATA_URL, args = args)
# predictor_model4 = predictor4.train() # 2.2 million mae took last epoch

# predictor = XPrizePredictor(f"./modelsaves/{args.model_name}.pth", DATA_URL, args = args)
# predictor_model = predictor.train(skip_training=True)


Creating numpy arrays for Keras for each country...
Numpy arrays created
Trial 0


100%|██████████| 18/18 [00:09<00:00,  1.92it/s]
100%|██████████| 18/18 [00:02<00:00,  7.38it/s]

[Epoch: 1 / 17] train_loss: 0.8277 | val_loss: 0.1517 | val rmse: 0.1454 | val mean_absolute_error: 0.0707 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.97it/s]
100%|██████████| 18/18 [00:02<00:00,  7.37it/s]

[Epoch: 2 / 17] train_loss: 0.0967 | val_loss: 0.0728 | val rmse: 0.1463 | val mean_absolute_error: 0.0708 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.99it/s]
100%|██████████| 18/18 [00:02<00:00,  7.39it/s]

[Epoch: 3 / 17] train_loss: 0.0729 | val_loss: 0.0689 | val rmse: 0.1458 | val mean_absolute_error: 0.0706 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.92it/s]
100%|██████████| 18/18 [00:02<00:00,  7.43it/s]

[Epoch: 4 / 17] train_loss: 0.0690 | val_loss: 0.0695 | val rmse: 0.1453 | val mean_absolute_error: 0.0707 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.97it/s]
100%|██████████| 18/18 [00:02<00:00,  7.42it/s]

[Epoch: 5 / 17] train_loss: 0.0691 | val_loss: 0.0681 | val rmse: 0.1433 | val mean_absolute_error: 0.0700 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.53it/s]

[Epoch: 6 / 17] train_loss: 0.0696 | val_loss: 0.0687 | val rmse: 0.1456 | val mean_absolute_error: 0.0707 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.96it/s]
100%|██████████| 18/18 [00:02<00:00,  7.46it/s]

[Epoch: 7 / 17] train_loss: 0.0689 | val_loss: 0.0687 | val rmse: 0.1439 | val mean_absolute_error: 0.0707 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.50it/s]

[Epoch: 8 / 17] train_loss: 0.0686 | val_loss: 0.0691 | val rmse: 0.1457 | val mean_absolute_error: 0.0712 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.42it/s]

[Epoch: 9 / 17] train_loss: 0.0696 | val_loss: 0.0686 | val rmse: 0.1446 | val mean_absolute_error: 0.0708 | val explained_variance: 0.0000
Stop training at epoch 8. The lowest loss achieved is 0.06856683434711562 at epoch 4
Trial 1



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.43it/s]

[Epoch: 1 / 17] train_loss: 1.6771 | val_loss: 0.2148 | val rmse: 0.1411 | val mean_absolute_error: 0.0689 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.34it/s]

[Epoch: 2 / 17] train_loss: 0.1741 | val_loss: 0.0985 | val rmse: 0.1442 | val mean_absolute_error: 0.0695 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.36it/s]

[Epoch: 3 / 17] train_loss: 0.0852 | val_loss: 0.0682 | val rmse: 0.1423 | val mean_absolute_error: 0.0698 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.47it/s]

[Epoch: 4 / 17] train_loss: 0.0700 | val_loss: 0.0684 | val rmse: 0.1429 | val mean_absolute_error: 0.0693 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.20it/s]

[Epoch: 5 / 17] train_loss: 0.0689 | val_loss: 0.0680 | val rmse: 0.1443 | val mean_absolute_error: 0.0698 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.96it/s]
100%|██████████| 18/18 [00:02<00:00,  7.26it/s]

[Epoch: 6 / 17] train_loss: 0.0678 | val_loss: 0.0678 | val rmse: 0.1444 | val mean_absolute_error: 0.0698 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.97it/s]
100%|██████████| 18/18 [00:02<00:00,  7.30it/s]

[Epoch: 7 / 17] train_loss: 0.0675 | val_loss: 0.0682 | val rmse: 0.1461 | val mean_absolute_error: 0.0704 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.40it/s]

[Epoch: 8 / 17] train_loss: 0.0673 | val_loss: 0.0678 | val rmse: 0.1449 | val mean_absolute_error: 0.0697 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.29it/s]

[Epoch: 9 / 17] train_loss: 0.0682 | val_loss: 0.0674 | val rmse: 0.1432 | val mean_absolute_error: 0.0697 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.51it/s]

[Epoch: 10 / 17] train_loss: 0.0682 | val_loss: 0.0686 | val rmse: 0.1444 | val mean_absolute_error: 0.0704 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.90it/s]
100%|██████████| 18/18 [00:02<00:00,  7.56it/s]

[Epoch: 11 / 17] train_loss: 0.0690 | val_loss: 0.0685 | val rmse: 0.1445 | val mean_absolute_error: 0.0696 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.29it/s]

[Epoch: 12 / 17] train_loss: 0.0680 | val_loss: 0.0678 | val rmse: 0.1437 | val mean_absolute_error: 0.0700 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.24it/s]

[Epoch: 13 / 17] train_loss: 0.0686 | val_loss: 0.0703 | val rmse: 0.1436 | val mean_absolute_error: 0.0699 | val explained_variance: 0.0000
Stop training at epoch 12. The lowest loss achieved is 0.07032734817928737 at epoch 8
Trial 2



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.41it/s]

[Epoch: 1 / 17] train_loss: 1.3923 | val_loss: 0.1781 | val rmse: 0.1489 | val mean_absolute_error: 0.0713 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.30it/s]

[Epoch: 2 / 17] train_loss: 0.1457 | val_loss: 0.1085 | val rmse: 0.1486 | val mean_absolute_error: 0.0716 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.32it/s]

[Epoch: 3 / 17] train_loss: 0.0903 | val_loss: 0.0819 | val rmse: 0.1490 | val mean_absolute_error: 0.0720 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.41it/s]

[Epoch: 4 / 17] train_loss: 0.0724 | val_loss: 0.0696 | val rmse: 0.1484 | val mean_absolute_error: 0.0715 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.25it/s]

[Epoch: 5 / 17] train_loss: 0.0700 | val_loss: 0.0693 | val rmse: 0.1494 | val mean_absolute_error: 0.0716 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.28it/s]

[Epoch: 6 / 17] train_loss: 0.0696 | val_loss: 0.0688 | val rmse: 0.1472 | val mean_absolute_error: 0.0707 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.50it/s]

[Epoch: 7 / 17] train_loss: 0.0697 | val_loss: 0.0692 | val rmse: 0.1480 | val mean_absolute_error: 0.0714 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.95it/s]
100%|██████████| 18/18 [00:02<00:00,  7.33it/s]

[Epoch: 8 / 17] train_loss: 0.0690 | val_loss: 0.0703 | val rmse: 0.1477 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.41it/s]

[Epoch: 9 / 17] train_loss: 0.0692 | val_loss: 0.0686 | val rmse: 0.1451 | val mean_absolute_error: 0.0705 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.93it/s]
100%|██████████| 18/18 [00:02<00:00,  7.46it/s]

[Epoch: 10 / 17] train_loss: 0.0694 | val_loss: 0.0698 | val rmse: 0.1488 | val mean_absolute_error: 0.0715 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:09<00:00,  1.90it/s]
100%|██████████| 18/18 [00:02<00:00,  7.66it/s]

[Epoch: 11 / 17] train_loss: 0.0702 | val_loss: 0.0688 | val rmse: 0.1467 | val mean_absolute_error: 0.0710 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.32it/s]

[Epoch: 12 / 17] train_loss: 0.0697 | val_loss: 0.0693 | val rmse: 0.1468 | val mean_absolute_error: 0.0715 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:09<00:00,  1.94it/s]
100%|██████████| 18/18 [00:02<00:00,  7.67it/s]


[Epoch: 13 / 17] train_loss: 0.0696 | val_loss: 0.0692 | val rmse: 0.1466 | val mean_absolute_error: 0.0713 | val explained_variance: -0.0000
Stop training at epoch 12. The lowest loss achieved is 0.06919974730246597 at epoch 8
MAE per model: [2816290.3218064103, 4157846.5849052356, 2090085.088321695]
best model was number 2
Done
Creating numpy arrays for Keras for each country...
Numpy arrays created
Trial 0


100%|██████████| 18/18 [00:09<00:00,  1.81it/s]
100%|██████████| 18/18 [00:02<00:00,  7.16it/s]

[Epoch: 1 / 17] train_loss: 1.1590 | val_loss: 0.1709 | val rmse: 0.1451 | val mean_absolute_error: 0.0703 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
100%|██████████| 18/18 [00:02<00:00,  6.95it/s]

[Epoch: 2 / 17] train_loss: 0.1278 | val_loss: 0.0976 | val rmse: 0.1501 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.87it/s]

[Epoch: 3 / 17] train_loss: 0.0792 | val_loss: 0.0777 | val rmse: 0.1476 | val mean_absolute_error: 0.0711 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.76it/s]
100%|██████████| 18/18 [00:02<00:00,  7.25it/s]

[Epoch: 4 / 17] train_loss: 0.0737 | val_loss: 0.0717 | val rmse: 0.1483 | val mean_absolute_error: 0.0709 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
100%|██████████| 18/18 [00:02<00:00,  7.05it/s]

[Epoch: 5 / 17] train_loss: 0.0718 | val_loss: 0.0741 | val rmse: 0.1482 | val mean_absolute_error: 0.0713 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.79it/s]
100%|██████████| 18/18 [00:02<00:00,  6.72it/s]

[Epoch: 6 / 17] train_loss: 0.0712 | val_loss: 0.0701 | val rmse: 0.1481 | val mean_absolute_error: 0.0709 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.08it/s]

[Epoch: 7 / 17] train_loss: 0.0700 | val_loss: 0.0697 | val rmse: 0.1465 | val mean_absolute_error: 0.0704 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.16it/s]

[Epoch: 8 / 17] train_loss: 0.0706 | val_loss: 0.0700 | val rmse: 0.1469 | val mean_absolute_error: 0.0706 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.11it/s]

[Epoch: 9 / 17] train_loss: 0.0705 | val_loss: 0.0716 | val rmse: 0.1494 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.80it/s]
100%|██████████| 18/18 [00:02<00:00,  7.20it/s]

[Epoch: 10 / 17] train_loss: 0.0706 | val_loss: 0.0709 | val rmse: 0.1503 | val mean_absolute_error: 0.0716 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.88it/s]

[Epoch: 11 / 17] train_loss: 0.0723 | val_loss: 0.0714 | val rmse: 0.1492 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000
Stop training at epoch 10. The lowest loss achieved is 0.07139254361391068 at epoch 6
Trial 1



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.97it/s]

[Epoch: 1 / 17] train_loss: 1.3472 | val_loss: 0.1076 | val rmse: 0.1454 | val mean_absolute_error: 0.0713 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.27it/s]

[Epoch: 2 / 17] train_loss: 0.1702 | val_loss: 0.1123 | val rmse: 0.1471 | val mean_absolute_error: 0.0716 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.03it/s]

[Epoch: 3 / 17] train_loss: 0.0908 | val_loss: 0.0735 | val rmse: 0.1475 | val mean_absolute_error: 0.0718 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.29it/s]

[Epoch: 4 / 17] train_loss: 0.0741 | val_loss: 0.0702 | val rmse: 0.1465 | val mean_absolute_error: 0.0711 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.08it/s]

[Epoch: 5 / 17] train_loss: 0.0711 | val_loss: 0.0704 | val rmse: 0.1459 | val mean_absolute_error: 0.0712 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.10it/s]

[Epoch: 6 / 17] train_loss: 0.0700 | val_loss: 0.0716 | val rmse: 0.1476 | val mean_absolute_error: 0.0717 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  6.88it/s]

[Epoch: 7 / 17] train_loss: 0.0711 | val_loss: 0.0706 | val rmse: 0.1476 | val mean_absolute_error: 0.0718 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.86it/s]

[Epoch: 8 / 17] train_loss: 0.0705 | val_loss: 0.0704 | val rmse: 0.1461 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000
Stop training at epoch 7. The lowest loss achieved is 0.07036312276290523 at epoch 3
Trial 2



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.20it/s]

[Epoch: 1 / 17] train_loss: 0.7233 | val_loss: 0.0774 | val rmse: 0.1450 | val mean_absolute_error: 0.0705 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.91it/s]

[Epoch: 2 / 17] train_loss: 0.0966 | val_loss: 0.0774 | val rmse: 0.1472 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.16it/s]


[Epoch: 3 / 17] train_loss: 0.0752 | val_loss: 0.0728 | val rmse: 0.1455 | val mean_absolute_error: 0.0704 | val explained_variance: -0.0000


100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.10it/s]

[Epoch: 4 / 17] train_loss: 0.0711 | val_loss: 0.0710 | val rmse: 0.1469 | val mean_absolute_error: 0.0711 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.89it/s]

[Epoch: 5 / 17] train_loss: 0.0695 | val_loss: 0.0701 | val rmse: 0.1445 | val mean_absolute_error: 0.0704 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.03it/s]

[Epoch: 6 / 17] train_loss: 0.0709 | val_loss: 0.0715 | val rmse: 0.1460 | val mean_absolute_error: 0.0708 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.76it/s]
100%|██████████| 18/18 [00:02<00:00,  6.97it/s]

[Epoch: 7 / 17] train_loss: 0.0711 | val_loss: 0.0711 | val rmse: 0.1461 | val mean_absolute_error: 0.0708 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
100%|██████████| 18/18 [00:02<00:00,  6.99it/s]

[Epoch: 8 / 17] train_loss: 0.0706 | val_loss: 0.0695 | val rmse: 0.1429 | val mean_absolute_error: 0.0700 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.05it/s]

[Epoch: 9 / 17] train_loss: 0.0705 | val_loss: 0.0703 | val rmse: 0.1478 | val mean_absolute_error: 0.0713 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
100%|██████████| 18/18 [00:02<00:00,  7.14it/s]

[Epoch: 10 / 17] train_loss: 0.0710 | val_loss: 0.0719 | val rmse: 0.1481 | val mean_absolute_error: 0.0716 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  6.97it/s]

[Epoch: 11 / 17] train_loss: 0.0705 | val_loss: 0.0698 | val rmse: 0.1452 | val mean_absolute_error: 0.0708 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
100%|██████████| 18/18 [00:02<00:00,  7.02it/s]


[Epoch: 12 / 17] train_loss: 0.0707 | val_loss: 0.0738 | val rmse: 0.1476 | val mean_absolute_error: 0.0709 | val explained_variance: 0.0000
Stop training at epoch 11. The lowest loss achieved is 0.07382188116510709 at epoch 7
MAE per model: [1627487.0607890007, 1428147.6382434994, 1697741.750663823]
best model was number 1
Done


In [None]:
RUN_NUM = "4xca-K.15"
args.hidden_dim= 128
args.hidden_dim_action= 128
args.xprize_regularizer = True
args.set_regularizers()
args.model_name = f"lstm_v{VERSION_NUM}-{RUN_NUM}_{args.optimizer}_{args.criterion}_lr_{str(args.learning_rate)[1:]}_batch_{args.batch_size}_XR_{args.xprize_lambda}"


model_path = f'{args.output}{args.model_name}.pth'
predictor3 = XPrizePredictor(None, DATA_URL, args = args)
predictor_model3 = predictor3.train()


Creating numpy arrays for Keras for each country...
Numpy arrays created
Trial 0


100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.92it/s]

[Epoch: 1 / 17] train_loss: 1.6147 | val_loss: 0.0739 | val rmse: 0.1485 | val mean_absolute_error: 0.0719 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.83it/s]

[Epoch: 2 / 17] train_loss: 0.1436 | val_loss: 0.1122 | val rmse: 0.1485 | val mean_absolute_error: 0.0718 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.83it/s]

[Epoch: 3 / 17] train_loss: 0.0838 | val_loss: 0.0784 | val rmse: 0.1496 | val mean_absolute_error: 0.0723 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.71it/s]

[Epoch: 4 / 17] train_loss: 0.0734 | val_loss: 0.0737 | val rmse: 0.1473 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.81it/s]

[Epoch: 5 / 17] train_loss: 0.0723 | val_loss: 0.0708 | val rmse: 0.1477 | val mean_absolute_error: 0.0720 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.74it/s]

[Epoch: 6 / 17] train_loss: 0.0717 | val_loss: 0.0700 | val rmse: 0.1473 | val mean_absolute_error: 0.0713 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.73it/s]

[Epoch: 7 / 17] train_loss: 0.0711 | val_loss: 0.0712 | val rmse: 0.1487 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.67it/s]
100%|██████████| 18/18 [00:02<00:00,  6.74it/s]

[Epoch: 8 / 17] train_loss: 0.0713 | val_loss: 0.0708 | val rmse: 0.1476 | val mean_absolute_error: 0.0720 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.79it/s]

[Epoch: 9 / 17] train_loss: 0.0712 | val_loss: 0.0704 | val rmse: 0.1487 | val mean_absolute_error: 0.0721 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.58it/s]

[Epoch: 10 / 17] train_loss: 0.0708 | val_loss: 0.0705 | val rmse: 0.1489 | val mean_absolute_error: 0.0722 | val explained_variance: -0.0000
Stop training at epoch 9. The lowest loss achieved is 0.07054322771728039 at epoch 5
Trial 1



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.84it/s]

[Epoch: 1 / 17] train_loss: 3.7734 | val_loss: 1.0483 | val rmse: 1.0557 | val mean_absolute_error: 1.0460 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.62it/s]

[Epoch: 2 / 17] train_loss: 4.2035 | val_loss: 1.0486 | val rmse: 1.0559 | val mean_absolute_error: 1.0462 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.76it/s]

[Epoch: 3 / 17] train_loss: 4.4257 | val_loss: 1.0479 | val rmse: 1.0554 | val mean_absolute_error: 1.0457 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.82it/s]

[Epoch: 4 / 17] train_loss: 4.7082 | val_loss: 1.0467 | val rmse: 1.0548 | val mean_absolute_error: 1.0453 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.61it/s]

[Epoch: 5 / 17] train_loss: 4.9116 | val_loss: 1.0434 | val rmse: 1.0558 | val mean_absolute_error: 1.0463 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.83it/s]

[Epoch: 6 / 17] train_loss: 4.0091 | val_loss: 1.5250 | val rmse: 1.9588 | val mean_absolute_error: 1.9538 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.95it/s]

[Epoch: 7 / 17] train_loss: 0.5267 | val_loss: 0.2148 | val rmse: 0.1492 | val mean_absolute_error: 0.0724 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.67it/s]

[Epoch: 8 / 17] train_loss: 0.2067 | val_loss: 0.1401 | val rmse: 0.1475 | val mean_absolute_error: 0.0715 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.66it/s]

[Epoch: 9 / 17] train_loss: 0.0922 | val_loss: 0.0845 | val rmse: 0.1490 | val mean_absolute_error: 0.0717 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.84it/s]

[Epoch: 10 / 17] train_loss: 0.0763 | val_loss: 0.0721 | val rmse: 0.1480 | val mean_absolute_error: 0.0716 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.68it/s]

[Epoch: 11 / 17] train_loss: 0.0721 | val_loss: 0.0718 | val rmse: 0.1483 | val mean_absolute_error: 0.0719 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.80it/s]

[Epoch: 12 / 17] train_loss: 0.0716 | val_loss: 0.0703 | val rmse: 0.1477 | val mean_absolute_error: 0.0713 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.86it/s]

[Epoch: 13 / 17] train_loss: 0.0707 | val_loss: 0.0711 | val rmse: 0.1478 | val mean_absolute_error: 0.0721 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.90it/s]

[Epoch: 14 / 17] train_loss: 0.0713 | val_loss: 0.0711 | val rmse: 0.1477 | val mean_absolute_error: 0.0718 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.57it/s]

[Epoch: 15 / 17] train_loss: 0.0724 | val_loss: 0.0709 | val rmse: 0.1486 | val mean_absolute_error: 0.0720 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.69it/s]

[Epoch: 16 / 17] train_loss: 0.0713 | val_loss: 0.0712 | val rmse: 0.1495 | val mean_absolute_error: 0.0722 | val explained_variance: -0.0000
Stop training at epoch 15. The lowest loss achieved is 0.07120429683062765 at epoch 11
Trial 2



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.72it/s]

[Epoch: 1 / 17] train_loss: 1.7100 | val_loss: 0.1275 | val rmse: 0.1468 | val mean_absolute_error: 0.0715 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.83it/s]

[Epoch: 2 / 17] train_loss: 0.2000 | val_loss: 0.0733 | val rmse: 0.1500 | val mean_absolute_error: 0.0726 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.70it/s]

[Epoch: 3 / 17] train_loss: 0.0859 | val_loss: 0.0756 | val rmse: 0.1469 | val mean_absolute_error: 0.0713 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.54it/s]

[Epoch: 4 / 17] train_loss: 0.0725 | val_loss: 0.0722 | val rmse: 0.1491 | val mean_absolute_error: 0.0722 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.75it/s]

[Epoch: 5 / 17] train_loss: 0.0706 | val_loss: 0.0734 | val rmse: 0.1492 | val mean_absolute_error: 0.0724 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.71it/s]

[Epoch: 6 / 17] train_loss: 0.0714 | val_loss: 0.0732 | val rmse: 0.1485 | val mean_absolute_error: 0.0718 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.72it/s]

[Epoch: 7 / 17] train_loss: 0.0734 | val_loss: 0.0718 | val rmse: 0.1492 | val mean_absolute_error: 0.0722 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.65it/s]

[Epoch: 8 / 17] train_loss: 0.0725 | val_loss: 0.0712 | val rmse: 0.1490 | val mean_absolute_error: 0.0719 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.79it/s]

[Epoch: 9 / 17] train_loss: 0.0706 | val_loss: 0.0718 | val rmse: 0.1467 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.83it/s]

[Epoch: 10 / 17] train_loss: 0.0749 | val_loss: 0.0738 | val rmse: 0.1475 | val mean_absolute_error: 0.0715 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.60it/s]
100%|██████████| 18/18 [00:02<00:00,  6.78it/s]

[Epoch: 11 / 17] train_loss: 0.0717 | val_loss: 0.0703 | val rmse: 0.1458 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.82it/s]

[Epoch: 12 / 17] train_loss: 0.0713 | val_loss: 0.0708 | val rmse: 0.1488 | val mean_absolute_error: 0.0717 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.84it/s]

[Epoch: 13 / 17] train_loss: 0.0727 | val_loss: 0.0702 | val rmse: 0.1460 | val mean_absolute_error: 0.0711 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.73it/s]

[Epoch: 14 / 17] train_loss: 0.0705 | val_loss: 0.0715 | val rmse: 0.1491 | val mean_absolute_error: 0.0717 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.95it/s]

[Epoch: 15 / 17] train_loss: 0.0700 | val_loss: 0.0702 | val rmse: 0.1479 | val mean_absolute_error: 0.0714 | val explained_variance: -0.0000



100%|██████████| 18/18 [00:11<00:00,  1.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.71it/s]

[Epoch: 16 / 17] train_loss: 0.0724 | val_loss: 0.0699 | val rmse: 0.1466 | val mean_absolute_error: 0.0709 | val explained_variance: 0.0000



100%|██████████| 18/18 [00:10<00:00,  1.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.79it/s]


[Epoch: 17 / 17] train_loss: 0.0719 | val_loss: 0.0718 | val rmse: 0.1485 | val mean_absolute_error: 0.0719 | val explained_variance: 0.0000
MAE per model: [1562964.7320169755, 1913905.0185998445, 1564491.0364312034]
best model was number 0
Done


## Prediction
To run predict you need to download another csv and add it to `/content/data/` (defualt directory is content). upload all files from [here](https://github.com/leaf-ai/covid-xprize/tree/master/covid_xprize/validation/data)

In [None]:

NPIS_INPUT_FILE = "./data/2020-09-30_historical_ip.csv" # same as ./data
start_date = "2020-08-01"
end_date = "2020-08-31"

# predictor = XPrizePredictor(f"./modelsaves/{args.model_name}.pth", DATA_URL, args = args)
# predictor_model = predictor.train(skip_training=True)

# Smooth l1 loss for the GRU has the best MAE so far: MAE per model: [2796938.8493539672]

Prediction timeit output: 

37.7 s ± 150 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [None]:
# %%timeit
predictor.predictor3 = predictor_model3
new_predictor = CombinedModel(predictor3.predictor.nb_loopback_days, predictor3.predictor.nb_action_dim, predictor3.predictor.nb_context_dim, predictor3.predictor.hidden_dim, predictor3.predictor.args)
new_predictor.load_state_dict(predictor3.predictor.state_dict())
print(new_predictor)
predictor3.predictor = new_predictor.eval()
# predictor.compute_maes()
# preds_df = predictor.predict(start_date, end_date, NPIS_INPUT_FILE)
# preds_df = predictor2.predict(start_date, end_date, NPIS_INPUT_FILE)
with torch.no_grad():
    preds_df = predictor3.predict(start_date, end_date, NPIS_INPUT_FILE)

CombinedModel(
  (action_encoder): ActionEncoder(
    (activation): Sigmoid()
    (rnn): LSTM(12, 128)
    (ln_rnn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (hidden2out): Linear(in_features=128, out_features=1, bias=True)
  )
  (context_encoder): ContextEncoder(
    (activation): Softplus(beta=1, threshold=20)
    (rnn): LSTM(1, 128)
    (ln_rnn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (hidden2out): Linear(in_features=128, out_features=1, bias=True)
  )
)


In [None]:
preds_df[preds_df["CountryName"] == "United States"]

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
5425,United States,,2020-08-01,168976.740204
5426,United States,,2020-08-02,217419.020776
5427,United States,,2020-08-03,205848.838913
5428,United States,,2020-08-04,219066.038361
5429,United States,,2020-08-05,165365.279051
...,...,...,...,...
7063,United States,Wyoming,2020-08-27,0.000000
7064,United States,Wyoming,2020-08-28,232.091698
7065,United States,Wyoming,2020-08-29,0.000000
7066,United States,Wyoming,2020-08-30,0.000000


## Next Steps

Refactor [predict.py](https://github.com/leaf-ai/covid-xprize/blob/master/covid_xprize/examples/predictors/lstm/predict.py) to work with out model. Move to the team sandbox and get it to run there for the robojudge. Maybe not too helpful but you can see the validation section of this example for more background. [example link](https://github.com/leaf-ai/covid-xprize/blob/master/covid_xprize/examples/predictors/lstm/Example-LSTM-Predictor.ipynb)

In [None]:
# TODO get this API to run
# !python predict.py -s 2020-08-01 -e 2020-08-04 -ip ../../../validation/data/2020-09-30_historical_ip.csv -o predictions/2020-08-01_2020-08-04.csv

##Tensorboard

For Localhosts


1) Start the remote server and run tensorboard on the server
```bash
tensorboard --logdir=./logs/ --host $SERVER_IP --port $SERVER_PORT
```
2) SSH tunnel the port to your laptop

```bash
ssh uname@hostname.edu -L 6006:$SERVER_IP:$SERVER_PORT
```

For Normal Jupyter notebooks (untested)




In [None]:
# %tensorboard --logdir ./logs

For Colab (untested)

In [None]:
# LOG_DIR = './log'
# get_ipython().system_raw(
#     'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
#     .format(LOG_DIR)
# )
# ! curl -s http://localhost:4040/api/tunnels | python3 -c \
#     "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"