# Visible Neural Network

In [1]:
# Data ----
from dataG2F.core import get_data
from dataG2F.qol  import ensure_dir_path_exists

# Data Utilities ----
import numpy  as np
import pandas as pd

# Model Building  ----
## General ====
import torch
from   torch import nn
import torch.nn.functional as F
from   torch.utils.data import Dataset
from   torch.utils.data import DataLoader

from vnnpaper.zma import \
    BigDataset,    \
    plDNN_general, \
    mask_parents,  \
    vnn_factory_1, \
    vnn_factory_2, \
    vnn_factory_3

# Hyperparameter Tuning ----
import os # needed for checking history (saved by lightning) 

## Logging with Pytorch Lightning ====
import lightning.pytorch as pl
from   lightning.pytorch.loggers import CSVLogger # used to save the history of each trial (used by ax)

## Adaptive Experimentation Platform ====
from ax.service.ax_client import AxClient, ObjectiveProperties
# from ax.utils.notebook.plotting import init_notebook_plotting, render

In [2]:
torch.set_float32_matmul_precision('medium')

In [3]:
# init_notebook_plotting()

## Setup

In [4]:
cache_path = '../nbs_artifacts/zma_g2f_individual_eres_1mods/'

In [5]:
## Settings ====
# run_hyps = 32 
# run_hyps_force = False # should we run more trials even if the target number has been reached?
# max_hyps = 64

# Run settings: 
params_run = {
    'batch_size': 256,
    'max_epoch' : 16, #256,    
}

# data settings
params_data = {
    # 'y_var': 'Yield_Mg_ha',
    'y_var': [
        # Description quoted from competition data readme
        # 'Yield_Mg_ha',     # Grain yield in Mg per ha at 15.5% grain moisture, using plot area without alley (Mg/ha).
        'Pollen_DAP_days', # Number of days after planting that 50% of plants in the plot began shedding pollen.
        # 'Silk_DAP_days',   # Number of days after planting that 50% of plants in the plot had visible silks.
        # 'Plant_Height_cm', # Measured as the distance between the base of a plant and the ligule of the flag leaf (centimeter).
        # 'Ear_Height_cm',   # Measured as the distance from the ground to the primary ear bearing node (centimeter).
        # 'Grain_Moisture',  # Water content in grain at harvest (percentage).
        # 'Twt_kg_m3'        # Shelled grain test weight (kg/m3), a measure of grain density.
    ],

    'y_resid': 'Env', # None, Env, Geno
    'y_resid_strat': 'naive_mean', # None, naive_mean, filter_mean, ...
    'holdout_parents': [
        ## 2022 ====
        'LH244',
        ## 2021 ====
        'PHZ51',
        # 'PHP02',
        # 'PHK76',
        ## 2019 ====
        # 'PHT69',
        'LH195',
        ## 2017 ====
        # 'PHW52',
        # 'PHN82',
        ## 2016 ====
        # 'DK3IIH6',
        ## 2015 ====
        # 'PHB47',
        # 'LH82',
        ## 2014 ====
        # 'LH198',
        # 'LH185',
        # 'PB80',
        # 'CG102',
 ],    
}

In [6]:
# params_list = [    

In [7]:
lightning_log_dir = cache_path+"lightning"
exp_name = [e for e in cache_path.split('/') if e != ''][-1]

In [8]:
# parameterization is needed for setup. These values will be overwritten by Ax if tuning is occuring. 
# in this file I define params later. I've included it here to gurantee that we can merge other params dicts into it.
params = {
'default_out_nodes_inp'  : 1,
'default_out_nodes_edge' : 1,
'default_out_nodes_out'  : 1, #len(params_data['y_var']) if type(params_data['y_var']) == list else 1,

'default_drop_nodes_inp' : 0.0,
'default_drop_nodes_edge': 0.0,
'default_drop_nodes_out' : 0.0,

'default_reps_nodes_inp' : 1,
'default_reps_nodes_edge': 1,
'default_reps_nodes_out' : 1,

'default_decay_rate'     : 1
}

default_out_nodes_inp  = params['default_out_nodes_inp' ]
default_out_nodes_edge = params['default_out_nodes_edge'] 
default_out_nodes_out  = params['default_out_nodes_out' ]

default_drop_nodes_inp = params['default_drop_nodes_inp' ] 
default_drop_nodes_edge= params['default_drop_nodes_edge'] 
default_drop_nodes_out = params['default_drop_nodes_out' ] 

