# Equivariant deep dmri

In [1]:
import logging
import sys

import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

from equideepdmri.utils.q_space import Q_SamplingSchema
from equideepdmri.network.VoxelWiseSegmentationNetwork import VoxelWiseSegmentationNetwork

In [2]:
print(f'gpus: {torch.cuda.device_count()}')
print(f'active device: {torch.cuda.get_device_properties(torch.cuda.current_device())}')

gpus: 1
active device: _CudaDeviceProperties(name='GeForce GTX 1080', major=6, minor=1, total_memory=8118MB, multi_processor_count=20)


In [8]:
class ColorFormatter(logging.Formatter):
    """Logging Formatter to add colors and count warning / errors"""

    grey = "\x1b[38;21m"
    yellow = "\x1b[33;21m"
    red = "\x1b[31;21m"
    bold_red = "\x1b[31;1m"
    reset = "\x1b[0m"
    format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"

    FORMATS = {
        logging.DEBUG: grey + format + reset,
        logging.INFO: grey + format + reset,
        logging.WARNING: yellow + format + reset,
        logging.ERROR: red + format + reset,
        logging.CRITICAL: bold_red + format + reset
    }

    def format(self, record):
        log_fmt = self.FORMATS.get(record.levelno)
        formatter = logging.Formatter(log_fmt)
        return formatter.format(record)

def init_logger(name, log_level):
    """Create a logger and add a colored formatter if not added already"""
    logger = logging.getLogger(name)
    logger.setLevel(log_level)
    if not logger.hasHandlers():
        stream = logging.StreamHandler()
        stream.setLevel(logging.DEBUG)
        stream.setFormatter(ColorFormatter())
        
        logger.addHandler(stream)

logger_name = 'mudi'
init_logger(logger_name, logging.DEBUG) # <- set this to something like logging.ERROR when training for real
logger = logging.getLogger(logger_name)

### Example from Rotation-Equivariant Deep Learning for Diffusion MRI
See: [github.com/philip-mueller/equivariant-deep-dmri](https://github.com/philip-mueller/equivariant-deep-dmri)

In [4]:
def compute_binary_label_weights(training_dataloader: DataLoader) -> torch.Tensor:
    num_P_voxels = 0.
    num_total_voxels = 0.
    for i, batch in enumerate(training_dataloader):
        target: torch.Tensor = batch['target']
        brain_mask = batch['brain_mask'].bool()
        target = target[brain_mask]
        num_P_voxels += float(target.nonzero().size(0))
        num_total_voxels += float(target.numel())

    return torch.tensor(1 - (num_P_voxels/num_total_voxels))


class RandomDMriSegmentationDataset:
    def __init__(self, N, Q, num_b0, p_size: tuple):
        self.N = N
        assert Q >= num_b0

        q_vectors = torch.rand(Q, 3)
        q_vectors[:num_b0, :] = 0.0
        self.q_sampling_schema = Q_SamplingSchema(q_vectors)

        assert len(p_size) == 3
        self.scans = torch.randn(N, Q, *p_size)
        self.targets = (torch.randn(N, *p_size) > 0.8).float()
        self.brain_masks = torch.ones(N, *p_size)

    def __len__(self):
        return self.N

    def __getitem__(self, i):
        assert isinstance(i, int)  # only batch-size == 1

        return {'sample_id': str(i), 'input': self.scans[i],
                'target': self.targets[i], 'brain_mask': self.brain_masks[i] }

In [5]:
dataset = RandomDMriSegmentationDataset(N=10, Q=8, num_b0=2, p_size=(10, 10, 10))
model = VoxelWiseSegmentationNetwork(
    q_sampling_schema_in=dataset.q_sampling_schema,
    pq_channels=[
        [7, 4]
    ],
    p_channels=[
        [20, 5],
        [10, 3],
        [5, 2],
        [1]
    ],
    pq_kernel={
        'kernel':'pq_TP',
        'p_radial_basis_type':'cosine'
    },
    p_kernel={
        'p_radial_basis_type':'cosine'
    },
    kernel_sizes=5,
    non_linearity={
        'tensor_non_lin':'gated',
        'scalar_non_lin':'swish'
    },
    q_reduction={
        'reduction':'length_weighted_average'
    }
)
print(model)

VoxelWiseSegmentationNetwork(
  (pq_layers): ModuleList(
    (0): Sequential(
      (conv): <EquivariantPQLayer (1,)->(11, 4)>
      (non_linearity): GatedBlockNonLin()
    )
  )
  (q_reduction_layer): QLengthWeightedAvgPool(
    (radial_basis): FiniteElement_RadialBasis(
      (model): FC()
    )
  )
  (p_layers): ModuleList(
    (0): Sequential(
      (conv): <EquivariantPLayer (7, 4)->(25, 5)>
      (non_linearity): GatedBlockNonLin()
    )
    (1): Sequential(
      (conv): <EquivariantPLayer (20, 5)->(13, 3)>
      (non_linearity): GatedBlockNonLin()
    )
    (2): Sequential(
      (conv): <EquivariantPLayer (10, 3)->(7, 2)>
      (non_linearity): GatedBlockNonLin()
    )
    (3): <EquivariantPLayer (5, 2)->(1,)>
  )
)


In [9]:
epochs = 3
dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
pos_weight = compute_binary_label_weights(dataloader)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=5.0e-03)

