In [1]:
import os
import nilearn as nil
import numpy as np
import random
import torch
from scipy import ndimage
from torch.utils.data import Dataset
from PIL import Image

In [2]:
class FmriDataset(Dataset):

    def __init__(self, data_dir='/data/fmri/data', mask_path='/data/fmri/mask/caudate_mask.nii',
                img_shape=(57, 68, 49, 135), img_timesteps=15):
        self.data_dir = data_dir
        self.img_timesteps = img_timesteps
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.mask_path, self.img_shape = mask_path, img_shape
        self.samples = []
        # Initialize the image indexes with their scores
        self.index_data()
        self.mask = self.read_mask()
        self.class_weights = self.find_weights()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, score = self.samples[idx]
        score = self.get_class(score)
        img = self.read_image(img_path)
        img = self.apply_mask(img)
        img = self.apply_temporal_aug(img)
        return img, score

    def index_data(self):
        """
        Stores all the image_paths with their respective scores/classes in the 
        """
        self.weights = {i:0 for i in range(5)}
        for sub in os.listdir(self.data_dir):
            sub_dir = os.path.join(self.data_dir, sub)
            preproc_dir = os.path.join(sub_dir, f'{sub}.preproc')
            for img_name in os.listdir(preproc_dir):
                img_path = os.path.join(preproc_dir, img_name)
                score = self.get_score(sub_dir, img_name)
                score_class = self.get_class(score)
                self.weights[score_class] += 1
                self.samples.append((img_path, score))
                
    def get_class(self, score):
        """
        Categorize each score into one of the five classes (bins)
        Returns values from 0-4 (5 classes)
        """
        if score < 1:
            return 0
        elif score >= 100:
            return 4
        else:
            return score // 20
        

    def get_score(self, sub_dir, img_name):
        score_file = '0back_VAS-f.1D' if '0back' in img_name else '2back_VAS-f.1D'
        score_path = os.path.join(sub_dir, score_file)
        with open(score_path, 'r') as s_f:
            scores = [int(str(score.replace('\n', ''))) for score in s_f]

        task_num = img_name.split('.')[1]
        score_num = int(task_num[-1:])
        return scores[score_num]
    
    def read_image(self, img_path):
        nX, nY, nZ, nT = self.img_shape
        img = nil.image.load_img(img_path)
        img = img.get_fdata()[:nX, :nY, :nZ, :nT]
        img = torch.tensor(img, dtype=torch.float, device=self.device)
        img = (img - img.mean()) / img.std()
        return img
    
    def read_mask(self):
        nX, nY, nZ, _ = self.img_shape
        mask_img = nil.image.load_img(self.mask_path)
        mask_img = mask_img.get_fdata()[:]
        mask_img = np.asarray(mask_img)
        dilated_mask = np.zeros((nX, nY, nZ))
        ratio = round(mask_img.shape[2]/nZ)
        for k in range(nZ):
            temp = ndimage.morphology.binary_dilation(mask_img[:, :, k*ratio], iterations=1) * 1
            temp_img = Image.fromarray(np.uint8(temp*255))
            dilated_mask[:, :, k] = np.array(temp_img.resize((nY, nX)))
            
        dilated_mask = (dilated_mask > 64).astype(int)
        dilated_mask = torch.tensor(dilated_mask, dtype=torch.float, device=self.device)
        return dilated_mask
    
    def apply_mask(self, img):
        nT = img.shape[-1]
        for i in range(nT):
            img[:, :, :, i] = torch.mul(img[:, :, :, i], self.mask)
        return img
    
    def apply_temporal_aug(self, img):
        """
        Image shape: X, Y, Z, t=135
        So, take any 15 random timesteps from the 135 available in ascending order 
        """
        total_timesteps = img.shape[3]
        rand_timesteps = sorted(random.sample(range(0, total_timesteps), self.img_timesteps))
        img = torch.tensor(np.take(img.cpu().numpy(), rand_timesteps, axis=3))
        # Move time axes to the first place followed by X, Y, Z
        img = img.permute(3, 0, 1, 2)
        return img
    
    def find_weights(self):
        weights = dict(self.weights)
        key_max = max(weights.keys(), key=(lambda k: weights[k]))
        max_value = weights[key_max]
        for key in weights.keys():
            weights[key] = max_value / weights[key]
            
        return weights
            

In [3]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.autograd as dif
from torch.nn.modules.utils import _triple

