# Human Activity Recognition - SSL JOINT

Last Update : 24 July 2019

In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = False # can also be set in the trainer

## Environment

In [2]:
cd ..

/master


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""

import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import pandas as pd
import h5py


import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset

In [4]:
from skssl.transformers.neuralproc.datasplit import CntxtTrgtGetter, GetRandomIndcs, get_all_indcs
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = CntxtTrgtGetter(contexts_getter=GetRandomIndcs(min_n_indcs=0.1, max_n_indcs=0.5),
                                     targets_getter=get_all_indcs,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = CntxtTrgtGetter(contexts_getter=GetRandomIndcs(min_n_indcs=0.01, max_n_indcs=0.5),
                                 targets_getter=GetRandomIndcs(min_n_indcs=0.5, max_n_indcs=0.99),
                                 is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

In [5]:
data_both = get_timeseries_dataset("har")(split="both")

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False):
    def mycollate(batch):
        min_length = min([v.size(0) for b in batch for k,v in b[0].items() if "X" in k])
        # chose first min_legth of each (assumes that randomized)
        
        batch = [({k:v[:min_length, ...] for k,v in b[0].items()}, b[1]) for b in batch]        
        collated = torch.utils.data.dataloader.default_collate(batch)
        
        X = collated[0]["X"]
        y = collated[0]["y"]
        
        if is_repeat_batch:
            
            X = torch.cat([X,X], dim=0)
            y = torch.cat([y,y], dim=0)
            collated[1] = torch.cat([collated[1], collated[1]], dim=0) # targets
        
        collated[0]["X"], collated[0]["y"], collated[0]["X_trgt"], collated[0]["y_trgt"] = get_cntxt_trgt(X, y)
        
        return collated
    return mycollate

In [6]:
X_DIM = 1  # 1D spatial input (although actually 2 but the first is for sparse channels)
Y_DIM = data_both.data.shape[-1] # multiple channels
N_TARGETS = len(np.unique(data_both.targets))

sampling_percentages = [0.5]
label_percentages = [N_TARGETS, N_TARGETS*2, 0.01, 0.05, 0.1, 0.3, 0.5, 1]

## Model

In [7]:
import torch.nn as nn
from skssl.transformers import GlobalNeuralProcess, NeuralProcessLoss, AttentiveNeuralProcess
from skssl.utils.helpers import rescale_range
from skssl.predefined import UnetCNN, CNN, MLP, SparseSetConv, SetConv, MlpRBF, GaussianRBF, BatchSparseSetConv
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split
from utils.helpers import count_parameters

In [8]:
from copy import deepcopy

models = {}

unet = partial(UnetCNN,
               Conv=torch.nn.Conv1d,
               Pool=torch.nn.MaxPool1d,
               upsample_mode="linear",
               n_layers=18,
               is_double_conv=True,
               is_depth_separable=True,
               Normalization=torch.nn.BatchNorm1d,
               is_chan_last=True,
               bottleneck=None,
               kernel_size=7,
               max_nchannels=256,
              is_force_same_bottleneck=True,
               _is_summary=True,
              )

kwargs = dict(x_dim=X_DIM, 
              y_dim=Y_DIM,
              min_std=5e-3,
                n_tmp_queries=128,
                r_dim=64,
              keys_to_tmp_attn=partial(SetConv, RadialBasisFunc=GaussianRBF),
              TmpSelfAttn=unet,
              tmp_to_queries_attn=partial(SetConv, RadialBasisFunc=GaussianRBF),
              is_skip_tmp=False,
              is_use_x=False,
              get_cntxt_trgt=precomputed_cntxt_trgt_split,
              is_encode_xy=False,
             Classifier=partial(MLP, input_size=256, output_size=N_TARGETS, dropout=0.5, hidden_size=64, n_hidden_layers=1))

models["ssl_classifier_gnp_large_shared_bottleneck_sup"] = partial(GlobalNeuralProcess, **kwargs)

kwargs_bis = deepcopy(kwargs)
kwargs_bis["TmpSelfAttn"] = partial(unet, n_layers=14, max_nchannels=128)
kwargs_bis["Classifier"] = partial(MLP, input_size=128, output_size=N_TARGETS, dropout=0.5, hidden_size=64, n_hidden_layers=3)
kwargs_bis["r_dim"] = 16

models["ssl_classifier_gnp_small_shared_bottleneck_sup"] = partial(GlobalNeuralProcess, **kwargs_bis)

kwargs_bis = deepcopy(kwargs)
kwargs_bis["TmpSelfAttn"] = partial(unet, n_layers=10, bottleneck=True, max_nchannels=128)
kwargs_bis["Classifier"] = partial(MLP, input_size=64, output_size=N_TARGETS,  dropout=0.5, hidden_size=64, n_hidden_layers=3)
kwargs_bis["r_dim"] = 16

models["ssl_classifier_gnp_mini_shared_bottleneck_sup"] = partial(GlobalNeuralProcess, **kwargs_bis)

In [9]:
from utils.helpers import count_parameters
for k,v in models.items():
    print(k, "- N Param:", count_parameters(v()))

ssl_classifier_gnp_large_shared_bottleneck_sup - N Param: 1023774
ssl_classifier_gnp_small_shared_bottleneck_sup - N Param: 105654
ssl_classifier_gnp_mini_shared_bottleneck_sup - N Param: 38262




# Training

In [10]:
from ntbks_helpers import train_models_
from skorch.dataset import CVSplit
from utils.data.ssldata import get_train_dev_test_ssl

In [11]:
N_EPOCHS = 100 
BATCH_SIZE = 32
IS_RETRAIN = True # if false load precomputed
chckpnt_dirname="results/challenge/har/"

In [12]:
from skssl.utils.helpers import HyperparameterInterpolator

n_steps_per_epoch = len(data_both)//BATCH_SIZE
hi=HyperparameterInterpolator(1e-5, 10, N_EPOCHS*n_steps_per_epoch, start_step=n_steps_per_epoch*10, mode="linear")

In [16]:
data_trainers = {}

for sampling_perc in sampling_percentages:
    for label_perc in [1]:#label_percentages:
        data_train, _, data_test = get_train_dev_test_ssl("har", 
                                                          n_labels=label_perc, 
                                                          data_perc=sampling_perc, 
                                                          dev_size=0)
        
        # add test as unlabeled data
        data_train.data = np.concatenate([data_train.data, data_test.data], axis=0)
        data_train.targets = np.concatenate([data_train.targets, np.ones_like(data_test.targets)], axis=0)
        data_train.indcs = np.concatenate([data_train.indcs, data_test.indcs], axis=0)
        
        data_trainers.update(train_models_({"{}%har_{}%lab".format(int(sampling_perc*100), 
                                                                   int(label_perc*100)): (data_train, data_test)}, 
                              models, 
                              criterion=partial(NeuralProcessLoss, 
                                                ssl_loss="supervised",
                                                get_lambda_sup=lambda: hi(True)),
                                patience=15,
                              chckpnt_dirname=chckpnt_dirname,
                              max_epochs=N_EPOCHS,
                              batch_size=BATCH_SIZE,
                              is_retrain=IS_RETRAIN,
                              #train_split=CVSplit(0.05),
                              iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=True),  
                              iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_test)))


