In [68]:
import os, datetime, argparse, tqdm, pickle
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import utils
import mymodels
from os.path import join as pjoin

In [69]:

parser = argparse.ArgumentParser()
parser.add_argument('--file_traf', type = str, default = '../prepdata/traffic-volume-A-20180101-20190101.df')
parser.add_argument('--file_coarse', type = str, default = '../prepdata/coarse_grained_lte.h5')
parser.add_argument('--file_fine', type = str, default = '../prepdata/fine_grained_lte.h5')
parser.add_argument('--model_name', type = str, default = 'LastRepeat')
parser.add_argument('--memo', type = str, default = '')

parser.add_argument('--train_ratio', type = float, default = 0.7,
                    help = 'training set [default : 0.7]')
parser.add_argument('--val_ratio', type = float, default = 0.1,
                    help = 'validation set [default : 0.1]')
parser.add_argument('--test_ratio', type = float, default = 0.2,
                    help = 'testing set [default : 0.2]')
parser.add_argument('--cnn_size', type = int, default = 3)
parser.add_argument('--P', type = int, default = 6)
parser.add_argument('--Q', type = int, default = 1)
parser.add_argument('--time_slot', type = int, default = 60, 
                    help = 'a time step is 60 mins')

parser.add_argument('--L', type = int, default = 1,
                    help = 'number of STAtt Blocks')
# parser.add_argument('--LZ', type = int, default = 2,
#                     help = 'number of Regional STAtt Blocks')
parser.add_argument('--K', type = int, default = 8,
                    help = 'number of attention heads')
parser.add_argument('--d', type = int, default = 8,
                    help = 'dims of each head attention outputs')
parser.add_argument('--D', type = int, default = 64)
parser.add_argument('--batch_size', type = int, default = 32,
                    help = 'batch size')
parser.add_argument('--max_epoch', type = int, default = 1000,
                    help = 'epoch to run')
parser.add_argument('--patience', type = int, default = 10,
                    help = 'patience for early stop')
parser.add_argument('--learning_rate', type=float, default = 0.001,
                    help = 'initial learning rate')
parser.add_argument('--decay_epoch', type=int, default = 5,
                    help = 'decay epoch')

parser.add_argument('--save_dir', default = 'test',
                    help = 'save_dir')

args = parser.parse_args([])

    
args.test_name = args.memo + args.model_name

args.model_checkpoint_dir = f'checkpoint/'
args.model_checkpoint = os.path.join(args.model_checkpoint_dir, args.test_name)
args.test_dir = f'test_exp/'

if not os.path.isdir(args.model_checkpoint_dir):
    os.makedirs(args.model_checkpoint_dir)
if not os.path.isdir(args.test_dir):
    os.makedirs(args.test_dir)


(trainX, trainZC, trainZF, trainTE, trainY, 
            valX, valZC, valZF, valTE, valY, 
            testX, testZC, testZF, testTE, testY, extdata) = utils.loadVolumeData2(args)

In [70]:
parser = argparse.ArgumentParser()
parser.add_argument('--train_ratio', type = float, default = 0.7,
                    help = 'training set [default : 0.7]')
parser.add_argument('--val_ratio', type = float, default = 0.1,
                    help = 'validation set [default : 0.1]')
parser.add_argument('--test_ratio', type = float, default = 0.2,
                    help = 'testing set [default : 0.2]')
parser.add_argument('--P', type = int, default = 6)
parser.add_argument('--Q', type = int, default = 1)
parser.add_argument('--time_slot', type = int, default = 60, 
                    help = 'a time step is 60 mins')

args = parser.parse_args([])

In [71]:
traf_df = pd.read_hdf('../prepdata/traffic-volume-A-20180101-20190101.df')
warns = np.isnan(traf_df.iloc[24*59:24*151]).sum(0) > 900

In [72]:
skip_models = []# ['DNN', 'MyDCGRU_GST', 'MyGRU', 'MyLSTM']

args.test_dir = f'test_exp/'
label = np.load(os.path.join(args.test_dir, 'label.npy'))[..., ~warns]

