In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd './drive/MyDrive/Cardiac Project/Registration_phase'

/content/drive/MyDrive/Cardiac Project/Registration_phase


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import torch.nn.functional as nnf
from torchvision import transforms
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils import data

from tqdm import tqdm
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

In [4]:
path = './phase_2_voxel_morph/voxelMorph/'

In [5]:
def read_dataset(path):
    cine_dataset = []
    lge_dataset = []
    for filename in tqdm(glob.iglob(path + "**/*.png" , recursive = True)):
        class_type = filename.split('/')[3]
        image_name = filename.split('/')[4]
        image = cv2.imread(filename ,  cv2.IMREAD_GRAYSCALE)
        image = transforms.ToTensor()(image)
        if class_type == 'cine_images':
            cine_dataset.append((image , image_name))
        if class_type == 'lge_images':
            lge_dataset.append((image , image_name))

    cine_dataset.sort(key = lambda x: x[1]) 
    lge_dataset.sort(key = lambda x: x[1]) 
 
    dataset = []
    for i in range(len(cine_dataset)):
        cine_image = cine_dataset[i][0]
        lge_image = lge_dataset[i][0]
        dataset.append((cine_image , lge_image))

    return dataset

In [6]:
dataset = read_dataset(path)

606it [07:09,  1.41it/s]


In [7]:
train_dataset , test_dataset = train_test_split(dataset, test_size=0.33, random_state=42)

In [8]:
train_loader = data.DataLoader(train_dataset ,
                               batch_size = 20,
                               shuffle = True,
                               num_workers =  6,
                               worker_init_fn =  np.random.seed(42))  
test_loader = data.DataLoader(test_dataset ,
                               batch_size = 20,
                               shuffle = True,
                               num_workers =  6,
                               worker_init_fn =  np.random.seed(42))                      

  cpuset_checked))


In [9]:
def default_unet_features():
    nb_features = [
        [16, 32, 32, 32],             # encoder
        [32, 32, 32, 32, 32, 16, 16]  # decoder
    ]
    return nb_features

In [10]:
class ConvBlock(nn.Module):
    """
    Specific convolutional block followed by leakyrelu for unet.
    """

    def __init__(self, ndims, in_channels, out_channels, stride=1):
        super().__init__()

        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.main = Conv(in_channels, out_channels, 3, stride, 1)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        out = self.main(x)
        out = self.activation(out)
        return out

In [11]:
class Unet(nn.Module):
    """
    A unet architecture. Layer features can be specified directly as a list of encoder and decoder
    features or as a single integer along with a number of unet levels. The default network features
    per layer (when no options are specified) are:

        encoder: [16, 32, 32, 32]
        decoder: [32, 32, 32, 32, 32, 16, 16]
    """

    def __init__(self, inshape, nb_features=None, nb_levels=None, feat_mult=1):
        super().__init__()
        """
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            nb_features: Unet convolutional features. Can be specified via a list of lists with
                the form [[encoder feats], [decoder feats]], or as a single integer. If None (default),
                the unet features are defined by the default config described in the class documentation.
            nb_levels: Number of levels in unet. Only used when nb_features is an integer. Default is None.
            feat_mult: Per-level feature multiplier. Only used when nb_features is an integer. Default is 1.
        """

        # ensure correct dimensionality
        ndims = len(inshape)
        assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims

        # default encoder and decoder layer features if nothing provided
        if nb_features is None:
            nb_features = default_unet_features()

        # build feature list automatically
        if isinstance(nb_features, int):
            if nb_levels is None:
                raise ValueError('must provide unet nb_levels if nb_features is an integer')
            feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
            self.enc_nf = feats[:-1]
            self.dec_nf = np.flip(feats)
        elif nb_levels is not None:
            raise ValueError('cannot use nb_levels if nb_features is not an integer')
        else:
            self.enc_nf, self.dec_nf = nb_features

        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # configure encoder (down-sampling path)
        prev_nf = 2
        self.downarm = nn.ModuleList()
        for nf in self.enc_nf:
            self.downarm.append(ConvBlock(ndims, prev_nf, nf, stride=2))
            prev_nf = nf

        # configure decoder (up-sampling path)
        enc_history = list(reversed(self.enc_nf))
        self.uparm = nn.ModuleList()
        for i, nf in enumerate(self.dec_nf[:len(self.enc_nf)]):
            channels = prev_nf + enc_history[i] if i > 0 else prev_nf
            self.uparm.append(ConvBlock(ndims, channels, nf, stride=1))
            prev_nf = nf

        # configure extra decoder convolutions (no up-sampling)
        prev_nf += 2
        self.extras = nn.ModuleList()
        for nf in self.dec_nf[len(self.enc_nf):]:
            self.extras.append(ConvBlock(ndims, prev_nf, nf, stride=1))
            prev_nf = nf
 
    def forward(self, x):

        # get encoder activations
        x_enc = [x]
        for layer in self.downarm:
            x_enc.append(layer(x_enc[-1]))

        # conv, upsample, concatenate series
        x = x_enc.pop()
        for layer in self.uparm:
            x = layer(x)
            x = self.upsample(x)
            x = torch.cat([x, x_enc.pop()], dim=1)

        # extra convs at full resolution
        for layer in self.extras:
            x = layer(x)

        return x


