<a href="https://colab.research.google.com/github/ammarisme/covid-19/blob/master/CV19_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
!pip install torch-geometric \
  torch-sparse==latest+cu101 \
  torch-scatter==latest+cu101 \
  torch-cluster==latest+cu101 \
  -f https://pytorch-geometric.com/whl/torch-1.5.0.html

!pip install torch-geometric \
  torch-sparse==latest+cu101 \
  torch-scatter==latest+cu101 \
  torch-cluster==latest+cu101 \
  -f https://pytorch-geometric.com/whl/torch-1.5.0.html

In [0]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
from functools import reduce
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader, InMemoryDataset

import numpy as np
import pandas as pd

from os import listdir
from os.path import isfile, join

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('running on '+ ("GPU" if torch.cuda.is_available() else "CPU"))

In [0]:
from google.colab import drive
drive.mount('/content/drive')
PATH = '/content/drive/My Drive/covid'

In [0]:
class CovidDataSet(InMemoryDataset):
    def __init__(self, root, input_sequence, output_sequence, transform=None, pre_transform=None):
        super(CovidDataSet, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
      if os.path.exists(self.root+PROCESSED_DIR):
        return self.root+'/cleaned'
      else:
        os.mkdir(self.root+PROCESSED_DIR)
        return self.root+'/cleaned'
        
    @property
    def processed_dir(self):
      if os.path.exists(self.root+PROCESSED_DIR):
        return self.root+PROCESSED_DIR
      else:
        os.mkdir(self.root+PROCESSED_DIR)
        return self.root+PROCESSED_DIR

    @property
    def raw_file_names(self):
      mypath = self.raw_dir
      filenames = [f for f in listdir(mypath) if isfile(join(mypath, f))]
      return filenames

    @property
    def processed_file_names(self):
        return ['processed.dt']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []

        for raw_path in self.raw_paths:
          df = pd.read_csv(raw_path)
          for synthetic_seq in df['synthesis_seq'].unique():
            synthetic_data = df[df['synthesis_seq']==synthetic_seq]

            for country in synthetic_data['countryterritoryCode'].unique():
              country_data = synthetic_data[synthetic_data['countryterritoryCode'] == country]
              popData2018 = country_data['popData2018'].values[0]
              _country_code = country_data['_country_code'].values[0]
              
              country_data_i = country_data[:-output_sequence_len]
              country_data_o = country_data[input_sequence_len:]
              
              
              country_data_array = np.array([country_data_i['cases'].to_numpy(),
                                             country_data_i['deaths'].to_numpy(),
                                             country_data_i['popData2018'].to_numpy(),
                                             country_data_i['_country_code'].to_numpy(),
                                             country_data_i['countriesAndTerritories'].to_numpy(),
                                             country_data_i['geoId'].to_numpy(),
                                             country_data_i['countryterritoryCode'].to_numpy(),
                                             country_data_i['continentExp'].to_numpy()
                                             ])
              feature_length = len(country_data_array)
              country_data_array = country_data_array.reshape(feature_length,len(country_data_i))

              country_data_array_y = np.array([country_data_o['cases'].to_numpy(), country_data_o['deaths'].to_numpy()])
              country_data_array_y = country_data_array_y.reshape(2,len(country_data_o))

              x = country_data_array[:feature_length].T
              y = country_data_array_y[:2].T

              sets =0
              x_list = []
              dict_x = dict()
              for i in range(input_sequence_len):
                array_len = ((len(x) -i) - ((len(x)-i)%input_sequence_len))+i
                if array_len <= 0:
                  continue
                sets = int( array_len/ input_sequence_len)
                if sets <= 0:
                  continue
                #print('input seq : ', i , ' ', array_len , ' ',array_len-i , ' number of sets : ', sets)
                x_temp = x[i:array_len].T.reshape(sets,feature_length,input_sequence_len)
                x_temp = x_temp.reshape(feature_length,sets,input_sequence_len)
                uniq_keys = np.array([i+(input_sequence_len*k) for k in range(input_sequence_len)])
                
                arrays_split = np.hsplit(x_temp,sets)
                dict_x.update(dict(zip(uniq_keys, arrays_split)))
              
              dict_y = dict()
              y_list = []
              for i in range(output_sequence_len):
                array_len_y = (len(y)-i) - ((len(y)  - i)%output_sequence_len)+i
                if array_len_y <= 0:
                  continue
                sets = int(array_len_y / output_sequence_len)
                if sets <= 0:
                  continue
                
                #print('output seq : ', i , ' ', array_len_y , ' ',array_len_y-(i) , ' number of sets : ', sets)
                y_temp = y[i:array_len_y].T.reshape(sets, 2, output_sequence_len)
                uniq_keys = np.array([i+(output_sequence_len*k) for k in range(output_sequence_len)])
                y_temp = y_temp.reshape(2,sets,output_sequence_len)
                arrays_split = np.hsplit(y_temp,sets)
                dict_y.update(dict(zip(uniq_keys, arrays_split)))
              

              temp_x_list  = [dict_x[i].T for i in sorted(dict_x.keys())]
              temp_y_list  = [dict_y[i].T for i in sorted(dict_y.keys())]

              #_country_code,popData2018
              xy_list = [Data(x = torch.from_numpy(features).type(torch.FloatTensor).squeeze()) for features in temp_x_list]

              for i in sorted(dict_y.keys()):
                xy_list[i].y = torch.from_numpy(temp_y_list[i]).squeeze()

              data_list += xy_list
          print('processed : '+ raw_path)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [0]:
#os.remove(INPUT_ROOT+"/processed/pre_transform.pt")
#os.remove(INPUT_ROOT+"/processed/processed.dt")
#os.remove(INPUT_ROOT+"/processed/pre_filter.pt")
#test dataset and loader
input_sequence_len = 5
output_sequence_len = 5
feature_length = 8
output_size = 2
INPUT_ROOT = PATH+'/input/training'
DATA_TAG = "seq2seq_5_5"
PROCESSED_DIR = '/processed_'+DATA_TAG

test_dataset = CovidDataSet(INPUT_ROOT, input_sequence_len, input_sequence_len)

In [0]:
n= len(test_dataset)
print('possible batch sizes : ', set(reduce(list.__add__, ([i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0))))
batch_size = 2819

def data_set_is_correct():
  test_dataloader = DataLoader(test_dataset, batch_size)
  for batch in test_dataloader:
    inputs = np.array([
                       batch.x.view(batch_size, input_sequence_len, feature_length)[26+i][-1][0].item() for i in range(feature_length)
              ])
    output = batch.y.view(batch_size, output_sequence_len, output_size)[25][0:feature_length].T[0].detach().numpy()
    return np.sum(output-inputs) == 0

if data_set_is_correct() == True: print('Dataset is correct')
else: print('Corrupt dataset')