alive_list = dict()
group_results = dict()
group_preds = dict()
for fname in sorted(os.listdir(args.test_dir)):
    skip = False
    for w in skip_models:
        if w in fname:
            skip = True
            break

    if 'pred' in fname and not skip:
        pred_tmp = np.load(os.path.join(args.test_dir, fname))[..., ~warns]
        group_preds.setdefault(fname[7:-4], [])
        group_preds[fname[7:-4]].append(pred_tmp)
        # print((fname[5:-4] + ' '*20)[:20], '\t'.join('%.4f'%_ for _ in utils.metric(pred_tmp, label)), sep='\t')

        group_results.setdefault(fname[7:-4], {})
        # try:
        for q in range(args.Q):
            # print(pred_tmp.shape, label.shape)
            group_results[fname[7:-4]].setdefault(q, [])
            group_results[fname[7:-4]][q].append(utils.metric(pred_tmp[:, q, :], label[:, q, :]))

        group_results[fname[7:-4]].setdefault('all', [])
        group_results[fname[7:-4]]['all'].append(utils.metric(pred_tmp, label))
        # except:
        #     pass


def Sorting(lst):
    lst2 = sorted(lst, key=len)
    return lst2
      
# Driver code
lst = list(group_results.keys())
# print(Sorting(lst))
lst = Sorting(lst)


for k in lst: #sorted(group_results):
    res_list = []
    for q in range(args.Q):
        ol = len(group_results[k][q])
        tmae = np.array(group_results[k][q]).mean(0)[0]
        if args.Q > 1:
            res_list.append(tmae)

    # print((k + '_'+str(ol) + f'-TA' + ' '*20)[:20], '\t'.join('%.4f'%_ for _ in np.array(group_results[k]['all']).mean(0)), sep='\t')
    res_list.extend(list(np.array(group_results[k]['all']).mean(0)))
    print((k + '_'+str(ol) + ' '*30)[:20], '\t'.join('%.3f'%_ for _ in res_list), sep='\t')

    
for key in group_preds:
    mae_preds = []
    for pred in group_preds[key]:
        mae_pred = utils.metric(pred_tmp[:, q, :], label[:, q, :])[0]
        mae_preds.append(mae_pred)
    
    mae_preds    
        
    
    group_preds[key] = np.mean(np.array(group_preds[key]), 0)

MyGRU_3             	120.624	186.238	0.172
MyLSTM_3            	103.820	155.810	0.161
MyGM0ZC_3           	58.785	103.466	0.089
MyGM0ZF_3           	57.143	103.315	0.088
MyGMAN0_3           	55.731	100.714	0.084
MyGM0ZCF_3          	54.658	99.521	0.084
MyGM0ZCFB_3         	58.244	104.305	0.089
MyGM0ZCFC_3         	55.582	99.788	0.085
MyGM0ZCFW_3         	54.725	96.899	0.084
MyGMDCGRU_3         	59.936	104.910	0.087
MyDCGRUSTE_3        	65.065	112.684	0.092
MyDCGRUSTE0ZC_3     	58.982	102.486	0.088
MyDCGRUSTE0ZF_3     	58.781	103.878	0.086
MyDCGRUSTE0ZCF_3    	55.192	96.653	0.084
MyDCGRUSTE0ZCFB_3   	50.578	91.881	0.079
MyDCGRUSTE0ZCFC_3   	56.064	98.432	0.085
MyDCGRUSTE0ZCFW_3   	57.306	99.696	0.086
MyDCGRUSTE0ZCFB2_3  	51.022	93.027	0.080
MyDCGRUSTE0ZCFBB_3  	61.883	107.797	0.088
MyDCGRUSTE0ZCFBV_3  	99.306	178.843	0.152
MyDCGRUSTE0ZCFBPB_3 	62.793	109.238	0.094