In [12]:
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)


class VecInt(nn.Module):
    """
    Integrates a vector field via scaling and squaring.
    """

    def __init__(self, inshape, nsteps):
        super().__init__()
        
        assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
        self.nsteps = nsteps
        self.scale = 1.0 / (2 ** self.nsteps)
        self.transformer = SpatialTransformer(inshape)

    def forward(self, vec):
        vec = vec * self.scale
        for _ in range(self.nsteps):
            vec = vec + self.transformer(vec, vec)
        return vec


class ResizeTransform(nn.Module):
    """
    Resize a transform, which involves resizing the vector field *and* rescaling it.
    """

    def __init__(self, vel_resize, ndims):
        super().__init__()
        self.factor = 1.0 / vel_resize
        self.mode = 'linear'
        if ndims == 2:
            self.mode = 'bi' + self.mode
        elif ndims == 3:
            self.mode = 'tri' + self.mode

    def forward(self, x):
        if self.factor < 1:
            # resize first to save memory
            x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)
            x = self.factor * x

        elif self.factor > 1:
            # multiply first to save memory
            x = self.factor * x
            x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)

        # don't do anything if resize is 1
        return x

In [13]:
class VxmDense(nn.Module):
    """
    VoxelMorph network for (unsupervised) nonlinear registration between two images.
    """

    def __init__(self,
        inshape,
        nb_unet_features=None,
        nb_unet_levels=None,
        unet_feat_mult=1,
        int_steps=7,
        int_downsize=2,
        bidir=False,
        use_probs=False):
        """ 
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            nb_unet_features: Unet convolutional features. Can be specified via a list of lists with
                the form [[encoder feats], [decoder feats]], or as a single integer. If None (default),
                the unet features are defined by the default config described in the unet class documentation.
            nb_unet_levels: Number of levels in unet. Only used when nb_features is an integer. Default is None.
            unet_feat_mult: Per-level feature multiplier. Only used when nb_features is an integer. Default is 1.
            int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
            int_downsize: Integer specifying the flow downsample factor for vector integration. The flow field
                is not downsampled when this value is 1.
            bidir: Enable bidirectional cost function. Default is False.
            use_probs: Use probabilities in flow field. Default is False.
        """
        super().__init__()

        # internal flag indicating whether to return flow or integrated warp during inference
        self.training = True

        # ensure correct dimensionality
        ndims = len(inshape)
        assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims
        # configure core unet model
        self.unet_model = Unet(
            inshape,
            nb_features=nb_unet_features,
            nb_levels=nb_unet_levels,
            feat_mult=unet_feat_mult
        )

        # configure unet to flow field layer
        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.flow = Conv(self.unet_model.dec_nf[-1], ndims, kernel_size=3, padding=1)

        # init flow layer with small weights and bias
        self.flow.weight = nn.Parameter(Normal(0, 1e-5).sample(self.flow.weight.shape))
        self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))

        # probabilities are not supported in pytorch
        if use_probs:
            raise NotImplementedError('Flow variance has not been implemented in pytorch - set use_probs to False')

        # configure optional resize layers
        resize = int_steps > 0 and int_downsize > 1
        self.resize = ResizeTransform(int_downsize, ndims) if resize else None
        self.fullsize = ResizeTransform(1 / int_downsize, ndims) if resize else None

        # configure bidirectional training
        self.bidir = bidir

        # configure optional integration layer for diffeomorphic warp
        down_shape = [int(dim / int_downsize) for dim in inshape]
        self.integrate = VecInt(down_shape, int_steps) if int_steps > 0 else None

        # configure transformer
        self.transformer = SpatialTransformer(inshape)

    def forward(self, source, target, registration=False):
        '''
        Parameters:
            source: Source image tensor.
            target: Target image tensor.
            registration: Return transformed image and flow. Default is False.
        '''

        # concatenate inputs and propagate unet
        x = torch.cat([source, target], dim=1)
        x = self.unet_model(x)

        # transform into flow field
        flow_field = self.flow(x)

        # resize flow for integration
        pos_flow = flow_field
        if self.resize:
            pos_flow = self.resize(pos_flow)

        preint_flow = pos_flow

        # negate flow for bidirectional model
        neg_flow = -pos_flow if self.bidir else None

        # integrate to produce diffeomorphic warp
        if self.integrate:
            pos_flow = self.integrate(pos_flow)
            neg_flow = self.integrate(neg_flow) if self.bidir else None

            # resize to final resolution
            if self.fullsize:
                pos_flow = self.fullsize(pos_flow)
                neg_flow = self.fullsize(neg_flow) if self.bidir else None

        # warp image with flow field
        y_source = self.transformer(source, pos_flow)
        y_target = self.transformer(target, neg_flow) if self.bidir else None

        # return non-integrated flow field if training
        if not registration:
            return (y_source, y_target, preint_flow) if self.bidir else (y_source, preint_flow)
        else:
            return y_source, pos_flow

