# Imports

In [None]:
import os, gc
import torch
from datetime import datetime

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
pd.set_option('display.max_columns', None)

# Initial setup

## GPU

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

## Pytorch lightning and forecasting

In [None]:
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer, MultiNormalizer

# Load input

In [None]:
from dataclasses import dataclass

@dataclass
class args:
    result_folder = '../results/TFT_baseline'
    figPath = os.path.join(result_folder, 'figures')
    checkpoint_folder = os.path.join(result_folder, 'checkpoints')
    input_filePath = '../2022_May_cleaned/Total.csv'

    configPath = '../configurations/baseline.json'
    # configPath = '../config_2022_August.json'

    model_path = os.path.join(checkpoint_folder, 'best-epoch=0.ckpt')

    # set this to false when submitting batch script, otherwise it prints a lot of lines
    show_progress_bar = True

    # interpret_output has high memory requirement
    # results in out-of-memery for Total.csv and a model of hidden size 64, even with 64GB memory
    interpret_train = 'Total.csv' not in input_filePath

In [None]:
start = datetime.now()
print(f'Started at {start}')

total_data = pd.read_csv(args.input_filePath)
print(total_data.shape)
total_data.head()

# Config

In [None]:
import json
import sys
sys.path.append( '..' )
from Class.Parameters import Parameters
from script.utils import *

with open(args.configPath, 'r') as input_file:
  config = json.load(input_file)

parameters = Parameters(config, **config)

In [None]:
targets = parameters.data.targets
time_idx = parameters.data.time_idx
tft_params = parameters.model_parameters

batch_size = tft_params.batch_size
max_prediction_length = tft_params.target_sequence_length
max_encoder_length = tft_params.input_sequence_length

# Processing

In [None]:
total_data['Date'] = pd.to_datetime(total_data['Date'].values) 
total_data['FIPS'] = total_data['FIPS'].astype(str)
print(f"There are {total_data['FIPS'].nunique()} unique counties in the dataset.")

## Adapt input to encoder length
Input data length needs to be a multiple of encoder length to created batch data loaders.

In [None]:
train_start = parameters.data.split.train_start
total_data = total_data[total_data['Date']>=train_start]
total_data[time_idx] = (total_data["Date"] - total_data["Date"].min()).apply(lambda x: x.days)

## Train validation test split and scaling

In [None]:
train_data, validation_data, test_data = train_validation_test_split(
    total_data, parameters
)

In [None]:
train_scaled, validation_scaled, test_scaled, target_scaler = scale_data(
    train_data, validation_data, test_data, parameters
)

## Create dataset and dataloaders

In [None]:
def prepare_data(data: pd.DataFrame, pm: Parameters, train=False):
  data_timeseries = TimeSeriesDataSet(
    data,
    time_idx= time_idx,
    target=targets,
    group_ids=pm.data.id, 
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_reals=pm.data.static_features,
    # static_categoricals=['FIPS'],
    time_varying_known_reals = pm.data.time_varying_known_features,
    time_varying_unknown_reals = pm.data.time_varying_unknown_features,
    target_normalizer = MultiNormalizer(
      [GroupNormalizer(groups=pm.data.id) for _ in range(len(targets))]
    )
  )

  if train:
    dataloader = data_timeseries.to_dataloader(train=True, batch_size=batch_size)
  else:
    dataloader = data_timeseries.to_dataloader(train=False, batch_size=batch_size*8)

  return data_timeseries, dataloader

In [None]:
_, train_dataloader = prepare_data(train_scaled, parameters)
_, validation_dataloader = prepare_data(validation_scaled, parameters)
_, test_dataloader = prepare_data(test_scaled, parameters)

del validation_scaled, test_scaled
gc.collect()

# Model

In [None]:
tft = TemporalFusionTransformer.load_from_checkpoint(args.model_path)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# Prediction Processor and PlotResults

In [None]:
from Class.PredictionProcessor import PredictionProcessor

processor = PredictionProcessor(
    time_idx, parameters.data.id[0], max_prediction_length, targets, 
    train_start, max_encoder_length
)

In [None]:
from Class.Plotter import *

plotter = PlotResults(args.figPath, targets, show=args.show_progress_bar)

# Evaluate

## Train results

### Average

In [None]:
print(f'\n---Training prediction--\n')

train_raw_predictions, train_index = tft.predict(
    train_dataloader, mode="raw", return_index=True, show_progress_bar=args.show_progress_bar
)

print('\nTrain raw prediction shapes\n')
for key in train_raw_predictions.keys():
    item = train_raw_predictions[key]
    if type(item) == list: print(key, f'list of length {len(item)}', item[0].shape)
    else: print(key, item.shape)

print('\n---Training results--\n')
train_predictions = upscale_prediction(targets, train_raw_predictions['prediction'], target_scaler, max_prediction_length)
train_result_merged = processor.align_result_with_dataset(train_data, train_predictions, train_index)
show_result(train_result_merged, targets)

