In [9]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os
from os.path import join as oj
import sys
sys.path.append('../src')
import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from sklearn import metrics
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
plt.style.use('dark_background')
import data
from skorch.callbacks import Checkpoint
from skorch import NeuralNetRegressor
from config import *
from tqdm import tqdm
import train_reg
import config
import pandas as pd
import features
from scipy.stats import skew, pearsonr
import outcomes
from sklearn.model_selection import KFold
from torch import nn, optim
from torch.nn import functional as F

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# prepare data

In [40]:
dsets = ['clath_aux+gak_a7d2_new'] # this is one of the large datasets
feat_names = ['X_same_length', 'mean_total_displacement',
              'mean_square_displacement', 'lifetime']
meta = ['cell_num', 'Y_sig_mean']
df_full = None
for dset in dsets:
    df = data.get_data(dset=dset)
    df = df[df.lifetime > 15] # only keep hard tracks
    df = df[df.cell_num.isin(config.DSETS[dset]['train'])] # exclude held-out test data
    
    # downsample tracks
    length = 40
    df['X_same_length'] = [features.downsample(df.iloc[i]['X'], length)
                           for i in range(len(df))] # downsampling

    # regression response
    df = train_reg.add_sig_mean(df)     
    
    # remove extraneous feats
    df = df[feat_names + meta]
    df = df.dropna() 
    
    # merge all dsets
#     if df_full is None:
#         df_full = deepcopy(df)
#     else:
#         df_full = df_full.merge(df)

  sigs = np.array(r[f'{track}_pvals']) < 0.05


## train neural net

In [36]:
# decide on architecture
out_dir = f'{DIR_RESULTS}/dec10_deep'
# out_dir = f'results/regression/deep_learning/Dec4'

In [37]:
os.makedirs(out_dir, exist_ok=True)
outcome_def = 'Y_sig_mean'
num_epochs = 100
num_hidden = 20
for model_type in [ 'nn_lstm', 'fcnn', 'nn_cnn']: # = 'nn_cnn' # 'nn_lstm', 'fcnn', 'nn_cnn', 'nn_attention'
    train_reg.train_reg(df,
                        feat_names=feat_names,
                        model_type=model_type, 
                        outcome_def=outcome_def,
                        out_name=oj(out_dir, f'{dset}_{outcome_def}_{model_type}.pkl'),
                        fcnn_hidden_neurons=num_hidden,
                        fcnn_epochs=num_epochs)

100%|██████████| 100/100 [00:37<00:00,  2.75it/s]


# analyze results

In [39]:
results = train_reg.load_results(out_dir)
r = results
r = r[[k for k in r if not 'std' in k]]
r = r[[k for k in r if not '_f' in k]]
# r = r[r.index.str.contains('ros')] # only use random sampling
r = r.sort_values(by=['r2'], ascending=False)
# r.style.background_gradient(cmap='viridis', axis=None) # all values on same cmap
r

{'r2': [0.22065207252390961, 0.2563547629913744, 0.2880939093918967, 0.27601263002905874, 0.2537535676394833, 0.3583728130619288, 0.24911720548590122, 0.31005096646270314], 'pearsonr': [0.5935782428833882, 0.5775870631014847, 0.5602878898534248, 0.557773593167131, 0.5274713289435491, 0.600967953128585, 0.6345618685257912, 0.5626297095474697]}
dict_keys(['r2', 'pearsonr'])
{'r2': [-1.7804643716601247, -1.2376744635106083, -1.374558865574127, -0.733147571663531, -1.0138287542827298, -0.6223858838549365, -0.6202399539165495, -0.9802613520405998], 'pearsonr': [0.3174476193625304, 0.376058913381081, 0.33957780906042495, 0.36184644668071425, 0.3295159842852877, 0.40571672197074904, 0.4008338563746303, 0.3442416390508416]}
dict_keys(['r2', 'pearsonr'])
{'r2': [-1.7746432990748389, 0.22588681400214916, 0.2718245622728861, 0.25046333440540247, 0.2486752821555649, 0.36035165809677605, 0.1670131832803008, 0.30054097815271763], 'pearsonr': [0.502593821195028, 0.5800764476672919, 0.5566928944684711

Unnamed: 0_level_0,cv_accuracy_by_cell,pearsonr,r2
model_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
clath_aux+gak_a7d2_new_Y_sig_mean_nn_cnn,"[0.22065207252390961, 0.2563547629913744, 0.28...",0.575,0.271
clath_aux+gak_a7d2_new_Y_sig_mean_fcnn,"[-1.7746432990748389, 0.22588681400214916, 0.2...",0.557,-0.072
clath_aux+gak_a7d2_new_Y_sig_mean_nn_lstm,"[-1.7804643716601247, -1.2376744635106083, -1....",0.356,-1.11