In [14]:
import torch
import torch.nn.functional as F
import numpy as np
import math


class NCC:
    """
    Local (over window) normalized cross correlation loss.
    """

    def __init__(self, win=None):
        self.win = win

    def loss1(self, y_true, y_pred):
        
        I = y_true
        J = y_pred

        # get dimension of volume
        # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
        ndims = len(list(I.size())) - 2
        assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
        # set window size
        win = [9] * ndims if self.win is None else self.win

        # compute filters
        sum_filt = torch.ones([1, 1, *win]).to("cuda")

        pad_no = math.floor(win[0]/2)

        # get convolution function
        conv_fn = getattr(F, 'conv%dd' % ndims)

        # compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J
        I_sum = conv_fn(I, sum_filt, stride=(1,1), padding=1)
        J_sum = conv_fn(J, sum_filt, stride=(1,1), padding=1)
        I2_sum = conv_fn(I2, sum_filt, stride=(1,1), padding=1)
        J2_sum = conv_fn(J2, sum_filt, stride=(1,1), padding=1)
        IJ_sum = conv_fn(IJ, sum_filt, stride=(1,1), padding=1)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

        cc = cross * cross / (I_var * J_var + 1e-5)

        return -torch.mean(cc)
    
    def loss(self, y_true, y_pred):
        n = 9
        I = y_true
        J = y_pred
        batch_size, channels, xdim, ydim = I.shape
        I2 = torch.mul(I, I)
        J2 = torch.mul(J, J)
        IJ = torch.mul(I, J)
        sum_filter = torch.ones((1, channels, n, n))
        sum_filter = sum_filter.to("cuda")
        I_sum = torch.conv2d(I, sum_filter, padding=1, stride=(1,1))
        J_sum = torch.conv2d(J, sum_filter,  padding=1 ,stride=(1,1))
        I2_sum = torch.conv2d(I2, sum_filter, padding=1, stride=(1,1))
        J2_sum = torch.conv2d(J2, sum_filter, padding=1, stride=(1,1))
        IJ_sum = torch.conv2d(IJ, sum_filter, padding=1, stride=(1,1))
        win_size = n**2
        u_I = I_sum / win_size
        u_J = J_sum / win_size
        cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
        cc = cross*cross / (I_var*J_var + np.finfo(float).eps)
        return torch.mean(cc)


class MSE:
    """
    Mean squared error loss.
    """

    def loss(self, y_true, y_pred):
        return torch.mean((y_true - y_pred) ** 2)