--- Training 50%har_100%lab/ssl_classifier_gnp_large_shared_bottleneck_sup ---



HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp      dur
-------  ------------  -----------  ------------  ----  -------
      1        [36m7.5321[0m       [32m0.4028[0m        [35m1.3262[0m     +  19.6591


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      2        [36m3.3597[0m       [32m0.4255[0m        [35m1.2522[0m     +  19.6911


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      3        [36m2.5699[0m       [32m0.4744[0m        [35m1.1186[0m     +  19.4092


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      4        [36m2.1422[0m       0.3424        1.2435        19.2636


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      5        2.3167       0.4564        1.2193        19.4043


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      6        [36m1.5239[0m       0.2613        1.3679        19.9147


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      7        [36m1.0620[0m       0.3431        1.2926        19.8916


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      8        [36m0.9851[0m       0.2888        1.2991        19.4789


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      9        [36m0.6137[0m       0.2229        1.2691        19.5189


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     10        [36m0.6024[0m       0.2803        1.2428        19.4890


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     11        [36m0.3724[0m       0.2989        1.2443        19.4781


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     12        [36m0.3429[0m       0.2559        1.2962        19.4867


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     13        [36m0.3342[0m       [32m0.5008[0m        [35m1.1012[0m     +  19.9636


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     14        [36m0.1712[0m       [32m0.5365[0m        [35m0.8978[0m     +  19.9798


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     15        0.4094       [32m0.7533[0m        [35m0.6943[0m     +  19.9674


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     16       [36m-0.0259[0m       [32m0.7893[0m        [35m0.6546[0m     +  20.5171


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     17        0.0054       0.7747        0.6724        20.3850


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     18       [36m-0.2439[0m       [32m0.8612[0m        [35m0.5396[0m     +  19.8439


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     19       -0.0341       0.8554        0.5409        20.0502


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     20       [36m-0.2484[0m       0.8402        0.5677        20.0995


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     21       [36m-0.3402[0m       0.8297        0.6114        20.0648


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     22       -0.3295       0.8558        [35m0.5084[0m     +  20.0584


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     23       [36m-0.5370[0m       0.7207        0.8613        20.0417


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     24       -0.4453       0.7370        0.8001        20.0259


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     25       -0.3771       [32m0.8629[0m        0.6186        20.0584


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     26       -0.4190       0.8476        0.5770        20.0894


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     27       [36m-0.9132[0m       0.8276        0.6471        20.1025


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     28       [36m-0.9307[0m       0.8018        0.6804        20.0368


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     29       -0.8597       0.7570        0.8004        20.0645


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     30       [36m-1.1461[0m       0.6817        1.0666        20.0746


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     31       -1.0441       0.7984        0.7584        20.0795


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     32       [36m-1.2756[0m       0.8117        0.6727        19.7019


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     33       -1.2270       0.7615        0.8902        20.0779


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     34       [36m-1.4517[0m       0.6566        2.0343        19.5795


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     35       -1.2376       0.6929        1.4825        20.0908


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     36       -1.1875       0.7319        1.0867        20.0832


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

Stopping since valid_loss has not improved in the last 15 epochs.
Re-initializing module.
Re-initializing optimizer.
50%har_100%lab/ssl_classifier_gnp_large_shared_bottleneck_sup best epoch: 22 val_loss: 0.5084086644030407

--- Training 50%har_100%lab/ssl_classifier_gnp_small_shared_bottleneck_sup ---



HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp      dur
-------  ------------  -----------  ------------  ----  -------
      1       [36m14.0398[0m       [32m0.5976[0m        [35m0.9148[0m     +  12.3125


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      2        [36m7.0436[0m       [32m0.6793[0m        [35m0.8033[0m     +  12.2642


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      3        [36m6.2107[0m       [32m0.7207[0m        [35m0.7517[0m     +  12.3682


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      4        [36m5.7943[0m       0.6963        [35m0.7456[0m     +  12.3822


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      5        [36m5.7856[0m       0.7075        0.7540        12.3885


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      6        [36m5.3946[0m       [32m0.7360[0m        [35m0.6951[0m     +  12.4475


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      7        5.5279       0.7326        [35m0.6874[0m     +  12.3692


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      8        [36m5.2173[0m       0.7357        0.7098        12.3456


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      9        5.9873       [32m0.7458[0m        [35m0.6298[0m     +  12.2939


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     10        5.3129       0.7336        0.6421        12.2792


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     11        5.2666       [32m0.7482[0m        [35m0.6020[0m     +  12.3308


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     12        [36m4.9254[0m       [32m0.7516[0m        0.6705        12.4474


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     13        [36m4.7255[0m       [32m0.8656[0m        [35m0.5613[0m     +  12.3024


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     14        [36m4.6960[0m       0.8385        0.5992        12.3723


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     15        [36m4.6069[0m       [32m0.9033[0m        [35m0.5091[0m     +  12.3562


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     16        [36m4.4096[0m       0.8622        0.5527        12.3502


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     17        4.6337       0.8975        [35m0.5046[0m     +  12.3094


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     18        4.5122       0.8992        0.5130        12.3663


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     19        4.5664       0.8765        0.5312        12.3702


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     20        [36m4.3387[0m       0.8979        [35m0.4969[0m     +  12.4668


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     21        4.4244       0.8833        0.5595        12.3502


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     22        4.4329       0.8758        0.5552        12.3943


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     23        [36m4.2756[0m       0.8582        0.5878        12.3532


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     24        5.8671       0.8826        0.5257        12.5007


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     25        4.9290       [32m0.9125[0m        [35m0.4774[0m     +  12.4498


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     26        4.7812       0.8958        0.4943        10.7630


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     27        4.3649       0.9063        0.4804        12.3717


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     28        4.4945       0.8728        0.5243        12.3585


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     29        4.6024       0.8982        0.5164        12.3639


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     30        4.4427       0.8307        0.6006        12.4121


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     31        4.7519       0.8829        0.4939        12.3666


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     32        4.8632       0.8198        0.7387        12.4145


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     33        5.1069       0.8575        0.5338        12.4257


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     34        4.5967       0.8510        0.5863        12.3624


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     35        5.0382       0.8694        0.5076        12.3245


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     36        5.1648       0.8409        0.5865        12.3844


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     37        4.7189       0.8480        0.5930        12.3990


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     38        4.8180       0.8361        0.6233        12.5205


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     39        4.8375       0.8487        0.6180        12.4306


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

Stopping since valid_loss has not improved in the last 15 epochs.
Re-initializing module.
Re-initializing optimizer.
50%har_100%lab/ssl_classifier_gnp_small_shared_bottleneck_sup best epoch: 25 val_loss: 0.4774247264311926

--- Training 50%har_100%lab/ssl_classifier_gnp_mini_shared_bottleneck_sup ---



HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp      dur
-------  ------------  -----------  ------------  ----  -------
      1       [36m19.7676[0m       [32m0.6081[0m        [35m0.9583[0m     +  10.2063


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      2       [36m12.0665[0m       [32m0.6994[0m        [35m0.8146[0m     +  10.2191


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      3       [36m11.0202[0m       [32m0.7275[0m        [35m0.7572[0m     +  10.2042


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      4       [36m10.2313[0m       0.6793        0.7963        10.2053


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      5        [36m9.5970[0m       0.7021        [35m0.7397[0m     +  8.0226


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      6        [36m9.3443[0m       [32m0.7302[0m        [35m0.6728[0m     +  7.7528


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      7        [36m9.0081[0m       [32m0.7397[0m        0.7296        10.2284


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      8        [36m8.7885[0m       [32m0.8385[0m        [35m0.6066[0m     +  10.2022


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

      9        [36m8.5079[0m       [32m0.8500[0m        0.7241        10.2336


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     10        [36m8.3935[0m       [32m0.8534[0m        0.6674        10.2231


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     11        [36m8.3451[0m       [32m0.8877[0m        [35m0.5636[0m     +  10.2174


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     12        8.3800       0.8395        0.7435        10.2181


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     13        [36m8.2904[0m       0.8371        0.6833        10.2149


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     14        [36m8.2744[0m       0.8588        0.6562        10.2246


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     15        [36m8.0773[0m       0.8616        0.5731        10.2196


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     16        [36m7.9922[0m       0.8436        0.6775        10.2135


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     17        8.2477       0.8700        0.6102        10.2258


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     18        8.0371       0.8585        0.7337        10.2132


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     19        8.2653       0.8575        0.7888        10.2184


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     20        [36m7.9786[0m       0.8711        0.6581        10.1971


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     21        8.0877       0.8731        0.6935        10.2073


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     22        8.2308       0.8619        0.5763        10.2670


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     23        8.1264       0.8426        0.7242        10.2126


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     24        8.0781       0.8687        0.6566        10.2173


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

     25        8.2907       0.8724        0.6768        10.2268


HBox(children=(IntProgress(value=0, max=346), HTML(value='')))

Stopping since valid_loss has not improved in the last 15 epochs.
Re-initializing module.
Re-initializing optimizer.
50%har_100%lab/ssl_classifier_gnp_mini_shared_bottleneck_sup best epoch: 11 val_loss: 0.563553226977071


In [17]:
for k,t in data_trainers.items():
    print()
    for e, h in enumerate(t.history[::-1]):
        if h["valid_loss_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                  "val_acc:", h["valid_acc"])
            break


50%har_100%lab/ssl_classifier_gnp_large_shared_bottleneck_sup epoch: 22 val_loss: 0.5084086644030407 val_acc: 0.8557855446216491

50%har_100%lab/ssl_classifier_gnp_small_shared_bottleneck_sup epoch: 25 val_loss: 0.4774247264311926 val_acc: 0.9124533423820834

50%har_100%lab/ssl_classifier_gnp_mini_shared_bottleneck_sup epoch: 11 val_loss: 0.563553226977071 val_acc: 0.8876823888700374


In [42]:
for i in range(100*n_steps_per_epoch):
    x = hi(True)
    if i % n_steps_per_epoch == 0:
        print(i//n_steps_per_epoch, x)

0 1.0004304823679598e-05
1 1.1486478813866367e-05
2 1.3188242248387777e-05
3 1.514212809866449e-05
4 1.7385489213651198e-05
5 1.9961212402148445e-05
6 2.291853830899504e-05
7 2.6314002758887478e-05
8 3.0212517563695364e-05
9 3.4688611455294726e-05
10 3.9827854867094556e-05
11 4.5728495802106476e-05
12 5.2503338060793106e-05
13 6.0281897735174225e-05
14 6.921287919534444e-05
15 7.946701790236032e-05
16 9.124034439415184e-05
17 0.00010475793184276736
18 0.00012027819882579596
19 0.00013809784956895856
20 0.0001585575461035302
21 0.0001820484207744256
22 0.00020901955359994335
23 0.00023998655742943475
24 0.00027554143502317284
25 0.0003163638964943012
26 0.0003632343534707939
27 0.00041704883838956735
28 0.00047883613413804584
29 0.00054977744151416
30 0.0006312289604917681
31 0.0007247478169823302
32 0.0008321218307402155
33 0.0009554036934909864
34 0.0010969502106731646
35 0.0012594673569861518
36 0.0014460620070807454
37 0.001660301330338795
38 0.0019062809852045635
39 0.0021887034167