plotter.summed_plot(train_result_merged, type='Train_error', plot_error=True)
gc.collect()

### By future days

In [None]:
# gc.collect()
# for day in range(1, max_prediction_length+1):
#     print(f'Day {day}')
#     df = processor.align_result_with_dataset(train_data, train_predictions, train_index, target_time_step = day)
#     show_result(df, targets)
#     plotter.summed_plot(df, type=f'Train_day_{day}')
#     break

## Validation results

In [None]:
print(f'\n---Validation results--\n')
validation_raw_predictions, validation_index = tft.predict(
    validation_dataloader, return_index=True, show_progress_bar=args.show_progress_bar
)
validation_predictions = upscale_prediction(targets, validation_raw_predictions, target_scaler, max_prediction_length)

validation_result_merged = processor.align_result_with_dataset(validation_data, validation_predictions, validation_index)
show_result(validation_result_merged, targets)
plotter.summed_plot(validation_result_merged, type='Validation')
gc.collect()

## Test results

### Average

In [None]:
print(f'\n---Test results--\n')
test_raw_predictions, test_index = tft.predict(
    test_dataloader, mode="raw", return_index=True, show_progress_bar=args.show_progress_bar
)
test_predictions = upscale_prediction(targets, test_raw_predictions['prediction'], target_scaler, max_prediction_length)

test_result_merged = processor.align_result_with_dataset(test_data, test_predictions, test_index)
show_result(test_result_merged, targets)
plotter.summed_plot(test_result_merged, 'Test')
gc.collect()

## Dump results

In [None]:
train_result_merged['split'] = 'train'
validation_result_merged['split'] = 'validation'
test_result_merged['split'] = 'test'
df = pd.concat([train_result_merged, validation_result_merged, test_result_merged])
df.to_csv(os.path.join(plotter.figPath, 'predictions.csv'), index=False)

df.head()

In [None]:
del train_predictions, validation_predictions, test_predictions
gc.collect()

## Evaluation by county

In [None]:
fips_codes = test_result_merged['FIPS'].unique()

print(f'\n---Per county test results--\n')
count = 5

for index, fips in enumerate(fips_codes):
    if index == count: break

    print(f'FIPS {fips}')
    df = test_result_merged[test_result_merged['FIPS']==fips]
    show_result(df, targets)
    print()

In [None]:
del train_result_merged, validation_result_merged, test_result_merged

# Interpret

In [None]:
if args.interpret_train:
    raw_predictions = train_raw_predictions
    data = train_data
    index = train_index
else:
    raw_predictions = test_raw_predictions
    data = test_data
    index = test_index

## Weight plotter

In [None]:
plotWeights = PlotWeights(
    args.figPath, max_encoder_length, tft, 
    show=args.show_progress_bar
)

## Attention weights

In [None]:
attention_mean, attention = processor.get_mean_attention(
    tft.interpret_output(raw_predictions), 
    index, return_attention=True
)
plotWeights.plot_attention(
    attention_mean, figure_name='Daily_attention', 
    limit=0, enable_markers=False, title='Attention with dates'
)
gc.collect()
attention_weekly = processor.get_attention_by_weekday(attention_mean)
plotWeights.plot_weekly_attention(attention_weekly, figure_name='Weekly_attention')

attention_mean.round(3).to_csv(os.path.join(plotWeights.figPath, 'attention_mean.csv'), index=False)
attention.round(3).to_csv(os.path.join(plotWeights.figPath, 'attention.csv'), index=False)

## Variable Importance

In [None]:
interpretation = tft.interpret_output(
    raw_predictions, reduction="sum"
)
print(f'Interpretation:\n{interpretation}')

In [None]:
results = pd.DataFrame(columns=['Feature', 'Importance', 'Normalized', 'Type'])

for key in interpretation.keys():
    if '_variables' not in key: continue

    features = tft.__getattribute__(key)
    importance = interpretation[key]
    normalized = importance*100/torch.sum(importance)

    for index in range(len(features)):
        results.loc[len(results.index)] =  [
            features[index], importance[index].item(), normalized[index].item(), key
        ]

    print(f'{key}: {features}')
    print(f'Importance: {importance}')
    print(f'Normalized: {normalized}\n')

In [None]:
figures = plotWeights.plot_interpretation(interpretation)
for key in figures.keys():
    figure = figures[key]
    if args.interpret_train:
        figure.savefig(os.path.join(plotter.figPath, f'Train_{key}.jpg'), dpi=DPI) 
    else:
        figure.savefig(os.path.join(plotter.figPath, f'Test_{key}.jpg'), dpi=DPI)

In [None]:
results.round(3).to_csv(
    os.path.join(args.figPath, 'importance.csv'), 
    index=False
)

# End

In [None]:
print(f'Ended at {datetime.now()}. Elapsed time {datetime.now() - start}')