default_reps_nodes_inp = params['default_reps_nodes_inp' ]
default_reps_nodes_edge= params['default_reps_nodes_edge']
default_reps_nodes_out = params['default_reps_nodes_out' ]

default_decay_rate = params['default_decay_rate' ]

In [9]:
batch_size = params_run['batch_size']
max_epoch  = params_run['max_epoch']

y_var = params_data['y_var']

In [10]:
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

if 'None' != params_data['y_resid_strat']:
    save_prefix = save_prefix+'_'+params_data['y_resid_strat']

ensure_dir_path_exists(dir_path = cache_path)

In [11]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

Using cuda device


## Load Data

In [12]:
# Data Prep ----
obs_geno_lookup          = get_data('obs_geno_lookup')
phno                     = get_data('phno')
ACGT_gene_slice_list     = get_data('KEGG_slices')
parsed_kegg_gene_entries = get_data('KEGG_entries')

In [13]:
# make sure that the given y variable is there
# single column version
# phno = phno.loc[(phno[y_var].notna()), ].copy()
# phno = phno.reset_index().drop(columns='index')

# multicolumn
# mask based on the y variables
na_array = phno[y_var].isna().to_numpy().sum(axis=1)
mask_no_na = list(0 == na_array)

phno = phno.loc[mask_no_na, ].copy()
phno = phno.reset_index().drop(columns='index')

In [14]:
# update obs_geno_lookup

tmp = phno.reset_index().rename(columns={'index': 'Phno_Idx_new'}).loc[:, ['Phno_Idx_new', 'Geno_Idx']]
tmp = pd.merge(tmp,
          tmp.drop(columns='Phno_Idx_new').drop_duplicates().reset_index().rename(columns={'index': 'Phno_Idx_Orig_new'}))
tmp = tmp.sort_values('Phno_Idx_new').reset_index(drop=True)

obs_geno_lookup = tmp.to_numpy()

In [15]:
# make holdout sets
holdout_parents = params_data['holdout_parents']

# create a mask for parent genotype
mask = mask_parents(df= phno, col_name= 'Hybrid', holdout_parents= holdout_parents)

train_mask = mask.sum(axis=1) == 0
test_mask  = mask.sum(axis=1) > 0

train_idx = train_mask.loc[train_mask].index
test_idx  = test_mask.loc[test_mask].index

In [16]:
# convert y to residual if needed

if params_data['y_resid'] == 'None':
    pass
else:
    if params_data['y_resid_strat'] == 'naive_mean':
        # use only data in the training set (especially since testers will be more likely to be found across envs)
        # get enviromental means, subtract from observed value
        for i in range(len(y_var)):
            tmp = phno.loc[train_idx, ]
            env_mean = tmp.groupby(['Env_Idx']
                        ).agg(Env_Mean = (y_var[i], 'mean')
                        ).reset_index()
            tmp = phno.merge(env_mean)
            tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
            phno = tmp.drop(columns='Env_Mean')

    if params_data['y_resid_strat'] == 'filter_mean':
        # for adjusting to environment we could use _all_ observations but ideally we will use the same set of genotypes across all observations
        def minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']],
                                    year = 2014):
            # Within each year what hybrids are most common?
            tmp = tmp.loc[(tmp.Year == year), ].groupby(['Env', 'Hybrid']).count().reset_index().sort_values('Year')

            all_envs = set(tmp.Env)
            # if we filter on the number of sites a hybrid is planted at, what is the largest number of sites we can ask for before we lose a location?
            # site counts for sets which contain all envs
            i = max([i for i in list(set(tmp.Year)) if len(set(tmp.loc[(tmp.Year >= i), 'Env'])) == len(all_envs)])

            before = len(set(tmp.loc[:, 'Hybrid']))
            after  = len(set(tmp.loc[(tmp.Year >= i), 'Hybrid']))
            print(f'Reducing {year} hybrids from {before} to {after} ({round(100*after/before)}%).')
            tmp = tmp.loc[(tmp.Year >= i), ['Env', 'Hybrid']].reset_index(drop=True)
            return tmp

        for i in range(len(y_var)):
            tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']]
            filter_hybrids = [minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']], year = i) 
                            for i in list(set(phno.Year)) ]
            env_mean = pd.concat(filter_hybrids).merge(phno, how = 'left')

            env_mean = env_mean.groupby(['Env_Idx']
                            ).agg(Env_Mean = (y_var[i], 'mean')
                            ).reset_index()

            tmp = phno.merge(env_mean)
            tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
            phno = tmp.drop(columns='Env_Mean')
            