In [4]:
"""2+1D MODULE"""
""" -------------------------------------------------------------------------"""
# R2Plus1D Convolution
class SpatioTemporalConv(nn.Module):
    r"""Applies a factored 3D convolution over an input signal composed of several input 
    planes with distinct spatial and time axes, by performing a 2D convolution over the 
    spatial axes to an intermediate subspace, followed by a 1D convolution over the time 
    axis to produce the final output.

    Args:
        in_channels (int): Number of channels in the input tensor
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(SpatioTemporalConv, self).__init__()

        # if ints are entered, convert them to iterables, 1 -> [1, 1, 1]
        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)

        # decomposing the parameters into spatial and temporal components by
        # masking out the values with the defaults on the axis that
        # won't be convolved over. This is necessary to avoid unintentional
        # behavior such as padding being added twice
        spatial_kernel_size =  [1, kernel_size[1], kernel_size[2]]
        spatial_stride =  [1, stride[1], stride[2]]
        spatial_padding =  [0, padding[1], padding[2]]

        temporal_kernel_size = [kernel_size[0], 1, 1]
        temporal_stride =  [stride[0], 1, 1]
        temporal_padding =  [padding[0], 0, 0]

        # compute the number of intermediary channels (M) using formula 
        # from the paper section 3.5
        intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/ \
                            (kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels)))

        # the spatial conv is effectively a 2D conv due to the 
        # spatial_kernel_size, followed by batch_norm and ReLU
        self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size,
                                    stride=spatial_stride, padding=spatial_padding, bias=bias)
        self.bn = nn.BatchNorm3d(intermed_channels)
        self.relu = nn.ReLU()

        # the temporal conv is effectively a 1D conv, but has batch norm 
        # and ReLU added inside the model constructor, not here. This is an 
        # intentional design choice, to allow this module to externally act 
        # identical to a standard Conv3D, so it can be reused easily in any 
        # other codebase
        self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 
                                    stride=temporal_stride, padding=temporal_padding, bias=bias)

    def forward(self, x):
        x = self.relu(self.bn(self.spatial_conv(x)))
        x = self.temporal_conv(x)
        return x

In [5]:
"""MODEL"""


""" Classifier """
class Custom3D(nn.Module):
    def __init__(self, params):
        super(Custom3D, self).__init__()
        self.ndf = params.ndf
        self.nc = params.nT // params.nDivT
        self.nClass = params.nClass
        
        self.conv1 = nn.Sequential(
            ## input is 15 x 54 x 64 x 50
            SpatioTemporalConv(self.nc, self.ndf, 5, 2, 1, bias = False),
            nn.ReLU(True),
            ## state size. (ndf) x 26 x 31 x 24
        )
        self.conv2 = nn.Sequential(
            SpatioTemporalConv(self.ndf * 1, self.ndf * 2, 4, 2, 1, bias = False),
            nn.BatchNorm3d(self.ndf * 2),
            nn.ReLU(True),
            ## state size. (ndf*2) x 13 x 15 x 12
        )
        self.conv3 = nn.Sequential(
            SpatioTemporalConv(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias = False),
            nn.BatchNorm3d(self.ndf * 4),
            nn.ReLU(True),
            ## state size. (ndf*4) x 6 x 7 x 6
        )
        self.conv4 = nn.Sequential(
            SpatioTemporalConv(self.ndf * 4, self.ndf * 4, 4, 2, 1, bias = False),
            nn.BatchNorm3d(self.ndf * 4),
            nn.ReLU(True),
            ## state size. (ndf*2) x 3 x 3 x 3
        )
        self.conv5 = nn.Sequential(
            SpatioTemporalConv(self.ndf * 4, self.ndf * 2, 3, 1, 0, bias = False),
            nn.ReLU(True),
        )
        
        self._to_linear = None
        x = torch.randn(1, self.nc, params.nX, params.nY, params.nZ)
        self.convs(x)
        
        self.fc1 = nn.Linear(self._to_linear, self.ndf * 1)

        
        self.fc2 = nn.Sequential(
            nn.Linear(self.ndf * 1, self.nClass),
        )
        
    def convs(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        # This is to make sure that we don't have to worry about the shape from the convolutional layers
        # before sending the input to the FC layers
        if self._to_linear is None:
            self._to_linear = int(x[0].shape[0]*x[0].shape[1]*x[0].shape[2]*x[0].shape[3])
            
        return x
        
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

In [6]:
data = FmriDataset()

In [7]:
data.class_weights

{0: 1.0,
 1: 3.9206349206349205,
 2: 7.484848484848484,
 3: 7.264705882352941,
 4: 5.369565217391305}

In [None]:
data[0][0].shape

### Main file code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils

#### Parameters input

In [None]:
from easydict import EasyDict as edict

In [None]:
params = edict({
    'path': '/data/fmri',
    'nGPU': 2,
    'nEpochs': 10,
    'nBacks': 2,
    'nTasks': 4,
    'nClass': 5,
    'batchSize': 10,
    'nT': 135,
    'nX': 57,
    'nY': 68,
    'nZ': 49,
    'nDivT': 9,
    'ndf': 64,
    'lr': 0.001,
    'beta1': 0.5,
    'beta2': 0.999
})
params

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
BATCH_SIZE = 10

#### Loss function and optimizer

In [None]:
class_weights = torch.FloatTensor([data.class_weights[i] for i in range(5)]).to(device)
class_weights

In [None]:
net = Custom3D(params=params).to(device)
net

In [None]:
loss_function = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [None]:
def train_test_length(total, test_pct=0.2):
    train_count = int((1-test_pct)*total)
    test_count = total - train_count
    return train_count, test_count

In [None]:
train_count, test_count = train_test_length(total=len(data), test_pct=0.2)
train_count, test_count

In [None]:
trainset, testset = random_split(data, [train_count, test_count])

In [None]:
train_loader = DataLoader(trainset, batch_size=params.batchSize, shuffle=True)
test_loader = DataLoader(testset, batch_size=params.batchSize, shuffle=True)

In [None]:
def train(net):
    for epoch in range(params.nEpochs):
        for batch in train_loader:
            inputs, labels = batch[0].to(device), batch[1].to(device)
            optimizer.zero_grad()
            
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            
        print(f'Epoch: {epoch} | Loss: {loss}')

In [None]:
train(net)