# Concrete Autoencoders dMRI for PyTorch

In [1]:
import project_path # Always import this first

In [2]:
from pathlib import Path

import torch
from torch import nn
from torch import Tensor
from torch import reshape as tshape
from torch import matmul as tmat

import numpy as np

from utils.env import DATA_PATH
from utils.logger import logger, logging_tqdm

In [3]:
ROOT_PATH = Path().cwd().parent

In [4]:
logger.info('torch version %s', torch.__version__)

[38;21m2021-06-25 12:34:40,708 - geometric-dl - INFO - torch version 1.9.0 (<ipython-input-4-a395a760577f>:1)[0m


In [5]:
# use gpu if available, else cpu
has_cuda = torch.cuda.is_available()

logger.info('Is the GPU available? %s', has_cuda)
logger.info('Current device: %s', torch.cuda.current_device())
logger.info('Device count: %s', torch.cuda.device_count())

device = torch.device('cuda' if has_cuda else 'cpu')
if has_cuda:
    logger.info('Using device: %s', torch.cuda.get_device_properties(device))

[38;21m2021-06-25 12:34:40,799 - geometric-dl - INFO - Is the GPU available? True (<ipython-input-5-a91727d1b032>:4)[0m
[38;21m2021-06-25 12:34:40,799 - geometric-dl - INFO - Current device: 0 (<ipython-input-5-a91727d1b032>:5)[0m
[38;21m2021-06-25 12:34:40,799 - geometric-dl - INFO - Device count: 1 (<ipython-input-5-a91727d1b032>:6)[0m
[38;21m2021-06-25 12:34:40,799 - geometric-dl - INFO - Using device: _CudaDeviceProperties(name='Quadro RTX 4000', major=7, minor=5, total_memory=8192MB, multi_processor_count=36) (<ipython-input-5-a91727d1b032>:10)[0m


## Experiments

In [8]:
import pickle as pk
from datetime import datetime

from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from utils.dataset import MRISelectorSubjDataset
from utils.concrete import (
    ConcreteAutoencoderFeatureSelector, 
    decoder_1l, 
    decoder_2l, 
    decoder_3l)

In [6]:
# import modules to build RunBuilder and RunManager helper classes
from collections  import OrderedDict
from collections import namedtuple
from itertools import product

# Read in the hyper-parameters and return a Run namedtuple containing all the 
# combinations of hyper-parameters
class RunBuilder():
  @staticmethod
  def get_runs(params):

    Run = namedtuple('Run', params.keys())

    runs = []
    for v in product(*params.values()):
      runs.append(Run(*v))
    
    return runs

In [9]:
def train_model(train_subject, test_subject, params):
    """
    Trains the ConcreteAutoencoderFeatureSelector
    
    Parameters:
        train_subject (List): subjects to train on
        test_subject (List): subjects to test on
        params (Dict): model parameters to grid search on.
    """
    strftime = "%Y%m%d%H%M%S"
    writer = SummaryWriter(log_dir=Path(ROOT_PATH, 'runs', datetime.now().strftime(strftime)))

    torch.manual_seed(14)

    for run in RunBuilder.get_runs(params):
        now = datetime.now()
        model_info_template_str = f'{now:strftime}_{run}_K={run.n_features}_epoch={run.num_epochs}_test={test_subject[0]}_dec={run.decoder.__name__}'

        checkpoint_path = str(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_runtime.h5'))
        monitor_callback = ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=True)

        root_dir = Path(ROOT_PATH, 'data')
        dataf = 'data_.hdf5'
        headerf = 'header_.csv'
        subj_list_train = np.array(train_subject)
        subj_list_valid = np.array(test_subject)

        train_set = MRISelectorSubjDataset(root_dir, dataf, headerf, subj_list_train)
        train_gen = DataLoader(
            train_set, 
            batch_size = run.batch_size, 
            shuffle = True, 
            num_workers = 0, 
            pin_memory=True, 
            drop_last=True)

        # for the validation dataset
        valid_set = MRISelectorSubjDataset(root_dir, dataf,headerf,subj_list_valid)
        valid_gen = DataLoader(
            valid_set, 
            batch_size = run.batch_size, 
            shuffle = False, 
            num_workers = 0,
            pin_memory=True, 
            drop_last=True)

        # 1st time
        checkpt = False
        # Continue training
        # checkpt = True
        # temp = Tensor([10]) # check last value if necessary

        selector = ConcreteAutoencoderFeatureSelector(
            K=run.n_features,
            decoder=run.decoder,
            device=device,
            num_features=run.n_features, 
            num_epochs=run.num_epochs, 
            learning_rate=run.lr, 
            start_temp=10, 
            min_temp=0.1, 
            tryout_limit=1, 
            input_dim=1344, 
            checkpt = checkpt, 
            callback=monitor_callback, 
            writer=writer, 
            path = ROOT_PATH)#,losstrain=losstrain,lossval=lossval)    

        selector.fit(X=train_gen, val_X=valid_gen)

        model = selector.get_params()
        torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}_params.pt'))

        indices = selector.get_indices().to('cpu')
        logger.info(np.sort(indices))
        np.savetxt(Path(ROOT_PATH, 'runs', 'models', f'{model_info_template_str}.txt'), np.array(indices, dtype=int), fmt='%d')

    torch.save(model.state_dict(), Path(ROOT_PATH, 'runs', 'models', f'epoch={num_epochs}_net.pth'))

### Model training

We use a learning rate of 0.001, batch size of 265 and 2000 epochs. 2000 Epochs is likely not enough to get a high mean max of probabilities, but otherwise training takes too long. Our input size is 1344, so for the latent space we take half that, and continue halving for five more latent space sizes. Lastly we have 3 decoders of various complexities. `decoder_1l` being the least complex and `decoder_3l` the most complex.

In [10]:
train_model([11, 12, 13, 14], [15], OrderedDict(
    lr = [.001],
    batch_size = [256],
    num_epochs = [2000],
    n_features = [21, 42, 84, 168, 336, 672], # latent space sizes
    decoder = [decoder_1l, decoder_2l, decoder_3l]
))

[38;21m2021-06-25 12:34:48,299 - geometric-dl - INFO - steps per epoch: 1830 (feature_selector.py:44)[0m
[38;21m2021-06-25 12:35:00,906 - geometric-dl - INFO - epoch: 0/2000, loss: 470.0418, val loss: 135.8331 (feature_selector.py:139)[0m
[38;21m2021-06-25 12:35:00,908 - geometric-dl - INFO - mean max of probabilities: 0.00082948, temperature: 9.97169192 (feature_selector.py:141)[0m
[38;21m2021-06-25 12:35:10,795 - geometric-dl - INFO - epoch: 1/2000, loss: 469.0682, val loss: 135.4150 (feature_selector.py:139)[0m
[38;21m2021-06-25 12:35:10,797 - geometric-dl - INFO - mean max of probabilities: 0.00174358, temperature: 9.94346398 (feature_selector.py:141)[0m
[38;21m2021-06-25 12:35:20,555 - geometric-dl - INFO - epoch: 2/2000, loss: 468.1396, val loss: 136.4905 (feature_selector.py:139)[0m
[38;21m2021-06-25 12:35:20,563 - geometric-dl - INFO - mean max of probabilities: 0.00469668, temperature: 9.91531595 (feature_selector.py:141)[0m
[38;21m2021-06-25 12:35:30,435 - geom

KeyboardInterrupt: 