In [17]:
# center and y value data
assert 0 == phno.loc[:, y_var].isna().sum().sum() # second sum is for multiple y_vars

y = phno.loc[:, y_var].to_numpy() # added to make multiple ys work
# use train index to prevent information leakage
y_c = y[train_idx].mean(axis=0)
y_s = y[train_idx].std(axis=0)

y = (y - y_c)/y_s

## Fit Using VNNHelper

In [18]:
myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = params, ACGT_gene_slice_list = ACGT_gene_slice_list)

################################################################################
{'default_out_nodes_inp': 1, 'default_out_nodes_edge': 1, 'default_out_nodes_out': 1, 'default_drop_nodes_inp': 0.0, 'default_drop_nodes_edge': 0.0, 'default_drop_nodes_out': 0.0, 'default_reps_nodes_inp': 1, 'default_reps_nodes_edge': 1, 'default_reps_nodes_out': 1, 'default_decay_rate': 1}
################################################################################
Retaining 43.53%, 6067/13939 Entries
Removed node "Others"


### Calculate nodes membership in each matrix and positions within each

In [19]:
### Creating Structured Matrices for Layers
M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)

### Setup Dataloader using `M_list`

In [20]:
lookup_dict = new_lookup_dict

vals = get_data('KEGG_slices')
vals = [torch.from_numpy(e).to(torch.float) for e in vals]
# restrict to the tensors that will be used
vals = torch.concat([vals[lookup_dict[i]].reshape(4926, -1) 
                     for i in M_list[0].row_inp
                    #  for i in dd[0]['inp'] # matches
                     ], axis = 1)

vals = vals.to('cuda')

In [21]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(train_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32)[:, None].squeeze(),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True 
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(test_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32)[:, None].squeeze(),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False 
)

## Structured Layer

In [22]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_list):
        super(NeuralNetwork, self).__init__()
        self.layer_list = nn.ModuleList(layer_list)
 
    def forward(self, x):
        for l in self.layer_list:
            x = l(x)
        return x

## Tiny Test Study