In [52]:
for i in range(label.shape[2]):
    print(f'Sensor {i}', fwarns.keys().tolist()[i], utils.metric(group_preds['MyDCGRUSTE'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyGM0ZCF'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i], label[..., i])[0],sep='\t')

Sensor 0	A-01-I	98.381035	84.142136	83.36131	76.891464
Sensor 1	A-01-O	85.56044	74.35148	72.34007	71.99749
Sensor 2	A-02-I	100.03651	86.14253	87.28106	87.74827
Sensor 3	A-02-O	88.88697	77.47848	77.17087	72.59685
Sensor 4	A-03-I	60.33006	49.18087	46.03513	45.49823
Sensor 5	A-03-O	49.27024	41.501263	40.99067	40.036896
Sensor 6	A-04-I	33.42005	26.053843	26.620255	28.314121
Sensor 7	A-04-O	25.019495	22.420813	22.002342	21.848167
Sensor 8	A-07-I	30.665468	31.415398	30.85841	30.280851
Sensor 9	A-07-O	47.42576	44.203224	43.15869	43.75976
Sensor 10	A-08-I	60.197437	56.046104	56.918182	52.312714
Sensor 11	A-08-O	58.073586	57.634388	56.261272	54.435223
Sensor 12	A-09-O	52.13344	40.89382	40.402687	38.168858
Sensor 13	A-10-I	90.99236	78.09109	71.06549	70.33343
Sensor 14	A-10-O	84.272675	76.11994	76.91292	72.09267
Sensor 15	A-11-I	54.207973	57.600925	59.184933	54.920635
Sensor 16	A-11-O	67.9728	65.646576	65.208145	63.425682
Sensor 17	A-13-I	72.95406	63.72093	67.093864	67.75771
Sensor 18	A-13-O	54.7

In [11]:
traf_df.iloc[24*59:24*151]

Unnamed: 0,A-01-I,A-01-O,A-02-I,A-02-O,A-03-I,A-03-O,A-04-I,A-04-O,A-05-I,A-05-O,...,A-20-I,A-20-O,A-21-I,A-21-O,A-22-I,A-22-O,A-23-I,A-23-O,A-24-I,A-24-O
2018-03-01 00:00:00,648.0,801.0,805.0,980.0,305.0,366.0,29.0,90.0,,,...,739.0,1534.0,226.0,368.0,333.0,508.0,59.0,30.0,51.0,77.0
2018-03-01 01:00:00,537.0,635.0,739.0,840.0,336.0,328.0,16.0,55.0,,,...,528.0,1066.0,163.0,309.0,268.0,349.0,51.0,15.0,33.0,69.0
2018-03-01 02:00:00,402.0,476.0,574.0,697.0,305.0,278.0,17.0,43.0,,,...,442.0,882.0,153.0,236.0,194.0,285.0,43.0,9.0,23.0,64.0
2018-03-01 03:00:00,366.0,421.0,516.0,571.0,161.0,164.0,20.0,30.0,,,...,331.0,612.0,145.0,208.0,170.0,212.0,32.0,9.0,21.0,40.0
2018-03-01 04:00:00,362.0,361.0,474.0,474.0,201.0,151.0,21.0,16.0,,,...,336.0,501.0,138.0,209.0,176.0,233.0,32.0,13.0,23.0,50.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-05-31 19:00:00,1830.0,2115.0,1942.0,2052.0,551.0,1241.0,187.0,603.0,,,...,1174.0,1681.0,716.0,725.0,842.0,1011.0,737.0,134.0,249.0,654.0
2018-05-31 20:00:00,1459.0,1721.0,1591.0,1822.0,536.0,871.0,156.0,463.0,,,...,1087.0,1853.0,495.0,684.0,647.0,890.0,426.0,131.0,185.0,468.0
2018-05-31 21:00:00,1572.0,1695.0,1651.0,1939.0,627.0,871.0,170.0,419.0,,,...,1417.0,2328.0,499.0,716.0,750.0,1625.0,306.0,79.0,168.0,363.0
2018-05-31 22:00:00,1450.0,1689.0,1581.0,1947.0,589.0,863.0,114.0,298.0,,,...,1399.0,2209.0,419.0,646.0,719.0,1441.0,202.0,64.0,111.0,293.0


In [12]:
900 / traf_df.shape[0]

0.10273972602739725

In [13]:
traf_df
Traffic = np.nan_to_num(traf_df.values.astype(np.float32))

In [14]:
Traffic

array([[ 657., 1182.,  734., ...,   35.,  147.,  298.],
       [ 617., 1098.,  696., ...,   34.,   99.,  212.],
       [ 405.,  494.,  459., ...,   22.,   71.,   98.],
       ...,
       [1401., 1335., 1535., ...,    0.,  218.,  378.],
       [1365., 1122., 1494., ...,    0.,  204.,  357.],
       [1245.,  936., 1367., ...,    0.,  191.,  326.]], dtype=float32)

In [15]:
maxval = np.max(Traffic)
Traffic = Traffic / maxval

In [16]:
Traffic

array([[0.19160105, 0.34470692, 0.21405658, ..., 0.01020706, 0.04286964,
        0.0869058 ],
       [0.17993584, 0.32020998, 0.20297463, ..., 0.00991543, 0.02887139,
        0.06182561],
       [0.11811024, 0.14406532, 0.13385826, ..., 0.00641586, 0.02070574,
        0.02857976],
       ...,
       [0.40857393, 0.38932633, 0.44765237, ..., 0.        , 0.06357539,
        0.11023622],
       [0.39807525, 0.3272091 , 0.43569553, ..., 0.        , 0.05949256,
        0.10411198],
       [0.3630796 , 0.27296588, 0.3986585 , ..., 0.        , 0.05570137,
        0.09507145]], dtype=float32)

In [17]:
group_preds['MyDCGRUSTE'][..., ~warns]

IndexError: boolean index did not match indexed array along dimension 2; dimension is 39 but corresponding boolean dimension is 48

In [18]:
fwarns = warns[warns != True]

In [26]:
commute = (testTE[:, -1, :][:, 0] < 5) & (testTE[:, -1, :][:, 1] >= 7) & (testTE[:, -1, :][:, 1] <= 9)
for i in range(label.shape[2]):
    print(f'Sensor {i}', fwarns.keys().tolist()[i], utils.metric(group_preds['MyDCGRUSTE'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i][commute], label[..., i][commute])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i][commute], label[..., i][commute])[0],  sep='\t')

Sensor 0	A-01-I	185.41121	165.8687	162.89012	176.40988	137.92969
Sensor 1	A-01-O	96.8062	82.41535	80.05809	78.97288	84.39222
Sensor 2	A-02-I	190.80211	155.66032	221.52083	192.23141	239.70831
Sensor 3	A-02-O	152.77234	101.007904	118.11484	100.345764	93.02018
Sensor 4	A-03-I	116.90529	108.52097	86.91127	95.968704	71.34325
Sensor 5	A-03-O	47.64522	40.53748	49.23188	38.251846	40.16583
Sensor 6	A-04-I	124.47191	61.885723	106.70409	102.36122	86.7087
Sensor 7	A-04-O	17.514273	17.67049	15.633587	17.389841	18.435324
Sensor 8	A-07-I	41.695564	45.428734	51.260544	54.949028	46.867332
Sensor 9	A-07-O	47.327663	46.45765	44.858025	45.694405	48.767082
Sensor 10	A-08-I	111.12915	68.126465	82.99921	71.155045	66.686386
Sensor 11	A-08-O	53.819466	45.414696	49.83624	46.170036	48.912086
Sensor 12	A-09-O	77.66058	42.62307	56.751522	42.365593	42.014107
Sensor 13	A-10-I	168.5426	119.2716	140.98665	166.5139	113.49205
Sensor 14	A-10-O	132.50723	80.38561	106.43115	102.950615	92.998314
Sensor 15	A-11-I	58.99434	65

In [37]:
ms = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
for wh in range(24*7):
    w = wh//24 
    h = wh%24
    
    commute = (testTE[:, -1, :][:, 0] == w) & (testTE[:, -1, :][:, 1] == h)

    print(f'{ms[w]} {h} ', utils.metric(group_preds['MyDCGRUSTE'][commute], label[commute])[0],
                                                    utils.metric(group_preds['MyGMAN0'][commute], label[commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][commute], label[commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][commute], label[commute])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][commute], label[commute])[0],  sep='\t')