class Dice:
    """
    N-D dice for segmentation
    """

    def loss(self, y_true, y_pred):
        ndims = len(list(y_pred.size())) - 2
        vol_axes = list(range(2, ndims+2))
        top = 2 * (y_true * y_pred).sum(dim=vol_axes)
        bottom = torch.clamp((y_true + y_pred).sum(dim=vol_axes), min=1e-5)
        dice = torch.mean(top / bottom)
        return -dice


class Grad:
    """
    N-D gradient loss.
    """

    def __init__(self, penalty='l1', loss_mult=None):
        self.penalty = penalty
        self.loss_mult = loss_mult

    def loss(self, _, y_pred):
        y_pred = y_pred.permute(0,2,3,1)

        dy = torch.abs(y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]) 
        dx = torch.abs(y_pred[ :, :, 1:, :] - y_pred[:, :, :-1, :]) 

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx

        d = torch.mean(dx) + torch.mean(dy) 
        grad = d / 2.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

In [15]:
def train(model , train_loader ,epochs):
    lamda = 0.001
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_dice = 0
        for fixed_batch , moving_batch in train_loader:
            optimizer.zero_grad()
            fixed_batch = fixed_batch.to(device)
            moving_batch = moving_batch.to(device)
            registered_images , flow = model(moving_batch , fixed_batch)

            train_reg_loss = Lregist.loss(fixed_batch , registered_images)
            train_smooth_loss = Lsmooth.loss(fixed_batch , registered_images)
            train_loss = -1.0 * train_reg_loss + lamda * train_smooth_loss
            train_loss.backward()

            optimizer.step()
            epoch_loss += train_loss.item()
        print(epoch , epoch_loss/len(train_dataset) )

In [16]:
model = VxmDense(inshape = (64 , 64),
                nb_unet_features=None,
                nb_unet_levels=None,
                unet_feat_mult=1,
                int_steps=1,
                int_downsize=1,
                bidir=False,
                use_probs=False)

model = model.to(device)


In [17]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [18]:
Lregist = NCC(win = [9 , 9])
Lsmooth = Grad(penalty ='l2')

In [19]:
train(model , train_loader ,epochs = 50)

  cpuset_checked))


0 -0.012529958933591843
1 -0.01311676561832428
2 -0.013211853951215744
3 -0.013267641812562942
4 -0.013288931101560592
5 -0.013286093845963479
6 -0.013350236266851424
7 -0.013408442586660385
8 -0.013580862358212471
9 -0.01394116185605526
10 -0.014516581147909165
11 -3.3365823984891176
12 -0.00403778962790966
13 -0.0016822963580489158
14 -0.0011476461216807364
15 -0.0009828377468511463
16 -0.0009351460263133049
17 -0.0009197291405871511
18 -0.000913036591373384
19 -0.0009107905020937323
20 -0.0009102887660264969
21 -0.00091005505528301
22 -0.0009099380252882838
23 -0.0009099105698987842
24 -0.0009099036827683449
25 -0.0009099017968401313
26 -0.0009099017549306154
27 -0.0009099016943946481
28 -0.0009099016105756164
29 -0.0009099016105756164
30 -0.0009099016431719064
31 -0.0009099016059190035
32 -0.0009099015640094876
33 -0.0009099015686661005
34 -0.0009099015779793263
35 -0.0009099015733227134
36 -0.0009099015826359391
37 -0.0009099016012623906
38 -0.0009099015407264232
39 -0.00090990159

In [20]:
def test(model , test_loader):
    with torch.no_grad():
        all_loss = 0
        for fixed_batch , moving_batch in test_loader:
            fixed_batch = fixed_batch.to(device)
            moving_batch = moving_batch.to(device)
            registered_images ,flow = model(moving_batch , fixed_batch)
            test_loss = ncc_loss.loss(fixed_batch , registered_images)
            all_loss += test_loss.item()
        print(all_loss/len(test_dataset)) 
        return fixed_batch, moving_batch , registered_images  

In [21]:
fixed_batch, moving_batch , registered_images    = test(model , test_loader)

  cpuset_checked))


NameError: ignored

In [None]:
sample = 5
fig , (ax1 , ax2 , ax3) = plt.subplots(1,3,figsize = (10,10))
ax1.imshow(moving_batch[sample].squeeze(0).cpu() , cmap = 'gray')
ax2.imshow(fixed_batch[sample].squeeze(0).cpu() , cmap = 'gray')
ax3.imshow(registered_images[sample].squeeze(0).cpu() , cmap = 'gray')