# Concrete Autoencoders dMRI for PyTorch

In [1]:
import project_path # Always import this first

In [2]:
from pathlib import Path

import numpy as np
import torch

from utils.env import DATA_PATH
from utils.logger import logger, logging_tqdm

In [3]:
ROOT_PATH = Path().cwd().parent

In [4]:
logger.info('torch version %s', torch.__version__)

[38;21m2021-07-02 17:17:27,457 - geometric-dl - INFO - torch version 1.9.0 (<ipython-input-4-a395a760577f>:1)[0m


In [5]:
# use gpu if available, else cpu
has_cuda = torch.cuda.is_available()

logger.info("Is the GPU available? %s", has_cuda)
logger.info("Current device: %s", torch.cuda.current_device())
logger.info("Device count: %s", torch.cuda.device_count())

device = torch.device("cuda" if has_cuda else "cpu")
if has_cuda:
    logger.info("Using device: %s", torch.cuda.get_device_properties(device))
else:
    logger.warning("No GPU dectected! Training will be extremly slow")

[38;21m2021-07-02 17:17:27,496 - geometric-dl - INFO - Is the GPU available? True (<ipython-input-5-a5029b49c980>:4)[0m
[38;21m2021-07-02 17:17:27,498 - geometric-dl - INFO - Current device: 0 (<ipython-input-5-a5029b49c980>:5)[0m
[38;21m2021-07-02 17:17:27,500 - geometric-dl - INFO - Device count: 1 (<ipython-input-5-a5029b49c980>:6)[0m
[38;21m2021-07-02 17:17:27,501 - geometric-dl - INFO - Using device: _CudaDeviceProperties(name='NVIDIA GeForce GTX 1080', major=6, minor=1, total_memory=8118MB, multi_processor_count=20) (<ipython-input-5-a5029b49c980>:10)[0m


## Experiments

In [6]:
import pickle as pk
from datetime import datetime

from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from utils.concrete import (
    ConcreteAutoencoderFeatureSelector,
    decoder_1l,
    decoder_2l,
    decoder_3l,
)
from utils.dataset import MRISelectorSubjDataset

In [7]:
# import modules to build RunBuilder and RunManager helper classes
from collections  import OrderedDict
from collections import namedtuple
from itertools import product

# Read in the hyper-parameters and return a Run namedtuple containing all the 
# combinations of hyper-parameters
class RunBuilder():
  @staticmethod
  def get_runs(params):

    Run = namedtuple('Run', params.keys())

    runs = []
    for v in product(*params.values()):
      runs.append(Run(*v))
    
    return runs

In [8]:
def train_model(train_subject, test_subject, params):
    """
    Trains the ConcreteAutoencoderFeatureSelector

    Parameters:
        train_subject (List): subjects to train on
        test_subject (List): subjects to test on
        params (Dict): model parameters to grid search on.
    """
    strftime = "%Y%m%d%H%M%S"
    writer = SummaryWriter(
        log_dir=Path(ROOT_PATH, "runs", datetime.now().strftime(strftime))
    )

    torch.manual_seed(14)

    for run in RunBuilder.get_runs(params):
        logger.info("running: %s", run)
        
        del_feats = [1148, 912]
        
        now = datetime.now()
        logger.info(run.exclude)
        exclude_str = '-'.join(map(str, run.exclude))
        model_info_template_str = f"{now:%Y%m%d%H%M%S}_lr={run.lr}_batch_size={run.batch_size}_num_epochs={run.num_epochs}_n_features={run.n_features}_decoder={run.decoder.__name__}_test={test_subject[0]}_exclude={exclude_str}"

        checkpoint_path = str(
            Path(ROOT_PATH, "runs", "models", f"{model_info_template_str}_runtime.h5")
        )
        monitor_callback = ModelCheckpoint(
            checkpoint_path, monitor="val_loss", verbose=True
        )

        root_dir = Path(ROOT_PATH, "data")
        dataf = "data_.hdf5"
        headerf = "header_.csv"
        subj_list_train = np.array(train_subject)
        subj_list_valid = np.array(test_subject)

        train_set = MRISelectorSubjDataset(root_dir, dataf, headerf, subj_list_train, run.exclude)

        train_gen = DataLoader(
            train_set,
            batch_size=run.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        


#         # for the validation dataset
#         valid_set = MRISelectorSubjDataset(root_dir, dataf, headerf, subj_list_valid)
#         valid_gen = DataLoader(
#             valid_set,
#             batch_size=run.batch_size,
#             shuffle=False,
#             num_workers=0,
#             pin_memory=True,
#             drop_last=True,
#         )

        # 1st time
        checkpt = False
        # Continue training
        # checkpt = True
        # temp = Tensor([10]) # check last value if necessary

        selector = ConcreteAutoencoderFeatureSelector(
            K=run.n_features,
            decoder=run.decoder,
            device=device,
            num_features=run.n_features,
            num_epochs=run.num_epochs,
            learning_rate=run.lr,
            start_temp=10,
            min_temp=0.1,
            tryout_limit=1,
            input_dim=1344 - len(run.exclude),
            checkpt=checkpt,
            callback=monitor_callback,
            writer=writer,
            path=ROOT_PATH,
        )  # ,losstrain=losstrain,lossval=lossval)

        selector.fit(X=train_gen, val_X=None)

        model = selector.get_params()
        torch.save(
            model.state_dict(),
            Path(ROOT_PATH, "runs", "models", f"{model_info_template_str}_params.pt"),
        )

        indices = selector.get_indices().to("cpu")
        logger.info(np.sort(indices))
        np.savetxt(
            Path(ROOT_PATH, "runs", "models", f"{model_info_template_str}.txt"),
            np.array(indices, dtype=int),
            fmt="%d",
        )

### Model training

We use a learning rate of 0.001, batch size of 265 and 2000 epochs. 2000 Epochs is likely not enough to get a high mean max of probabilities, but otherwise training takes too long. Our input size is 1344, so for the latent space we take half that, and continue halving for five more latent space sizes. Lastly we have 3 decoders of various complexities. `decoder_1l` being the least complex and `decoder_3l` the most complex.

In [10]:
train_model(
    [11, 12, 13, 14],
    [15],
    OrderedDict(
        lr=[0.001],
        batch_size=[256],
        num_epochs=[2000],
        n_features=[21, 42, 84, 168, 336, 672],  # latent space sizes
        decoder=[decoder_1l, decoder_2l, decoder_3l],
        exclude=[[]] # features to exlude from training
    ),
)

[38;21m2021-07-02 17:18:17,192 - geometric-dl - INFO - running: Run(lr=0.001, batch_size=256, num_epochs=2000, n_features=21, decoder=<function decoder_1l at 0x7f134d7b0280>, exclude=[]) (<ipython-input-8-cd7596b3d763>:18)[0m
[38;21m2021-07-02 17:18:17,194 - geometric-dl - INFO - [] (<ipython-input-8-cd7596b3d763>:23)[0m
[38;21m2021-07-02 17:18:18,925 - geometric-dl - INFO - steps per epoch: 1830 (feature_selector.py:62)[0m
[38;21m2021-07-02 17:18:31,100 - geometric-dl - INFO - mean max of probabilities: 0.00084876, temperature: 9.97700064 (feature_selector.py:167)[0m


KeyboardInterrupt: 

In [9]:
train_model(
    [11, 12, 13, 14],
    [15],
    OrderedDict(
        lr=[0.001],
        batch_size=[256],
        num_epochs=[2000],
        n_features=[336],  # latent space sizes
        decoder=[decoder_3l],
        exclude=[[912, 1148]]
    ),
)

[38;21m2021-07-02 17:17:31,525 - geometric-dl - INFO - running: Run(lr=0.001, batch_size=256, num_epochs=2000, n_features=336, decoder=<function decoder_3l at 0x7f134d7b03a0>, exclude=[912, 1148]) (<ipython-input-8-cd7596b3d763>:18)[0m
[38;21m2021-07-02 17:17:31,527 - geometric-dl - INFO - [912, 1148] (<ipython-input-8-cd7596b3d763>:23)[0m
[38;21m2021-07-02 17:17:32,349 - geometric-dl - INFO - steps per epoch: 1830 (feature_selector.py:62)[0m
[38;21m2021-07-02 17:17:48,934 - geometric-dl - INFO - mean max of probabilities: 0.00083486, temperature: 9.97700064 (feature_selector.py:167)[0m
[38;21m2021-07-02 17:18:02,857 - geometric-dl - INFO - mean max of probabilities: 0.00176998, temperature: 9.95405417 (feature_selector.py:167)[0m


KeyboardInterrupt: 