Mon 0 	25.09652	24.197514	25.609127	26.53332	24.41088
Mon 1 	20.186337	17.092667	36.604378	18.931084	19.54401
Mon 2 	14.607178	13.912594	14.989511	13.703994	12.982497
Mon 3 	16.41562	13.727605	16.0079	14.735295	14.624308
Mon 4 	13.606535	11.613772	11.695979	11.994465	12.334438
Mon 5 	31.495031	26.315222	28.969208	26.89504	25.128586
Mon 6 	124.350044	61.47558	64.395226	65.67952	56.72077
Mon 7 	147.22972	52.77512	83.49855	86.87452	67.872505
Mon 8 	72.7456	57.725746	70.9473	58.426193	67.08542
Mon 9 	56.572636	53.23605	106.94682	88.68622	83.406845
Mon 10 	70.67924	55.009487	63.083347	61.439407	57.2616
Mon 11 	70.63276	56.71161	58.46366	58.4405	52.16812
Mon 12 	58.606766	56.295918	52.55388	50.824802	52.53046
Mon 13 	56.02157	50.721058	39.673744	39.932156	37.855522
Mon 14 	63.22322	72.4669	59.876186	62.544792	59.9778
Mon 15 	61.915146	57.099773	55.736446	57.747707	55.71995
Mon 16 	70.53536	56.144608	63.86136	64.24026	60.954304
Mon 17 	102.37306	67.47901	79.158936	83.56695	77.17917
Mon 18 	92