In [23]:
def _prep_data_to_dls(
        obs_geno_lookup,
        vals,
        batch_size,
        params_data
        ):
    y_var = params_data['y_var']

    # Data Prep ----
    phno                     = get_data('phno')
    # make sure that the given y variable is there
    # single column version
    # phno = phno.loc[(phno[y_var].notna()), ].copy()
    # phno = phno.reset_index().drop(columns='index')

    # multicolumn
    # mask based on the y variables
    na_array = phno[y_var].isna().to_numpy().sum(axis=1)
    mask_no_na = list(0 == na_array)

    phno = phno.loc[mask_no_na, ].copy()
    phno = phno.reset_index().drop(columns='index')


    # update obs_geno_lookup
    tmp = phno.reset_index().rename(columns={'index': 'Phno_Idx_new'}).loc[:, ['Phno_Idx_new', 'Geno_Idx']]
    tmp = pd.merge(tmp,
            tmp.drop(columns='Phno_Idx_new').drop_duplicates().reset_index().rename(columns={'index': 'Phno_Idx_Orig_new'}))
    tmp = tmp.sort_values('Phno_Idx_new').reset_index(drop=True)

    obs_geno_lookup = tmp.to_numpy()
    
    
    # make holdout sets
    holdout_parents = params_data['holdout_parents']

    
    # create a mask for parent genotype
    mask = mask_parents(df= phno, col_name= 'Hybrid', holdout_parents= holdout_parents)

    train_mask = mask.sum(axis=1) == 0
    test_mask  = mask.sum(axis=1) > 0

    train_idx = train_mask.loc[train_mask].index
    test_idx  = test_mask.loc[test_mask].index


    # convert y to residual if needed
    if params_data['y_resid'] == 'None':
        pass
    else:
        if params_data['y_resid_strat'] == 'naive_mean':
            # use only data in the training set (especially since testers will be more likely to be found across envs)
            # get enviromental means, subtract from observed value
            for i in range(len(y_var)):
                tmp = phno.loc[train_idx, ]
                env_mean = tmp.groupby(['Env_Idx']
                            ).agg(Env_Mean = (y_var[i], 'mean')
                            ).reset_index()
                tmp = phno.merge(env_mean)
                tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
                phno = tmp.drop(columns='Env_Mean')

        if params_data['y_resid_strat'] == 'filter_mean':
            # for adjusting to environment we could use _all_ observations but ideally we will use the same set of genotypes across all observations
            def minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']],
                                        year = 2014):
                # Within each year what hybrids are most common?
                tmp = tmp.loc[(tmp.Year == year), ].groupby(['Env', 'Hybrid']).count().reset_index().sort_values('Year')

                all_envs = set(tmp.Env)
                # if we filter on the number of sites a hybrid is planted at, what is the largest number of sites we can ask for before we lose a location?
                # site counts for sets which contain all envs
                i = max([i for i in list(set(tmp.Year)) if len(set(tmp.loc[(tmp.Year >= i), 'Env'])) == len(all_envs)])

                before = len(set(tmp.loc[:, 'Hybrid']))
                after  = len(set(tmp.loc[(tmp.Year >= i), 'Hybrid']))
                print(f'Reducing {year} hybrids from {before} to {after} ({round(100*after/before)}%).')
                tmp = tmp.loc[(tmp.Year >= i), ['Env', 'Hybrid']].reset_index(drop=True)
                return tmp

            for i in range(len(y_var)):
                tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']]
                filter_hybrids = [minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']], year = i) 
                                for i in list(set(phno.Year)) ]
                env_mean = pd.concat(filter_hybrids).merge(phno, how = 'left')

                env_mean = env_mean.groupby(['Env_Idx']
                                ).agg(Env_Mean = (y_var[i], 'mean')
                                ).reset_index()

                tmp = phno.merge(env_mean)
                tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
                phno = tmp.drop(columns='Env_Mean')


    # center and y value data
    assert 0 == phno.loc[:, y_var].isna().sum().sum() # second sum is for multiple y_vars

    y = phno.loc[:, y_var].to_numpy() # added to make multiple ys work
    # use train index to prevent information leakage
    y_c = y[train_idx].mean(axis=0)
    y_s = y[train_idx].std(axis=0)

    y = (y - y_c)/y_s



    training_dataloader = DataLoader(BigDataset(
        lookups_are_filtered = False,
        lookup_obs  = torch.from_numpy(np.array(train_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
        lookup_geno = torch.from_numpy(obs_geno_lookup),
        y =           torch.from_numpy(y).to(torch.float32)[:, None],
        G =           vals,
        G_type = 'raw',
        send_batch_to_gpu = 'cuda:0'
        ),
        batch_size = batch_size,
        shuffle = True 
    )

    validation_dataloader = DataLoader(BigDataset(
        lookups_are_filtered = False,
        lookup_obs  = torch.from_numpy(np.array(test_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
        lookup_geno = torch.from_numpy(obs_geno_lookup),
        y =           torch.from_numpy(y).to(torch.float32)[:, None],
        G =           vals,
        G_type = 'raw',
        send_batch_to_gpu = 'cuda:0'
        ),
        batch_size = batch_size,
        shuffle = False 
    )

    return training_dataloader, validation_dataloader


training_dataloader, validation_dataloader = _prep_data_to_dls(
        obs_geno_lookup= obs_geno_lookup,
        vals= vals,
        batch_size= batch_size,
        params_data= params_data
        )


In [24]:
def _mk_holdout_lists(
    possible_holdouts = [
        ## 2022 ====
        'LH244', # Held out in Hyps tuning
        ## 2021 ====
        'PHZ51', # Held out in Hyps tuning
        'PHP02',
        'PHK76',
        ## 2019 ====
        'PHT69',
        'LH195', # Held out in Hyps tuning
        ## 2017 ====
        'PHW52',
        'PHN82',
        ## 2016 ====
        'DK3IIH6',
        ## 2015 ====
        'PHB47',
        'LH82',
        ## 2014 ====
        'LH198',
        'LH185',
        'PB80',
        'CG102',
 ],
    n_holdouts = 3,
    np_seed = 6642,
):
    rng = np.random.default_rng(seed=np_seed)
    h = np.array(possible_holdouts)
    rng.shuffle(h)

    n_blocks = int(np.floor(len(h)/n_holdouts))
    blocks = np.repeat(np.array([i for i in range(n_blocks)]), [n_holdouts])

    return [h[blocks == i].tolist() for i in range(n_blocks)]

In [25]:
# JSON path and log dir
def read_hyp_exp_results(exp_name, exp_yvar, **kwargs):
    if 'log_path' in kwargs:
        if kwargs['log_path'] != None:
            log_path = kwargs['log_path']
        else:
            log_path = f'../nbs_artifacts/{exp_name}/lightning/{exp_name}__{exp_yvar}'
    else:
        log_path = f'../nbs_artifacts/{exp_name}/lightning/{exp_name}__{exp_yvar}'
    json_path = log_path+'.json'

    # get all versions, confirm that a metrics file exists
    version_logs = os.listdir(log_path)
    version_logs = [e for e in version_logs if os.path.exists(f'{log_path}/{e}/metrics.csv')]
    # extract and sort ints
    version_logs = sorted([int(e.split('_')[-1]) for e in version_logs])

    # Produce a tidy df for each version
    def _get_tidy_metrics(metrics_path, i):
        df = pd.read_csv(metrics_path)
        df = pd.melt(
                    df, 
                    id_vars= ['step', 'epoch'],
                    var_name= 'split',
                    value_vars= ['train_loss', 'val_loss'],
                    value_name= 'loss'
                ).dropna(
                ).reset_index(drop=True
                ).assign(version = i)
        return(df)
    # all training histories
    metrics = pd.concat([_get_tidy_metrics(metrics_path = f'{log_path}/version_{i}/metrics.csv', 
                                           i = i) for i in version_logs])
    
    if os.path.exists(json_path):
        return (AxClient.load_from_json_file(filepath = json_path), metrics)
    else:
        return metrics


In [26]:
# get the experiment name from cache path and then swap the suffix to be the hyperparams step
hyps_dir = [e for e in cache_path.split('/') if e != ''][-1].replace('1mods', '0hyps')

ax_client, metrics = read_hyp_exp_results(
    exp_name = hyps_dir,
    # the string if only one yvar, the whole list if multi
    exp_yvar = (lambda x: x[0] if len(x) == 1 else x)(y_var), 
    # Pass in log_path if we have a multi-y model (this will prevent log_path from being  dynamically set)
    # If None log_path will be 
    log_path = (lambda x: 
                None if len(x) == 1 
                else '../nbs_artifacts/'+hyps_dir+'/lightning/'+hyps_dir
                )(y_var)
)

idx, params, loss = ax_client.get_best_trial()
params

  return _class(
  return _class(
  return _class(
  return _class(
[INFO 06-26 13:44:40] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.


{'default_out_nodes_inp': 3,
 'default_out_nodes_edge': 25,
 'default_drop_nodes_inp': 0.10965091358870267,
 'default_drop_nodes_edge': 0.9536853883601725,
 'default_drop_nodes_out': 0.9462229676358401,
 'default_reps_nodes_inp': 1,
 'default_reps_nodes_edge': 3,
 'default_reps_nodes_out': 3,
 'default_decay_rate': 7.0,
 'default_out_nodes_out': 1}

In [1]:
holdouts = _mk_holdout_lists(
    n_holdouts = 3,
    np_seed = 6642, # the 1st entry here is the same cv that I used for 
)


for holdout in holdouts:
    tmp_params = {
                k:(params_data[k]
                if k != 'holdout_parents'
                else holdout )
                for k in params_data 
                }
    training_dataloader, validation_dataloader = _prep_data_to_dls(
            obs_geno_lookup= obs_geno_lookup,
            vals= vals,
            batch_size= batch_size,
            # Here I'm using a dictionary comprehension to make up a replacement to params_data on the fly. 
            # I'm not overwriting holdout_parents to make bugs easier to track
            params_data= tmp_params
            )

    # pulled from evaluate(parameterization)
    myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = params, ACGT_gene_slice_list = ACGT_gene_slice_list)
    M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)
    layer_list =  vnn_factory_3(M_list = M_list)
    model = NeuralNetwork(layer_list = layer_list)

    VNN = plDNN_general(model)  
    optimizer = VNN.configure_optimizers()
    logger = CSVLogger(lightning_log_dir, name=exp_name)
    logger.log_hyperparams(params={
        'params': params,
        'params_data': tmp_params
    })

    trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)
    trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


    # base name 
    prefix = ''.join([
        'vnn',
        '__'+(lambda x: x[0] if len(x) == 1 else 'multitarget')(x=y_var)+'__',
        '_'.join(holdout)
        ]
        )
    ensure_dir_path_exists(cache_path+'/models/')
    

    # add a pseudorep number
    pseudo_reps = sorted([e for e in os.listdir(cache_path+'/models/') if e[0:len(prefix)] == prefix])
    if pseudo_reps == []:
        save_name = prefix+'_rep00.pt'
    else:
        # increment replicate number
        rep_num = pseudo_reps[-1].split('_')[-1].replace('rep', '').replace('.pt', '')
        rep_num = (lambda x: '0'+x if len(x) == 1 else x)( str(1+int(rep_num)) )
        save_name = prefix+f'_rep{rep_num}.pt'

    # save the fitted model
    torch.save(model, cache_path+'models/'+save_name)