for epoch in range(epochs):
    for batch in iter(dataloader):
        sample_ids, x, target, brain_mask = batch['sample_id'], batch['input'], batch['target'], batch['brain_mask']
        

        assert brain_mask.size(0) == 1 and len(sample_ids) == 1 and target.size(0) == 1 and x.size(0) == 1, \
                        'Currently only batch-size 1 is supported'
        sample_ids = sample_ids[0]
        brain_mask = brain_mask.squeeze(0).bool()  # (Z x Y x X)
        target = target.squeeze(0)[brain_mask]  # (num_non_masked_voxels)
        # note: x is not squeezed as model expected batch dim, it is squeezed after model is applied
        
        logger.debug(f'{x.shape}')

        optimizer.zero_grad()

        predicted_scores = model(x).squeeze(0)  # (Z x Y x X)
        predicted_scores = predicted_scores[brain_mask]  # (num_non_masked_voxels)
        loss = criterion(predicted_scores, target)
        
        logger.info(f'loss {float(loss)}')

        loss.backward()
        optimizer.step()

[38;21m2021-06-10 10:40:03,503 - mudi - DEBUG - torch.Size([1, 8, 10, 10, 10]) (<ipython-input-9-2499bc7ed611>:18)[0m
[38;21m2021-06-10 10:40:03,611 - mudi - INFO - loss 2.8891196250915527 (<ipython-input-9-2499bc7ed611>:26)[0m
[38;21m2021-06-10 10:40:03,688 - mudi - DEBUG - torch.Size([1, 8, 10, 10, 10]) (<ipython-input-9-2499bc7ed611>:18)[0m
[38;21m2021-06-10 10:40:03,762 - mudi - INFO - loss 2.5472934246063232 (<ipython-input-9-2499bc7ed611>:26)[0m
[38;21m2021-06-10 10:40:03,798 - mudi - DEBUG - torch.Size([1, 8, 10, 10, 10]) (<ipython-input-9-2499bc7ed611>:18)[0m
[38;21m2021-06-10 10:40:03,867 - mudi - INFO - loss 2.2615816593170166 (<ipython-input-9-2499bc7ed611>:26)[0m
[38;21m2021-06-10 10:40:03,898 - mudi - DEBUG - torch.Size([1, 8, 10, 10, 10]) (<ipython-input-9-2499bc7ed611>:18)[0m
[38;21m2021-06-10 10:40:03,968 - mudi - INFO - loss 1.9356001615524292 (<ipython-input-9-2499bc7ed611>:26)[0m
[38;21m2021-06-10 10:40:04,003 - mudi - DEBUG - torch.Size([1, 8, 10, 1

In [8]:
import numpy as np
import pandas as pd
import os
import h5py

In [24]:
class MRISelectorSubjDataset(Dataset):
    """MRI dataset to select features from"""
    
    def __init__(self, root_dir, dataf, headerf, subj_list):
        """
        Initialize the dataset
        
        Args:
            root_dir (string): Directory with the .csv files
            data (string): Data .csv file
            header (string): Header .csv file
            subj_list (list): list of all the subjects to include
            
            batch_size & shuffle are defined with 'DataLoader' in pytorch 
        """     
        self.root_dir = root_dir
        self.dataf = dataf
        
        # load the header
        header = pd.read_csv(os.path.join(self.root_dir, headerf), index_col=0).to_numpy()
        self.ind = header[np.isin(header[:,1], subj_list), 0]
        
        self.indexes = np.arange(len(self.ind))
        
    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.ind)
    
    def __getitem__(self, index):
        """Generates one sample of data"""
        logger.debug(f'loading data for index: {index}')
        indexes = self.indexes[index]

        # Find list of IDs
        #list_IDs_temp = [self.ind[k] for k in indexes]
        list_IDs_temp = self.ind[indexes]
        
        h5f = h5py.File(os.path.join(self.root_dir, self.dataf), 'r')
        X = h5f.get('data1')
        X = X[list_IDs_temp,:]
        
        logger.debug(f'Data for index {index}: {X}{X.shape}')
        
        return X

In [25]:
root_dir = './data'
dataf = 'data_.hdf5'
headerf = 'header_.csv'
subj_list_train = np.array([11, 12, 13, 14])
subj_list_valid = np.array([15])

train_set = MRISelectorSubjDataset(root_dir, dataf, headerf, subj_list_train)
train_set.__getitem__(1)

[38;21m2021-06-08 15:13:45,221 - mudi - DEBUG - loading data for index: 1 (<ipython-input-24-9f2c8dedc506>:31)[0m
[38;21m2021-06-08 15:13:45,223 - mudi - DEBUG - Data for index 1: [0.01232295 0.00841025 0.01193968 ... 0.01052783 0.00930519 0.01574411](1344,) (<ipython-input-24-9f2c8dedc506>:42)[0m


array([0.01232295, 0.00841025, 0.01193968, ..., 0.01052783, 0.00930519,
       0.01574411], dtype=float32)

In [11]:
import numpy as np
from nilearn.masking import apply_mask
import os

img_file = 'MB_Re_t_moco_registered_applytopup.nii.gz'
msk_file = 'brain_mask.nii.gz'

direc11 = './data/cdmri0011/'

masked_data11 = np.transpose(apply_mask(imgs=os.path.join(direc11, img_file),
                                        mask_img=os.path.join(direc11, msk_file)))

In [25]:
np.transpose(masked_data11).shape

(1344, 108300)