In [40]:
ms = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
for wh in range(24*7):
    w = wh//24 
    h = wh%24
    
    commute = (testTE[:, -1, :][:, 0] == w) & (testTE[:, -1, :][:, 1] == h)
    
    i=23
    print(f'{ms[w]} {h} ', utils.metric(group_preds['MyDCGRUSTE'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i][commute], label[..., i][commute])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i][commute], label[..., i][commute])[0],  sep='\t')

Mon 0 	23.321411	32.86383	26.249893	28.951385	24.313446
Mon 1 	52.419647	16.751656	18.439484	27.89978	20.893661
Mon 2 	2.0480347	8.534195	4.5230865	1.3084335	4.226654
Mon 3 	8.050735	13.4867935	12.340805	10.235817	9.707855
Mon 4 	34.674355	26.408684	23.194885	30.805038	28.866852
Mon 5 	13.939941	29.584717	22.33725	23.373535	23.95044
Mon 6 	289.95496	99.857605	59.64496	73.50787	65.503784
Mon 7 	321.7102	91.186646	146.51746	210.30225	147.57007
Mon 8 	101.354004	71.3595	91.236206	107.823975	102.89258
Mon 9 	50.57483	42.63269	251.61304	145.75366	144.71643
Mon 10 	39.218628	41.42566	123.70264	52.462646	79.41431
Mon 11 	212.77966	132.97192	135.19763	141.35437	105.47839
Mon 12 	196.12988	94.42554	139.11816	124.48462	76.22742
Mon 13 	244.96472	94.51123	131.4862	110.83264	84.59326
Mon 14 	52.609863	114.80298	81.448975	73.57068	71.55542
Mon 15 	119.78259	15.685913	43.01526	25.695435	40.95691
Mon 16 	52.94214	62.298584	53.90686	56.590454	54.335327
Mon 17 	51.332886	75.48096	34.60913	60.4718	30.16

In [41]:
ms = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
for wh in range(24*7):
    w = wh//24 
    h = wh%24
    
    commute = (testTE[:, -1, :][:, 0] == w) & (testTE[:, -1, :][:, 1] == h)
    
    i=24
    print(f'{ms[w]} {h} ', utils.metric(group_preds['MyDCGRUSTE'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i][commute], label[..., i][commute])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i][commute], label[..., i][commute])[0],  sep='\t')

Mon 0 	17.350494	5.7703247	15.744263	10.041382	14.067398
Mon 1 	31.711273	23.273712	105.08567	18.376755	29.669312
Mon 2 	11.891739	9.790642	4.1897507	3.460846	8.908524
Mon 3 	18.61139	14.847504	24.820206	11.052513	13.516289
Mon 4 	4.5825577	11.203461	14.380669	5.296364	6.892235
Mon 5 	35.318573	27.757126	32.299896	34.93071	36.272614
Mon 6 	75.21387	25.2854	43.54297	49.381958	30.919678
Mon 7 	430.70837	74.69226	233.2041	163.03174	198.24414
Mon 8 	47.735962	25.85852	33.50232	34.772705	45.896484
Mon 9 	77.19592	63.344482	274.31372	116.23633	131.75806
Mon 10 	117.22998	75.857544	122.87427	103.16589	110.4762
Mon 11 	191.12073	94.196655	195.97278	159.2085	143.44897
Mon 12 	132.53894	47.343994	63.173218	22.77649	68.128174
Mon 13 	81.68274	87.797485	87.15967	74.72583	83.44104
Mon 14 	118.925476	98.3385	110.48535	107.54712	78.35083
Mon 15 	14.798584	19.903076	13.136719	21.303589	21.20459
Mon 16 	82.93359	51.60254	69.2561	67.7168	49.4917
Mon 17 	140.51721	41.210693	77.40796	54.651245	43.208984
M

