# Concrete Autoencoders dMRI for PyTorch

In [1]:
import project_path # Always import this first

In [2]:
from pathlib import Path

import torch
from torch import nn
from torch import Tensor
from torch import reshape as tshape
from torch import matmul as tmat

import numpy as np

from utils.env import DATA_PATH
from utils.logger import logger, logging_tqdm

In [3]:
ROOT_PATH = Path().cwd().parent

In [4]:
logger.info('torch version %s', torch.__version__)

[38;21m2021-06-11 19:45:08,195 - geometric-dl - INFO - torch version 1.8.1 (<ipython-input-4-a395a760577f>:1)[0m


In [5]:
logger.info('Current device: %s', torch.cuda.current_device())
logger.info('Device count: %s', torch.cuda.device_count())
logger.info('Is the GPU available? %s', torch.cuda.is_available())

# # use gpu if available, else cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('Using device: %s', torch.cuda.get_device_properties(device))

[38;21m2021-06-11 19:45:08,243 - geometric-dl - INFO - Current device: 0 (<ipython-input-5-bfdeb86c5565>:1)[0m
[38;21m2021-06-11 19:45:08,244 - geometric-dl - INFO - Device count: 1 (<ipython-input-5-bfdeb86c5565>:2)[0m
[38;21m2021-06-11 19:45:08,245 - geometric-dl - INFO - Is the GPU available? True (<ipython-input-5-bfdeb86c5565>:3)[0m
[38;21m2021-06-11 19:45:08,247 - geometric-dl - INFO - Using device: _CudaDeviceProperties(name='NVIDIA GeForce GTX 1080', major=6, minor=1, total_memory=8118MB, multi_processor_count=20) (<ipython-input-5-bfdeb86c5565>:7)[0m


## MUDI data

In [12]:
# packages related to data reading
import pandas as pd
import os
import h5py

# pytorch
from torch.utils.data import Dataset, DataLoader

In [13]:
class MRISelectorSubjDataset(Dataset):
    """MRI dataset to select features from."""
    
    # pytorch
    def __init__(self, root_dir, dataf, headerf, subj_list):
        """
        batch_size & shuffle are defined with 'DataLoader' in pytorch 

        Args:
            root_dir (string): Directory with the .csv files
            data (string): Data .csv file
            header (string): Header .csv file
            subj_list (list): list of all the subjects to include
        """
        
        self.root_dir = root_dir
        self.dataf = dataf
        self.headerf = headerf
        self.subj_list = subj_list
        
        # load the header
        subj = self.subj_list[0]
        self.header = pd.read_csv(os.path.join(self.root_dir, self.headerf), index_col=0).to_numpy()
        self.ind = self.header[np.isin(self.header[:,1],self.subj_list),0]
        
        self.indexes = np.arange(len(self.ind))
        
    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.ind)
    
    def __getitem__(self, index):
        """Generates one sample of data"""
        indexes = self.indexes[index]
        
        # Find list of IDs
        #list_IDs_temp = [self.ind[k] for k in indexes]
        list_IDs_temp = self.ind[indexes]
        
        h5f = h5py.File(os.path.join(self.root_dir, self.dataf), 'r')
        X = h5f.get('data1')
        X = X[list_IDs_temp,:]
        
        return X

## Concrete Autoencoder

In [14]:
import math
import pickle as pk

import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.base import Callback
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

In [15]:
class ConcreteSelect(nn.Module):
    
    def __init__(self, output_dim, input_shape, n_features = 500, start_temp = 10.0, min_temp = 0.1, alpha = 0.99999, **kwargs):
        super(ConcreteSelect, self).__init__(**kwargs)
        # encoder
        self.output_dim = output_dim
        self.input_shape = input_shape # the input layer has output (None,N_params_in). In this case, probably equal to input_dim
        self.start_temp = start_temp
        #self.min_temp = K.constant(min_temp)
        self.min_temp = nn.init.constant_(Tensor(np.zeros(1)),min_temp).to(device)
        #self.alpha = K.constant(alpha)
        self.alpha = nn.init.constant_(Tensor(np.zeros(1)),alpha).to(device)
        #self.name = name
              
        # equivalent to build in Keras
        self.temp = Variable(Tensor([self.start_temp]), requires_grad = False).to(device)
        tensor_logits = nn.init.xavier_normal_(torch.empty(self.output_dim,self.input_shape)).to(device)
        self.logits = nn.Parameter(tensor_logits, requires_grad = True).to(device)

        # for the decoder, we define three different Linear/dense layers and the activation function
        self.dense800 = nn.Linear(n_features,800)
        #self.dense800 = nn.Linear(500,800) # the example for the standard 500 features value
        self.dense1000 = nn.Linear(800,1000)
        self.dense1344 = nn.Linear(1000,1344)
        self.act = nn.LeakyReLU(0.2)
        
    # equivalent to call in Keras -> encoder, the concrete layer itself   
    def encoder(self, X, training = None):
        
        uniform = torch.rand(self.logits.size()).to(device)
        gumbel = -torch.log(-torch.log(uniform)).to(device)
        self.temp = torch.maximum(self.min_temp, self.temp * self.alpha).to(device)
        #print('temperature {}'.format(self.temp))
        #noisy_logits = (self.logits + gumbel.to(device)) / self.temp
        noisy_logits = ((self.logits + gumbel) / self.temp).to(device)
        samples = F.softmax(noisy_logits, dim = 1)
                
        #numClasses = self.logits.size()[1]
        dim_argmax = len(self.logits.size())-1
        discrete_logits = F.one_hot(torch.argmax(self.logits.to(device),dim_argmax),num_classes = self.logits.size()[1])
        
        # probably unnecessary
        if training is None:
            training = self.training
        
        if self.training:
            self.selections = samples
        else:
            self.selections = discrete_logits
        
        #Y = torch.dot(X,torch.transpose(self.selections, 0, 1)) 
        # dot is not exactly equal to a dot product, it could be a matrix product in keras 
        Y = torch.matmul(X,torch.transpose(self.selections.float(), 0, 1))
        return Y
    
    # decoder: we suppose the two-layers scheme. In keras this is defined outside
    def decoder(self,x):
        #x.to("cpu")
        x = self.act(self.dense800(x))
        x = self.act(self.dense1000(x))
        x = self.dense1344(x)
        
        return x
    
    def forward(self, X, training = None):
        y = self.encoder(X) # selected features
        x = self.decoder(y) # reconstructed signals

        return x, y

In [16]:
class StopperCallback(EarlyStopping):
    
    def __init__(self, mean_max_target = 0.998):#, writer=None):
        self.mean_max_target = mean_max_target
        #self.writer = writer
        #super(StopperCallback, self).__init__(monitor = '', patience = float('inf'), verbose = 1, mode = 'max')#, baseline = self.mean_max_target)
        super(StopperCallback, self).__init__(monitor = '', patience = float('inf'), verbose = True, mode = 'max')

In [55]:
class ConcreteAutoencoderFeatureSelector():
    
    #def __init__(self, K, output_function, num_epochs = 100, learning_rate = 0.001, start_temp = 10.0, min_temp = 0.1, tryout_limit = 5, input_dim = 1344, callback=None, writer=None): #batch_size = None, 
    def __init__(self, K, num_features = 500, num_epochs = 100, learning_rate = 0.001, start_temp = 10.0, min_temp = 0.1, tryout_limit = 5, input_dim = 1344, checkpt=True, callback=None, writer=None, path = ''):#, losstrain=None, lossval=None): #batch_size = None, 
        self.K = K # equivalent to output_dim
        # self.output_function = output_function # this function is now included in the ConcreteSelect class
        # but now we have to define the number of features to be extracted from the encoder
        self.num_features = num_features
        self.num_epochs = num_epochs
#         self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.start_temp = start_temp
        self.min_temp = min_temp
        self.tryout_limit = tryout_limit
        self.input_dim = input_dim
        self.checkpt = checkpt
        self.callback = callback
        self.writer = writer
        self.path = path #str(Path(ROOT_PATH, 'runs', 'models'))
        #self.losstrain = losstrain
        #self.lossval = lossval
        
    def fit(self, X, val_X=None):
#         if self.batch_size is None:
#             self.batch_size = max(len(X) // 256, 16)
        
        num_epochs = self.num_epochs
        steps_per_epoch = X.__len__()#(len(X) + self.batch_size - 1) // self.batch_size
        logger.info("steps per epoch: %s", steps_per_epoch)
        writer = self.writer
        #losses,losses_val=[],[]
        
        for i in range(self.tryout_limit):
            
            alpha = math.exp(math.log(self.min_temp / self.start_temp) / (num_epochs * steps_per_epoch))
            
            # we apply the model
            self.model = ConcreteSelect(self.K, self.input_dim, self.num_features, self.start_temp, self.min_temp, alpha).cuda()
            
            # we define the loss and the optimizer functions
            criterion = nn.MSELoss().cuda()
            optimizer = torch.optim.Adam(self.model.parameters(),lr=self.learning_rate) 
            
            if self.checkpt==True:
                checkpoint = torch.load(self.path)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                epoch_check = checkpoint['epoch']
                loss = checkpoint['loss']
            self.model.train()
            
            stopper_callback = StopperCallback()#writer=self.writer)

            logger.info('%s', self.callback)
            
            for epoch in range(num_epochs):
                if self.checkpt==True:
                    if epoch < epoch_check:
                        continue
                
                value_stop = torch.mean(torch.max(F.softmax(self.model.logits, dim = 1),1).values)
                logger.info('mean max of probabilities: %s %s %s', value_stop, '- temperature', self.model.temp)
                
                if value_stop >= stopper_callback.mean_max_target:
                    break
                
                self.model.train()
                with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA],
                    schedule=torch.profiler.schedule(
                        wait=2,
                        warmup=3,
                        active=6),
                    on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs', worker_name='worker0'),
                    record_shapes=True,
                    profile_memory=True,
                    with_stack=True
                ) as p:
                    for j, signals in enumerate(X):
                        signals = signals.to(device)
                        # just to check how it's going, the next two lines can be commented or removed
                        if(j%500 == 0):
                            logger.info("iteration: %s", j)

                        # steps in pytorch:
                        # 1. Initialise gradients at the start of each batch
                        # 2. Run the forward and then the backwards pass
                        # 3. Compute the loss and update the weights

                        # Initialise gradients
                        optimizer.zero_grad()

                        outputs, selected_features = self.model(signals)
                        loss = criterion(outputs, signals) # like criterion(yhat,target) -> the target in the autoencoder is the input

                        writer.add_scalar(str(Path(ROOT_PATH, 'runs', 'scalars')), loss, epoch)

                        #print('Epoch {}: Loss = {}'.format(epoch+1, loss.item())) # just to check how it's going

                        # Backward pass
                        loss.backward()

                        # Compute the loss and update the weights
                        optimizer.step()
                        p.step()
                        
                if val_X is not None:
                    # Evaluate the model
                    self.model.eval()
                    
                    #steps_per_epoch_val = val_X.__len__()
                    for j, signals in enumerate(val_X):
                        signals = signals.to(device)
                        outputs_pred, selected_features_pred = self.model(signals)

                        loss = criterion(outputs_pred,signals)
                
                # save for checkpoint
                torch.save({'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                        }, self.path)
                
            num_epochs *= 2
        
        self.probabilities = F.softmax(self.model.logits, dim = 1)
        self.indices = torch.argmax(self.model.logits, 1)
        
        return self
    
    def get_indices(self):
        val = torch.argmax(self.model.logits, 1)
        return val
        #return K.get_value(K.argmax(self.model.get_layer('concrete_select').logits))
    
    def get_mask(self):
        #nn.functional.one_hot(torch.argmax(self.logits),list(self.logits.size())[1], dim = )
        dim_argmax = len(self.model.logits.size())-1
        val = torch.sum(nn.functional.one_hot(torch.argmax(self.model.logits,dim_argmax),self.model.logits.size()[1]))
        return val
        #return K.get_value(K.sum(K.one_hot(K.argmax(self.model.get_layer('concrete_select').logits), self.model.get_layer('concrete_select').logits.shape[1]), axis = 0))
    

    def get_support(self, indices = False):
        return self.get_indices() if indices else self.get_mask()
    
    def get_params(self):
        return self.model
        #return self.output_function(self.concrete_select)

In [18]:
# import modules to build RunBuilder and RunManager helper classes
from collections  import OrderedDict
from collections import namedtuple
from itertools import product

# Read in the hyper-parameters and return a Run namedtuple containing all the 
# combinations of hyper-parameters
class RunBuilder():
  @staticmethod
  def get_runs(params):

    Run = namedtuple('Run', params.keys())

    runs = []
    for v in product(*params.values()):
      runs.append(Run(*v))
    
    return runs

In [19]:
# put all hyper params into a OrderedDict, easily expandable
params = OrderedDict(
    lr = [.001],
    batch_size = [256]
#     batch_size = [64]
)

## Experiment 3: 2 layers

In [48]:
n_means = 50
num_epochs = 1
#dec = decoder_2l
#dec = mudi_net(n_meas)
decstr = 'l2'

In [49]:
testsubj = 15
testsubjstr = '15'

In [56]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

torch.manual_seed(14)

for run in RunBuilder.get_runs(params):
    model_info_template_str = f'{run}_K={n_means}_epoch={num_epochs}_test={testsubjstr}_dec={decstr}'

    checkpoint_path = str(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_runtime.h5'))
    monitor_callback = ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=True)
    
    root_dir = Path(ROOT_PATH, 'data')
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 12, 13, 14])
    subj_list_valid = np.array([15])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(
        train_set, 
        batch_size = run.batch_size, 
        shuffle = True, 
        num_workers = 15, 
        pin_memory=False, 
        drop_last=True)
    
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(
        valid_set, 
        batch_size = run.batch_size, 
        shuffle = False, 
        num_workers = 15,
        pin_memory=False, 
        drop_last=True)

    path = str(Path(ROOT_PATH, 'runs', 'models', 'check15', 'model.pt'))
    # 1st time
    checkpt = False
    # Continue training
    # checkpt = True
    # temp = Tensor([10]) # check last value if necessary
    
    selector = ConcreteAutoencoderFeatureSelector(
        K=n_means, 
        num_features=n_means, 
        num_epochs=num_epochs, 
        learning_rate=run.lr, 
        start_temp=10, 
        min_temp=0.1, 
        tryout_limit=1, 
        input_dim=1344, 
        checkpt = checkpt, 
        callback=monitor_callback, 
        writer=writer, 
        path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}.pt'))
    # save only parameters
    torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_params.pt'))
    
    print(np.sort(selector.get_indices().to('cpu')))
    np.savetxt(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}.txt'), np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    #torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')

torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'epoch={num_epochs}_net.pth'))
model_file = open(Path(ROOT_PATH, 'runs', 'models', f'epoch={num_epochs}_net.bin'),'wb')
pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
model_file.close() 

[38;21m2021-06-11 15:14:27,555 - geometric-dl - INFO - steps per epoch: 1830 (<ipython-input-55-2bb0c577d71e>:29)[0m
[38;21m2021-06-11 15:14:27,584 - geometric-dl - INFO - <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7fc227730d00> (<ipython-input-55-2bb0c577d71e>:54)[0m
[38;21m2021-06-11 15:14:28,273 - geometric-dl - INFO - mean max of probabilities: tensor(0.0008, device='cuda:0', grad_fn=<MeanBackward0>) - temperature tensor([10.], device='cuda:0') (<ipython-input-55-2bb0c577d71e>:62)[0m
[38;21m2021-06-11 15:14:29,734 - geometric-dl - INFO - iteration: 0 (<ipython-input-55-2bb0c577d71e>:85)[0m


KeyboardInterrupt: 

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_meas) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')

In [None]:
a = np.loadtxt('./runs/textfiles/Run(lr=0.001, batch_size=256)K=500_epoch=2000_test15_decl2.txt')
a = np.sort(a.astype(int))
print(a)

In [None]:
testsubj = 14
testsubjstr = '14'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

"""def pad_collate(batch):
    xx = list(zip(*batch))
    xx_pad = pad_sequence(torch.as_tensor(xx), batch_first=True, padding_value=0)
    return xx_pad #, xlens"""

#torch.manual_seed(14)

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './MUDI/data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 12, 13, 15])
    subj_list_valid = np.array([14])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    """### Allocate memory for losses
    n_batch=0   # Count how many mini-batches of size mbatch we created
    for j,signals in enumerate(train_gen):
        n_batch = n_batch+1
        signals = signals[:,:,ind_MUDI]
        print(signals.size())
    losstrain = np.zeros((num_epochs,n_batch)) + np.nan
    
    n_batch=0   # Count how many mini-batches of size mbatch we created
    for j,signals in enumerate(valid_gen):
        n_batch = n_batch+1
    lossval = np.zeros((num_epochs,n_batch)) + np.nan"""
    
    path = './runs/models/check14/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = False
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    

    #selector.fit(X=train_gen, val_X=valid_gen)
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 13
testsubjstr = '13'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)

    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 12, 14, 15])
    subj_list_valid = np.array([13])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    path = './runs/models/check13/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    #torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
model_file.close()

In [None]:
torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 12
testsubjstr = '12'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([11, 13, 14, 15])
    subj_list_valid = np.array([12])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)
    
    path = './runs/models/check12/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
testsubj = 11
testsubjstr = '11'

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
writer = SummaryWriter()

for run in RunBuilder.get_runs(params):
    monitor_callback = ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)
    
    root_dir = './data'
    dataf = 'data_.hdf5'
    headerf = 'header_.csv'
    subj_list_train = np.array([12, 13, 14, 15])
    subj_list_valid = np.array([11])
    
    train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_train)
    train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, pin_memory=False, drop_last=True)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 4, pin_memory=False, collate_fn = pad_collate)
    #train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True, num_workers = 0, pin_memory=False)#, collate_fn = pad_collate)
    # for the validation dataset
    valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list_valid)
    valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False, pin_memory=False, drop_last=True)  
    
    path = './runs/models/check11/model.pt'
    # 1st time
    checkpt = False
    # Continue training
    checkpt = True
    
    selector = ConcreteAutoencoderFeatureSelector(K=n_means, num_features=n_means, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                  tryout_limit=5, input_dim=1344, checkpt = checkpt, callback=monitor_callback, writer=writer, path = path)#,losstrain=losstrain,lossval=lossval)    
    
    selector.fit(X=train_gen, val_X=valid_gen)
    
    model = selector.get_params()
    
    print(np.sort(selector.get_indices()))
    np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d')
    
    #model.save('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.h5')
    torch.save(model, './runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    # save only parameters
    torch.save(model.state_dict(),'./runs/models/params_' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.pt')
    
    torch.save(model.state_dict(), os.path.join('./runs/models/','epoch{}_net.pth'.format(num_epochs)) )
    model_file = open(os.path.join('./runs/models/','epoch{}_net.bin'.format(num_epochs)),'wb')
    pk.dump(model,model_file,pk.HIGHEST_PROTOCOL)      
    model_file.close()

In [None]:
print(np.sort(selector.get_indices()))
np.savetxt('./runs/textfiles/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '.txt', np.array(selector.get_indices(), dtype=int), fmt='%d') 

In [None]:
for run in RunBuilder.get_runs(params):
    for trial in range(3):
        logdir = "./runs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S") + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_trial' + str(trial)

        """tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
        monitor_callback = keras.callbacks.ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_runtime'  + '_trial' + str(trial) + '.h5', monitor='val_loss', verbose=0, save_weights_only=True)

        trainset = MRISelectorSubjDataset(root_dir='./data', dataf='data_.hdf5', headerf ='header_.csv',
                                      subj_list=np.array([11, 12, 13, 14, 15]), batch_size=run.batch_size)"""
        
        tensorboard_callback = torch.utils.tensorboard(log_dir=logdir)
        monitor_callback = pytorch_lightning.callbacks.ModelCheckpoint('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_test' + testsubjstr + '_dec' + decstr + '_runtime.h5', monitor='val_loss', verbose=True)

        root_dir = './data'
        dataf = 'data_.hdf5'
        headerf = 'header_.csv'
        subj_list = np.array([11, 12, 13, 14, 15])

        train_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list)
        train_gen = DataLoader(train_set, batch_size = run.batch_size, shuffle = True)
        # for the validation dataset
        #valid_set = MRISelectorSubjDataset(root_dir,dataf,headerf,subj_list)
        #valid_gen = DataLoader(valid_set, batch_size = run.batch_size, shuffle = False)
        
        selector = ConcreteAutoencoderFeatureSelector(K=n_means, output_function=dec, num_epochs=num_epochs, learning_rate=run.lr, start_temp=10.0, min_temp=0.1, 
                                                      tryout_limit=5, input_dim=1344, callback=[tensorboard_callback, monitor_callback])

        selector.fit(X=trainset)

        model = selector.get_params()

        model.save_weights('./runs/models/' + f'{run}' + 'K=' + str(n_means) + '_epoch=' + str(num_epochs) + '_testnone_dec' + decstr + '_trial' + str(trial) + '.h5')