# Concrete Autoencoders dMRI for PyTorch

In [1]:
import project_path # Always import this first

In [2]:
from pathlib import Path

import torch
from torch import nn
from torch import Tensor
from torch import reshape as tshape
from torch import matmul as tmat

import numpy as np

from utils.env import DATA_PATH
from utils.logger import logger, logging_tqdm

In [3]:
ROOT_PATH = Path().cwd().parent

In [4]:
logger.info('torch version %s', torch.__version__)

[38;21m2021-06-24 11:08:15,424 - geometric-dl - INFO - torch version 1.9.0 (<ipython-input-4-a395a760577f>:1)[0m


In [5]:
logger.info('Current device: %s', torch.cuda.current_device())
logger.info('Device count: %s', torch.cuda.device_count())
logger.info('Is the GPU available? %s', torch.cuda.is_available())

# # use gpu if available, else cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('Using device: %s', torch.cuda.get_device_properties(device))

[38;21m2021-06-24 11:08:15,534 - geometric-dl - INFO - Current device: 0 (<ipython-input-5-bfdeb86c5565>:1)[0m
[38;21m2021-06-24 11:08:15,534 - geometric-dl - INFO - Device count: 1 (<ipython-input-5-bfdeb86c5565>:2)[0m
[38;21m2021-06-24 11:08:15,534 - geometric-dl - INFO - Is the GPU available? True (<ipython-input-5-bfdeb86c5565>:3)[0m
[38;21m2021-06-24 11:08:15,534 - geometric-dl - INFO - Using device: _CudaDeviceProperties(name='Quadro RTX 4000', major=7, minor=5, total_memory=8192MB, multi_processor_count=36) (<ipython-input-5-bfdeb86c5565>:7)[0m


## Concrete Autoencoder

In [6]:
# import modules to build RunBuilder and RunManager helper classes
from collections  import OrderedDict
from collections import namedtuple
from itertools import product

# Read in the hyper-parameters and return a Run namedtuple containing all the 
# combinations of hyper-parameters
class RunBuilder():
  @staticmethod
  def get_runs(params):

    Run = namedtuple('Run', params.keys())

    runs = []
    for v in product(*params.values()):
      runs.append(Run(*v))
    
    return runs

In [7]:
# put all hyper params into a OrderedDict, easily expandable
params = OrderedDict(
    lr = [.001],
    batch_size = [256]
#     batch_size = [64]
)

## Experiments

In [12]:
import pickle as pk
from datetime import datetime

from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from utils.dataset import MRISelectorSubjDataset
from utils.concrete import (
    ConcreteAutoencoderFeatureSelector, 
    decoder_1l, 
    decoder_2l, 
    decoder_3l)

In [22]:
def train_model(train_subject, test_subject, n_means=500, num_epochs=2000, decoder=decoder_1l):
    strftime = "%Y%m%d%H%M%S"
    writer = SummaryWriter(log_dir=Path(ROOT_PATH, 'runs', datetime.now().strftime(strftime)))

    torch.manual_seed(14)

    for run in RunBuilder.get_runs(params):
        now = datetime.now()
        model_info_template_str = f'{now:strftime}_{run}_K={n_means}_epoch={num_epochs}_test={test_subject[0]}_dec={decoder.__name__}'

        checkpoint_path = str(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_runtime.h5'))
        monitor_callback = ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=True)

        root_dir = Path(ROOT_PATH, 'data')
        dataf = 'data_.hdf5'
        headerf = 'header_.csv'
        subj_list_train = np.array(train_subject)
        subj_list_valid = np.array(test_subject)

        train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
        train_gen = DataLoader(
            train_set, 
            batch_size = run.batch_size, 
            shuffle = True, 
            num_workers = 0, 
            pin_memory=True, 
            drop_last=True)

        # for the validation dataset
        valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
        valid_gen = DataLoader(
            valid_set, 
            batch_size = run.batch_size, 
            shuffle = False, 
            num_workers = 0,
            pin_memory=True, 
            drop_last=True)

        # 1st time
        checkpt = False
        # Continue training
        # checkpt = True
        # temp = Tensor([10]) # check last value if necessary

        selector = ConcreteAutoencoderFeatureSelector(
            K=n_means,
            decoder=decoder,
            device=device,
            num_features=n_means, 
            num_epochs=num_epochs, 
            learning_rate=run.lr, 
            start_temp=10, 
            min_temp=0.1, 
            tryout_limit=1, 
            input_dim=1344, 
            checkpt = checkpt, 
            callback=monitor_callback, 
            writer=writer, 
            path = ROOT_PATH)#,losstrain=losstrain,lossval=lossval)    

        selector.fit(X=train_gen, val_X=valid_gen)

        model = selector.get_params()
        torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_params.pt'))

        indices = selector.get_indices().to('cpu')
        logger.info(np.sort(indices))
        np.savetxt(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}.txt'), np.array(indices, dtype=int), fmt='%d')

    torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'epoch={num_epochs}_net.pth'))

In [23]:
%%time

train_model([11, 12, 13, 14], [15], num_epochs=1, decoder=decoder_2l)

[38;21m2021-06-24 11:38:03,830 - geometric-dl - INFO - steps per epoch: 1830 (feature_selector.py:44)[0m
[38;21m2021-06-24 11:38:03,931 - geometric-dl - INFO - <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x000001DE90B3FA90> (feature_selector.py:73)[0m
[38;21m2021-06-24 11:38:03,939 - geometric-dl - INFO - epoch: 0/1 (feature_selector.py:84)[0m
[38;21m2021-06-24 11:38:03,941 - geometric-dl - INFO - mean max of probabilities: 0.00082959, temperature: 10.00000000 (feature_selector.py:85)[0m
[38;21m2021-06-24 11:38:03,981 - geometric-dl - INFO - iteration: 0 (feature_selector.py:109)[0m
[38;21m2021-06-24 11:38:06,253 - geometric-dl - INFO - iteration: 500 (feature_selector.py:109)[0m
[38;21m2021-06-24 11:38:08,473 - geometric-dl - INFO - iteration: 1000 (feature_selector.py:109)[0m
[38;21m2021-06-24 11:38:10,683 - geometric-dl - INFO - iteration: 1500 (feature_selector.py:109)[0m
[38;21m2021-06-24 11:38:13,175 - geometric-dl - INFO - loss: 0.017

In [11]:
print(np.sort(selector.get_indices().to('cpu')))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices().to('cpu'), dtype=int), fmt='%d')

NameError: name 'selector' is not defined

In [None]:
a = np.loadtxt('./runs/textfiles/Run(lr=0.001, batch_size=256)K=500_epoch=2000_test15_decl2.txt')
a = np.sort(a.astype(int))
print(a)

In [None]:
testsubj = 14
testsubjstr = '14'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

"""def pad_collate(batch):
    xx = list(zip(*batch))
    xx_pad = pad_sequence(torch.as_tensor(xx), batch_first=True, padding_value=0)
    return xx_pad #, xlens"""

#torch.manual_seed(14)

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './MUDI/data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 12, 13, 15])
    subj_list_valid = np.array([14])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    """### Allocate memory for losses
    n_batch=0   # Count how many mini-batches of size mbatch we created
    for j,signals in enumerate(train_gen):
        n_batch = n_batch+1
        signals = signals[:,:,ind_MUDI]
        print(signals.size())
    losstrain = np.zeros((num_epochs,n_batch)) + np.nan
    
    n_batch=0   # Count how many mini-batches of size mbatch we created
    for j,signals in enumerate(valid_gen):
        n_batch = n_batch+1
    lossval = np.zeros((num_epochs,n_batch)) + np.nan"""
    
    path = './runs/models/check14/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = False
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    

    #selector.fit(X=train_gen, val_X=valid_gen)
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 13
testsubjstr = '13'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)

    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 12, 14, 15])
    subj_list_valid = np.array([13])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    path = './runs/models/check13/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    #torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
model_file.close()

In [None]:
torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 12
testsubjstr = '12'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 13, 14, 15])
    subj_list_valid = np.array([12])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    path = './runs/models/check12/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 11
testsubjstr = '11'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([12, 13, 14, 15])
    subj_list_valid = np.array([11])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)  
    
    path = './runs/models/check11/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
for run in RunBuilder.get_runs(params):
    for trial in range(3):
        logdir = "./runs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S") + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_trial' + str(trial)

        """tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
        monitor_callback = keras.callbacks.ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_runtime'  + '_trial' + str(trial) + '.h5', monitor='val_loss', verbose=0, save_weights_only=True)

        trainset = MRISelectorSubjDataset(root_dir='./data', dataf='data_.hdf5', headerf ='header_.csv',
                                      subj_list=np.array([11, 12, 13, 14, 15]), batch_size=run.batch_size)"""
        
        tensorboard_callback = torch.utils.tensorboard(log_dir=logdir)
        monitor_callback = pytorch_lightning.callbacks.ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)

        root_dir = './data'
        dataf = 'data_.hdf5'
        headerf = 'header_.csv'
        subj_list = np.array([11, 12, 13, 14, 15])

        train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list)
        train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True)
        # for the validation dataset
        #valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list)
        #valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False)
        
        selector = ConcreteAutoencoderFeatureSelector(K=n_means, output_function=dec, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                      tryout_limit=5, input_dim=1344, callback=[tensorboard_callback, monitor_callback])

        selector.fit(X=trainset)

        model = selector.get_params()

        model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_trial' + str(trial) + '.h5')