Sensor 0	A-01-I	84.98047	113.104675	154.71362
Sensor 1	A-01-O	25.097229	20.963724	46.07713
Sensor 2	A-02-I	65.64443	129.15083	162.72498
Sensor 3	A-02-O	44.0213	48.090454	60.805744
Sensor 4	A-03-I	4.037313	25.405539	30.317108
Sensor 5	A-03-O	32.552246	32.148987	45.63637
Sensor 6	A-04-I	15.905898	17.20217	17.743277
Sensor 7	A-04-O	6.449056	5.843	8.029735
Sensor 8	A-07-I	38.493896	29.871643	49.350586
Sensor 9	A-07-O	31.83309	59.410908	51.76121
Sensor 10	A-08-I	37.65912	52.847656	54.719055
Sensor 11	A-08-O	13.406006	37.271038	26.968994
Sensor 12	A-09-O	28.678223	58.3408	62.313374
Sensor 13	A-10-I	41.93982	36.00749	63.61255
Sensor 14	A-10-O	43.07953	48.48885	57.32005
Sensor 15	A-11-I	83.242714	47.108704	57.77881
Sensor 16	A-11-O	42.227825	34.89258	44.59792
Sensor 17	A-13-I	51.990215	57.498108	63.66559
Sensor 18	A-13-O	75.41996	37.126324	50.722942
Sensor 19	A-14-I	20.594584	41.282116	36.42093
Sensor 20	A-14-O	20.956705	7.166931	28.199362
Sensor 21	A-16-I	59.05363	50.13562	51.719257
Sensor 22

In [25]:
for i in range(label.shape[2]):
    print(f'Sensor {i}', fwarns.keys().tolist()[i], utils.metric(group_preds['MyDCGRUSTE'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i][commute], label[..., i][commute])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i][commute], label[..., i][commute])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i][commute], label[..., i][commute])[0],  sep='\t')

Sensor 0	A-01-I	185.41121	165.8687	162.89012	176.40988	137.92969
Sensor 1	A-01-O	96.8062	82.41535	80.05809	78.97288	84.39222
Sensor 2	A-02-I	190.80211	155.66032	221.52083	192.23141	239.70831
Sensor 3	A-02-O	152.77234	101.007904	118.11484	100.345764	93.02018
Sensor 4	A-03-I	116.90529	108.52097	86.91127	95.968704	71.34325
Sensor 5	A-03-O	47.64522	40.53748	49.23188	38.251846	40.16583
Sensor 6	A-04-I	124.47191	61.885723	106.70409	102.36122	86.7087
Sensor 7	A-04-O	17.514273	17.67049	15.633587	17.389841	18.435324
Sensor 8	A-07-I	41.695564	45.428734	51.260544	54.949028	46.867332
Sensor 9	A-07-O	47.327663	46.45765	44.858025	45.694405	48.767082
Sensor 10	A-08-I	111.12915	68.126465	82.99921	71.155045	66.686386
Sensor 11	A-08-O	53.819466	45.414696	49.83624	46.170036	48.912086
Sensor 12	A-09-O	77.66058	42.62307	56.751522	42.365593	42.014107
Sensor 13	A-10-I	168.5426	119.2716	140.98665	166.5139	113.49205
Sensor 14	A-10-O	132.50723	80.38561	106.43115	102.950615	92.998314
Sensor 15	A-11-I	58.99434	65

