In [5]:
import numpy as np
import os
import matplotlib.pyplot as plt
import nibabel as nib
import skimage.transform as resize
import seaborn as sns
import torch.nn as nn
import torch
import skimage
from torchvision import transforms
from mpl_toolkits.mplot3d import Axes3D

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import torchvision.transforms as transforms
import random

In [None]:
filea = "C:\\Users\\vivek\\Music\\amos22\\Images\\amos_0004.nii.gz"
fileb = "C:\\Users\\vivek\\Music\\amos22\\Labels\\amos_0004.nii.gz"
image11 = nib.load(filea)
image12 = nib.load(fileb)
data11 = image11.get_fdata()
data12 = image12.get_fdata()
plt.imshow(data11[:,:,50], cmap = "gray")
plt.colorbar(orientation="vertical")
plt.grid(False)
plt.show()
plt.imshow(data12[:,:,50], cmap = "gray")
plt.colorbar(orientation="vertical")
plt.grid(False)
plt.show()
re_data1 = skimage.transform.resize(data11, (64, 64, 64), order=0, preserve_range=2, anti_aliasing=False)
re_data2 = skimage.transform.resize(data12, (64, 64, 64), order=0, preserve_range=2, anti_aliasing=False)
re_data1 = np.reshape(re_data1, (1, 64, 64, 64))
re_data2 = np.reshape(re_data2, (1, 64, 64, 64))
print(re_data1.shape)
print(re_data2.shape)

## GIN

In [7]:
class GradlessGCReplayNonlinBlock3D(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, scale_pool = [1,3], use_act=True):
        """
        Convolution-leaky relu layer. Efficiently implemented by using Group Convolutions
        """
        super(GradlessGCReplayNonlinBlock3D, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.scale_pool = scale_pool
        self.use_act = use_act

    def forward(self, x_in):
        # Choose a random size for the kernel
        idx_k = torch.randint(high=len(self.scale_pool), size=(1,))
        k = self.scale_pool[idx_k[0]]

        nb, nc, nx, ny, nz = x_in.shape

        ker = torch.randn([self.out_channel * nb, self.in_channel, k, k, k]).cuda()
        shift = torch.randn([self.out_channel * nb, 1, 1, 1]).cuda() * 1.0

        x_in = x_in.view(1, nb * nc, nx, ny, nz)
        x_conv = F.conv3d(x_in, ker, stride=1, padding=k//2, dilation=1, groups=nb)
        x_conv = x_conv + shift
        if self.use_act:
            x_conv = F.leaky_relu(x_conv)

        x_conv = x_conv.view(nb, self.out_channel, nx, ny, nz)

        return x_conv

In [8]:
class GINGroup_Conv3D(nn.Module):
    def __init__(self, out_channel=1, interim_channel=2, in_channel=1, scale_pool = [1,3], n_layer=4, out_norm = 'frob'):
        """
        Global Intensity Non-Linear Augmentation
        """
        super(GINGroup_Conv3D, self).__init__()
        self.scale_pool = scale_pool
        self.out_channel = out_channel
        self.in_channel = in_channel
        self.n_layer = n_layer
        self.out_norm = out_norm
        self.layers = []

        self.layers.append(
            GradlessGCReplayNonlinBlock3D(out_channel=interim_channel, in_channel=in_channel, scale_pool=scale_pool).cuda()
        )
        for i in range(n_layer-2):
            self.layers.append(
                GradlessGCReplayNonlinBlock3D(out_channel=interim_channel, in_channel=interim_channel, scale_pool=scale_pool).cuda()
            )
        self.layers.append(
            GradlessGCReplayNonlinBlock3D(out_channel=out_channel, in_channel=interim_channel, scale_pool=scale_pool, use_act=False).cuda()
        )

        self.layers = nn.ModuleList(self.layers)

    def forward(self, x_in):
        x_in = x_in.float()
        if isinstance(x_in, list):
            x_in = torch.cat(x_in, dim=0)

        nb, nc, nx, ny, nz = x_in.shape

        alphas = torch.rand(nb)[:, None, None, None, None] # nb, 1, 1, 1, 1
        alphas = alphas.repeat(1, nc, 1, 1, 1).cuda() # nb, nc, 1, 1

        x = self.layers[0](x_in)
        for blk in self.layers[1:]:
            x = blk(x)
        mixed = alphas * x + (1.0 - alphas) * x_in

        if self.out_norm == 'frob':
            _in_frob = torch.norm(x_in.view(nb, nc, -1), dim=(-1, -2), p='fro', keepdim=False)
            _in_frob = _in_frob[:, None, None, None, None].repeat(1, nc, 1, 1, 1)
            _self_frob = torch.norm(mixed.view(nb, self.out_channel, -1), dim=(-1,-2), p='fro', keepdim=False)
            _self_frob = _self_frob[:, None, None, None, None].repeat(1, self.out_channel, 1, 1, 1)
            mixed = mixed * (1.0 / (_self_frob + 1e-5)) * _in_frob

        return mixed

In [9]:

if __name__ == '__main__':
    augmenter = GINGroup_Conv3D().cuda()
    
    image = np.reshape(re_data1, (1,1,64,64,64))
    image_input = torch.from_numpy(image)
    image_input = image_input.cuda()
    print(image_input.shape)

    # Forward pass through the shallow network
    output1 = augmenter(image_input).detach().cpu().numpy()
    output2 = augmenter(image_input).detach().cpu().numpy()
    input = image_input.detach().cpu().numpy()
    # Print the output shape
    plt.subplot(1,3,1)
    plt.imshow(input[0,0,:,:,50], cmap = "gray")
    plt.subplot(1,3,2)
    plt.imshow(output1[0,0,:,:,50], cmap = "gray")
    plt.subplot(1,3,3)
    plt.imshow(output2[0,0,:,:,50], cmap = "gray")
    plt.show()

NameError: name 're_data1' is not defined

In [None]:
# Set the size of the field
width = 64
height = 64
depth = 64

# Set the frequency of the noise
freq = random.uniform(0, 0.05)

# Generate the Perlin noise field
field = np.zeros((depth, height, width))
for z in range(depth):
    for y in range(height):
        for x in range(width):
            field[z, y, x] = pnoise3(x * freq, y * freq, z * freq)

field2 = 1-field
field = np.reshape(field,(1,1,64,64,64))
field2 = np.reshape(field2, (1,1,64,64,64))

In [None]:
# Plot the 3D Perlin noise field
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x_coords, y_coords, z_coords = np.meshgrid(np.arange(width), np.arange(height), np.arange(depth))
ax.scatter(x_coords.flatten(), y_coords.flatten(), z_coords.flatten(), c=field[0,0,:,:,:].flatten(), cmap='gray')
plt.show()

# Plot the 3D Perlin noise field
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x_coords, y_coords, z_coords = np.meshgrid(np.arange(width), np.arange(height), np.arange(depth))
ax.scatter(x_coords.flatten(), y_coords.flatten(), z_coords.flatten(), c=field2[0,0,:,:,:].flatten(), cmap='gray')
plt.show()

In [None]:
T1 = output1 * field + output2 * field2
T2 = output1 * field2 + output2 * field

In [None]:
plt.figure(figsize=(15,15))
for j in range(64):
    plt.subplot(8,8,j+1)
    plt.imshow(T1[0,0,:,:,j], cmap = "gray")
plt.show()

In [None]:
plt.figure(figsize=(15,15))
for j in range(64):
    plt.subplot(8,8,j+1)
    plt.imshow(T2[0,0,:,:,j], cmap = "gray")
plt.show()