In [1]:
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch # PyTorch
import torch.nn as nn # PyTorch neural network module
from torch.utils.data import Dataset, DataLoader # PyTorch data utilities
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import AdamW, SGD
# from apex.optimizers import FusedLAMB

import matplotlib.pyplot as plt
import os
import numpy as np
import time
import gc
import atexit
import copy
import random
import json
import csv
from datetime import datetime
import sys
np.set_printoptions(precision=4, suppress=True) 

sys.path.append('../')

cuda


In [None]:
# import custom files
from S2S import *
from data_utils import * 
from model_structure_param import * # Define hyperparameters
from plot_util import *
from common import *
from transformer import *

In [None]:
header = ['background_up', 'up_change_pred_pct', 'up_change_pred_precision', \
          'background_dn', 'dn_change_pred_pct', 'dn_change_pred_precision', \
            'background_none', 'none_change_pred_pct', 'none_change_pred_precision', \
                'accuracy', 'accuracy_lst', \
                    'pred_thres_change_accuracy', 'pred_thres_change_accuracy_lst', \
                        'pred_thres_change_precision', 'pred_thres_change_percision_lst', \
                            'pred_thres_actual_change_precision', 'pred_thres_actual_change_precision_lst', 'pred_thres_up_actual_precision', 'pred_thres_dn_actual_precision',\
                                'model_pth', 'time', 'best_k', 'epoch_num']

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn = nn.MSELoss(reduction = 'none')
torch.autograd.set_detect_anomaly(True)
print (device)

In [None]:
def get_direct_diff(y_batch,y_pred):
    y_batch_below_threshold = np.zeros_like(y_batch, dtype=bool)
    y_batch_below_threshold[np.abs(y_batch) < policy_threshold] = True
    actual_direct = np.clip(y_batch, 0, np.inf) # this turns negative to 0
    actual_direct[actual_direct != 0] = 1
    actual_direct[actual_direct == 0] = -1 # turns positive to 1
    actual_thres_direct = actual_direct.copy()
    actual_thres_direct[y_batch_below_threshold] = 0

    y_pred_below_threshold = np.zeros_like(y_pred, dtype=bool)
    y_pred_below_threshold[np.abs(y_pred) < policy_threshold] = True
    pred_direct = np.clip(y_pred, 0, np.inf) # turn all 
    pred_direct[pred_direct != 0] = 1
    pred_direct[pred_direct == 0] = -1
    pred_thres_direct = pred_direct.copy()
    pred_thres_direct[y_pred_below_threshold] = 0



    batch_size  = y_batch.shape[0]
    pred_window = y_batch.shape[1]

    all_cells_lst = np.full((pred_window,), batch_size)
    all_cells = batch_size * pred_window

    same_thres_cells_lst = np.count_nonzero(actual_thres_direct == pred_thres_direct, axis = 0)
    same_thres_cells = np.count_nonzero(actual_thres_direct == pred_thres_direct)

    actual_thres_change_lst = np.count_nonzero(actual_thres_direct != 0, axis = 0)
    true_pred_thres_change_lst = np.count_nonzero((actual_thres_direct == pred_thres_direct) & (actual_thres_direct != 0), axis = 0)
    all_pred_thres_change_lst = np.count_nonzero(pred_thres_direct != 0, axis = 0)

    actual_thres_change = np.sum(actual_thres_change_lst)
    true_pred_thres_change = np.sum(true_pred_thres_change_lst)
    all_pred_thres_change = np.sum(all_pred_thres_change_lst)
    
    t_thres_up = np.sum((actual_thres_direct == 1) & (pred_thres_direct == 1))
    f_thres_up = np.sum((actual_thres_direct != 1) & (pred_thres_direct == 1))

    t_thres_dn = np.sum((actual_thres_direct == -1) & (pred_thres_direct == -1))
    f_thres_dn = np.sum((actual_thres_direct != -1) & (pred_thres_direct == -1))

    t_thres_no = np.sum((actual_thres_direct == 0) & (pred_thres_direct == 0))
    f_thres_no = np.sum((actual_thres_direct != 0) & (pred_thres_direct == 0))

    actual_thres_up = np.sum(actual_thres_direct == 1)
    actual_thres_dn = np.sum(actual_thres_direct == -1)
    actual_thres_no = np.sum(actual_thres_direct == 0)

    assert actual_thres_up + actual_thres_dn + actual_thres_no == all_cells
    assert t_thres_up + f_thres_up + t_thres_dn + f_thres_dn + t_thres_no + f_thres_no == all_cells
    assert same_thres_cells == t_thres_up + t_thres_dn + t_thres_no, f'{same_thres_cells} != {t_thres_up} + {t_thres_dn} + {t_thres_no}'



    pred_thres_up_actual_up_lst = np.sum((actual_direct == 1) & (pred_thres_direct == 1), axis = 0)
    pred_thres_dn_actual_dn_lst = np.sum((actual_direct == -1) & (pred_thres_direct == -1), axis = 0)
    pred_thres_up_actual_up = np.sum(pred_thres_up_actual_up_lst)
    pred_thres_dn_actual_dn = np.sum(pred_thres_dn_actual_dn_lst)

    true_pred_thres_actual_change_lst = pred_thres_up_actual_up_lst + pred_thres_dn_actual_dn_lst
    true_pred_thres_actual_change = np.sum(true_pred_thres_actual_change_lst)

    pred_thres_up = np.sum(pred_thres_direct == 1)

    # print('get_direct_diff time: ', time.time()-start_time)

    return all_cells, same_thres_cells, \
            all_cells_lst, same_thres_cells_lst, \
            \
            actual_thres_up, actual_thres_dn, actual_thres_no, \
            t_thres_up, f_thres_up, t_thres_dn, f_thres_dn, t_thres_no, f_thres_no, \
            \
            actual_thres_change, all_pred_thres_change, true_pred_thres_change, true_pred_thres_actual_change, pred_thres_up_actual_up, pred_thres_dn_actual_dn, pred_thres_up,\
            actual_thres_change_lst, all_pred_thres_change_lst, true_pred_thres_change_lst, true_pred_thres_actual_change_lst