In [21]:
for i in range(label.shape[2]):
    print(f'Sensor {i}', fwarns.keys().tolist()[i], utils.metric(group_preds['MyDCGRUSTE'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyGMAN0'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i], label[..., i])[0],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i], label[..., i])[0], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i], label[..., i])[0],  sep='\t')

Sensor 0	A-01-I	98.381035	84.142136	80.76664	82.29185	76.891464
Sensor 1	A-01-O	85.56044	74.35148	74.923225	76.18488	71.99749
Sensor 2	A-02-I	100.03651	86.14253	90.6723	87.28877	87.74827
Sensor 3	A-02-O	88.88697	77.47848	78.46168	75.989456	72.59685
Sensor 4	A-03-I	60.33006	49.18087	46.809776	48.66286	45.49823
Sensor 5	A-03-O	49.27024	41.501263	43.61061	42.426407	40.036896
Sensor 6	A-04-I	33.42005	26.053843	31.621008	29.885628	28.314121
Sensor 7	A-04-O	25.019495	22.420813	22.038057	21.765368	21.848167
Sensor 8	A-07-I	30.665468	31.415398	30.658943	30.924383	30.280851
Sensor 9	A-07-O	47.42576	44.203224	45.10787	45.324726	43.75976
Sensor 10	A-08-I	60.197437	56.046104	55.56432	53.764084	52.312714
Sensor 11	A-08-O	58.073586	57.634388	55.272003	53.65041	54.435223
Sensor 12	A-09-O	52.13344	40.89382	42.68752	39.74805	38.168858
Sensor 13	A-10-I	90.99236	78.09109	76.992294	81.29863	70.33343
Sensor 14	A-10-O	84.272675	76.11994	75.357254	74.41447	72.09267
Sensor 15	A-11-I	54.207973	57.600925	57.063

In [22]:
for i in range(label.shape[2]):
    print(f'Sensor {i}', fwarns.keys().tolist()[i], utils.metric(group_preds['MyDCGRUSTE'][..., i], label[..., i])[2],
                                                    utils.metric(group_preds['MyGMAN0'][..., i], label[..., i])[2],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZC'][..., i], label[..., i])[2],
                                                    utils.metric(group_preds['MyDCGRUSTE0ZF'][..., i], label[..., i])[2], 
                                                    utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i], label[..., i])[2],  sep='\t')

Sensor 0	A-01-I	0.0680891521686359	0.06326572058494026	0.06090712258743087	0.06072605497392306	0.058531023436333984
Sensor 1	A-01-O	0.0659686333299416	0.06123029102630152	0.06076629750413629	0.06093579858456636	0.05842540965159312
Sensor 2	A-02-I	0.06275827170097709	0.05883861698951366	0.05952089683964285	0.056906381771636656	0.05722687367935127
Sensor 3	A-02-O	0.06074379146674851	0.05664922445587364	0.055985483304558666	0.054338770046641005	0.05248270850427553
Sensor 4	A-03-I	0.0845772768153598	0.08008101479635932	0.07603555743957587	0.07792679682456635	0.07584246227708533
Sensor 5	A-03-O	0.08314340769151778	0.07701266576403379	0.07903932621564085	0.07708366545630906	0.07412308667820872
Sensor 6	A-04-I	0.3008184882415916	0.3093707787389357	0.32103421393647746	0.29381658033103775	0.29827196492160324
Sensor 7	A-04-O	0.11694840034726055	0.11504487258811842	0.11283780136660042	0.11560588213526622	0.11550068330210518
Sensor 8	A-07-I	0.060088602143575094	0.06350846365740284	0.06112580619938

In [23]:
for i in range(label.shape[2]):
    print(f'Sensor {i}', traf_df.columns[i], utils.metric(group_preds['MyGMAN0'][..., i], label[..., i]), utils.metric(group_preds['MyDCGRUSTE0ZCF'][..., i], label[..., i]), sep='\t')

Sensor 0	A-01-I	(84.142136, 158.17505, 0.06326572058494026)	(76.891464, 141.05565, 0.058531023436333984)
Sensor 1	A-01-O	(74.35148, 136.05843, 0.06123029102630152)	(71.99749, 123.05061, 0.05842540965159312)
Sensor 2	A-02-I	(86.14253, 146.35397, 0.05883861698951366)	(87.74827, 134.40698, 0.05722687367935127)
Sensor 3	A-02-O	(77.47848, 119.95932, 0.05664922445587364)	(72.59685, 111.914734, 0.05248270850427553)
Sensor 4	A-03-I	(49.18087, 95.31088, 0.08008101479635932)	(45.49823, 78.80139, 0.07584246227708533)
Sensor 5	A-03-O	(41.501263, 65.71722, 0.07701266576403379)	(40.036896, 62.68133, 0.07412308667820872)
Sensor 6	A-04-I	(26.053843, 51.79244, 0.3093707787389357)	(28.314121, 48.863285, 0.29827196492160324)
Sensor 7	A-04-O	(22.420813, 37.072033, 0.11504487258811842)	(21.848167, 33.438274, 0.11550068330210518)
Sensor 8	A-05-I	(31.415398, 45.80056, 0.06350846365740284)	(30.280851, 42.40468, 0.0604288264624296)
Sensor 9	A-05-O	(44.203224, 64.89703, 0.04829352933291279)	(43.75976, 63.515984

In [19]:
label

array([[[2185., 1766., 2408., ...,    0.,  380.,  493.],
        [2071., 1860., 2409., ...,    0.,  453.,  540.],
        [2014., 1686., 1977., ...,    0.,  511.,  598.]],

       [[2071., 1860., 2409., ...,    0.,  453.,  540.],
        [2014., 1686., 1977., ...,    0.,  511.,  598.],
        [1972., 2150., 2010., ...,    0.,  418.,  573.]],

       [[2014., 1686., 1977., ...,    0.,  511.,  598.],
        [1972., 2150., 2010., ...,    0.,  418.,  573.],
        [2057., 1732., 2024., ...,    0.,  483.,  488.]],

       ...,

       [[1833., 1858., 1937., ...,    0.,  271.,  566.],
        [1454., 1552., 1592., ...,    0.,  221.,  473.],
        [1401., 1335., 1535., ...,    0.,  218.,  378.]],

       [[1454., 1552., 1592., ...,    0.,  221.,  473.],
        [1401., 1335., 1535., ...,    0.,  218.,  378.],
        [1365., 1122., 1494., ...,    0.,  204.,  357.]],

       [[1401., 1335., 1535., ...,    0.,  218.,  378.],
        [1365., 1122., 1494., ...,    0.,  204.,  357.],
        

array([[[2201.948   , 1801.0413  , 2408.8882  , ...,  118.17151 ,
          346.6506  ,  469.18573 ],
        [2157.2505  , 1794.9213  , 2352.3003  , ...,  134.6008  ,
          362.52203 ,  548.57404 ],
        [2071.4763  , 1788.9254  , 2228.0269  , ...,  151.67215 ,
          387.33246 ,  591.4836  ]],

       [[2159.05    , 1805.7067  , 2344.9575  , ...,  125.79533 ,
          386.12314 ,  559.01965 ],
        [2091.0864  , 1785.1785  , 2234.6711  , ...,  149.76949 ,
          380.98477 ,  588.97    ],
        [2078.9692  , 1800.1062  , 2191.3096  , ...,  145.23595 ,
          392.50497 ,  610.6024  ]],

       [[2047.7803  , 1803.8646  , 2260.3828  , ...,  135.51163 ,
          428.03925 ,  582.47327 ],
        [2063.1108  , 1808.4633  , 2226.2583  , ...,  143.18228 ,
          415.48273 ,  609.4136  ],
        [2063.5015  , 1801.7828  , 2164.0037  , ...,  147.73392 ,
          398.18036 ,  588.0248  ]],

       ...,

       [[1668.5629  , 1984.0696  , 1744.0242  , ...,   75.35905

KeyError: 'myDCGRUSTE'

In [None]:


def Sorting(lst):
    lst2 = sorted(lst, key=len)
    return lst2
      
# Driver code
lst = list(group_results.keys())
# print(Sorting(lst))
lst = Sorting(lst)


for k in lst: #sorted(group_results):
    res_list = []
    for q in range(args.Q):
        ol = len(group_results[k][q])
        tmae = np.array(group_results[k][q]).mean(0)[0]
        res_list.append(tmae)

    # print((k + '_'+str(ol) + f'-TA' + ' '*20)[:20], '\t'.join('%.4f'%_ for _ in np.array(group_results[k]['all']).mean(0)), sep='\t')
    res_list.extend(list(np.array(group_results[k]['all']).mean(0)))
    print((k + '_'+str(ol) + ' '*30)[:20], '\t'.join('%.4f'%_ for _ in res_list), sep='\t')

         
