In [None]:
from google.colab import drive
drive.mount('/content/drive')
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import time
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"


Mounted at /content/drive


In [None]:
# load subcortical data
X_Guys = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/X_Guys.npy')
y_Guys = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/y_Guys.npy')
ids_Guys = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/ids_Guys.npy')
X_HH = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/X_HH.npy')
y_HH = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/y_HH.npy')
ids_HH = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/ids_HH.npy')
X_IOP = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/X_IOP.npy')
y_IOP = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/y_IOP.npy')
ids_IOP = np.load('/content/drive/MyDrive/dhl_exam/data/subcortical/ids_IOP.npy')

#center data
def centring(X):
    X=np.asarray(X)
    epsilon = 1e-7 
    mean = np.mean(X, axis=0, keepdims=True)
    std = np.std(X, axis=0, keepdims=True)
    centered_array = (X - mean) / (std+epsilon)
    return centered_array

X_Guys_centered=centring(X_Guys)
X_HH_centered=centring(X_HH)
X_IOP_centered=centring(X_IOP)

# producing required train / val / test split
print("initial shapes")
print(X_Guys.shape)
print(y_Guys.shape)
print(ids_Guys.shape)
print(X_HH.shape)
print(y_HH.shape)
print(ids_HH.shape)
print(X_IOP.shape)
print(y_IOP.shape)
print(ids_IOP.shape)

combined_Guys_HH_X = np.concatenate([X_Guys_centered, X_HH_centered], axis=0)
combined_Guys_HH_y = np.concatenate([y_Guys, y_HH], axis=0)

X_train = torch.Tensor(combined_Guys_HH_X[0:int(len(combined_Guys_HH_X)*0.85)])
y_train = torch.Tensor(combined_Guys_HH_y[0:int(len(combined_Guys_HH_y)*0.85)])

# val data 15%
X_val = torch.Tensor(combined_Guys_HH_X[int(len(combined_Guys_HH_X)*0.85):int(len(combined_Guys_HH_X))])
y_val = torch.Tensor(combined_Guys_HH_y[int(len(combined_Guys_HH_y)*0.85):int(len(combined_Guys_HH_y))])
#test data from IOP data
X_test =  torch.Tensor(X_IOP_centered[0:int(len(X_IOP_centered))])
y_test = torch.Tensor(y_IOP[0:int(len(y_IOP))])

print("Check after split")
print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)

initial shapes
(321, 40, 128, 128)
(321, 40, 128, 128)
(321,)
(185, 40, 128, 128)
(185, 40, 128, 128)
(185,)
(71, 40, 128, 128)
(71, 40, 128, 128)
(71,)
Check after split
torch.Size([430, 40, 128, 128])
torch.Size([430, 40, 128, 128])
torch.Size([76, 40, 128, 128])
torch.Size([76, 40, 128, 128])
torch.Size([71, 40, 128, 128])
torch.Size([71, 40, 128, 128])


In [None]:
print("Inspecting Guys data")
print(X_Guys.shape)
print(y_Guys.shape)
print(ids_Guys.shape)
print("Inspecting HH data")
print(X_HH.shape)
print(y_HH.shape)
print(ids_HH.shape)
print("Inspecting IOP data")
print(X_IOP.shape)
print(y_IOP.shape)
print(ids_IOP.shape)


fig, axes = plt.subplots(2, 3, figsize=(12, 8))

axes[0, 0].imshow(X_Guys[:][10][18])
axes[0, 0].set_title("Guys X")

axes[0, 1].imshow(X_HH[:][10][18])
axes[0, 1].set_title("HH X")

axes[0, 2].imshow(X_IOP[:][10][18])
axes[0, 2].set_title("IOP X")

axes[1, 0].imshow(y_Guys[:][10][18])
axes[1, 0].set_title("Guys y")

axes[1, 1].imshow(y_HH[:][10][18])
axes[1, 1].set_title("HH y")

axes[1, 2].imshow(y_IOP[:][10][18])
axes[1, 2].set_title("IOP y")

plt.tight_layout()
plt.show()

In [None]:
# slice data into 2D
def reslice(x, split_size):
  temp = torch.split(x,split_size,dim=1)
  output=[]
  for i in range(len(temp)):
    output.append(temp[i])
  output=torch.cat(output, dim=0)
  return output

#X_train=reslice(X_train,1)
#y_train=reslice(y_train,1)
#X_val=reslice(X_val,1)
#y_val=reslice(y_val,1)
#X_test=reslice(X_test,1)
#y_test=reslice(y_test,1)

#sample slices - just pick 1 slice from 40 so for training always 
X_train=X_train[:,10,:,:]
X_val=X_val[:,20,:,:]
X_test=X_test[:,30,:,:]
y_train=y_train[:,10,:,:]
y_val=y_val[:,20,:,:]
y_test=y_test[:,30,:,:]

#introduce channel
X_train=torch.reshape(X_train,(len(X_train),1,128,128))
X_val=torch.reshape(X_val,(len(X_val),1,128,128))
X_test=torch.reshape(X_test,(len(X_test),1,128,128))
y_train=torch.reshape(y_train,(len(y_train),1,128,128))
y_val=torch.reshape(y_val,(len(y_val),1,128,128))
y_test=torch.reshape(y_test,(len(y_test),1,128,128))


print("Check shapes after slicing into 2D")
print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)

Check shapes after slicing into 2D
torch.Size([430, 1, 128, 128])
torch.Size([430, 1, 128, 128])
torch.Size([76, 1, 128, 128])
torch.Size([76, 1, 128, 128])
torch.Size([71, 1, 128, 128])
torch.Size([71, 1, 128, 128])


In [None]:
# convert y to one_hot

def to_one_hot(y, num_classes):
    y = np.array(y, dtype='int') - 1  
    one_hot = np.eye(num_classes)[y.flatten()]
    return one_hot.reshape(y.shape[0], num_classes, y.shape[2], y.shape[3]).astype(float)

y_train = to_one_hot(y_train, 5)
y_val = to_one_hot(y_val, 5)
y_test = to_one_hot(y_test, 5)

print(y_train.shape)
print(y_val.shape)
print(y_test.shape)


In [None]:
from torch.utils.data import Dataset, DataLoader

X_train=np.asarray(X_train)
X_val=np.asarray(X_val)
X_test=np.asarray(X_test)

class numpy_dataset(Dataset): 
    def __init__(self, data, target): 
        self.data =  torch.from_numpy(data)
        self.target = torch.from_numpy(target)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x, y

    def __len__(self):
        return len(self.data)
    
train_dataset = numpy_dataset(X_train, y_train)
val_dataset = numpy_dataset(X_val, y_val)
test_dataset = numpy_dataset(X_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, drop_last=True)


In [None]:
class SegNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(SegNet, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes


        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(512)

        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        self.conv7 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(256)
        self.conv8 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        self.conv9 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(64)
        self.conv10 = nn.Conv2d(64, num_classes, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)

    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        x, pool1_indices = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x, pool2_indices = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x, pool3_indices = self.pool(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x, pool4_indices = self.pool(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x, pool5_indices = self.pool(x)
        x = self.unpool(x, pool5_indices)
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.unpool(x, pool4_indices)
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.unpool(x, pool3_indices)
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.unpool(x, pool2_indices)
        x = F.relu(self.bn9(self.conv9(x)))
        x = self.conv10(x)
        x = F.interpolate(x, size=(128, 128), mode='bilinear', align_corners=False)
        return torch.sigmoid(x)


class FCN_flexible(nn.Module):
    def __init__(self, input_shape=(1, 128, 128), num_classes=2, dropout_prob=0.5, num_layers=5):
        super(FCN_flexible, self).__init__()
        encoder_layers = []
        in_channels = input_shape[0]
        for i in range(num_layers):
            out_channels = 64 * (2 ** i)
            encoder_layers.extend([
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout_prob),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ])
            in_channels = out_channels

        self.encoder = nn.Sequential(*encoder_layers)
        self.middle = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob),
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
        )
        decoder_layers = []
        for i in range(num_layers - 1, -1, -1):
            out_channels = 64 * (2 ** i)
            decoder_layers.extend([
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout_prob)
            ])
            in_channels = out_channels

        decoder_layers.extend([
            nn.Conv2d(out_channels, num_classes, kernel_size=1),
            nn.Sigmoid()
        ])

        self.decoder = nn.Sequential(*decoder_layers)

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x


def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = conv_block(in_channels, out_channels)

    def forward(self, x):
        return self.conv(self.pool(x))

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = conv_block(in_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat((skip, x), dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.init_conv = conv_block(in_channels, 64)
        self.encoders = nn.ModuleList([
            Encoder(64, 128),
            Encoder(128, 256),
            Encoder(256, 512),
            Encoder(512, 1024)
        ])
        self.decoders = nn.ModuleList([
            Decoder(1024, 512),
            Decoder(512, 256),
            Decoder(256, 128),
            Decoder(128, 64)
        ])
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.init_conv(x)
        skips = [x1]
        for encoder in self.encoders:
            skips.append(encoder(skips[-1]))

        x = skips.pop()
        for decoder in self.decoders:
            x = decoder(x, skips.pop())

        return self.final_conv(x)






In [None]:
# training loops adapted from Kaggle 
def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1.
    return intersection / union

def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)

def bce_dice_loss(pred, label):
    dice_loss = dice_coef_loss(pred, label)
    bce_loss = nn.BCELoss()(pred, label)
    return dice_loss + bce_loss

def train_loop(model, loader, loss_func,optimizer):
    model.train()
    train_losses = []
    train_dices = []

    for i, (image, mask) in enumerate(loader):
        image = image.to(device).float()
        mask = mask.to(device).float()
        outputs = model(image)
        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0            

        dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
        loss = loss_func(outputs, mask)
        train_losses.append(loss.item())
        train_dices.append(dice)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    return train_dices, train_losses


def train_model_early_stopping(train_loader, val_loader, loss_func,optimizer, scheduler, num_epochs, patience=5):
    train_loss_history = []
    train_dice_history = []
    val_loss_history = []
    val_dice_history = []

    best_val_dice = 0
    consecutive_no_improvement = 0
    
    for epoch in range(num_epochs):
        train_dices, train_losses = train_loop(model, train_loader, loss_func, optimizer)
        train_mean_dice = np.array(train_dices).mean()
        train_mean_loss = np.array(train_losses).mean()
        val_mean_dice, val_mean_loss = eval_loop(model, val_loader, loss_func,scheduler)
        
        train_loss_history.append(train_mean_loss)
        train_dice_history.append(train_mean_dice)
        val_loss_history.append(val_mean_loss.cpu().numpy())
        val_dice_history.append(val_mean_dice)
        
        print('Epoch: {}/{} |  Train Loss: {:.3f}, Val Loss: {:.3f}, Train DICE: {:.3f}, Val DICE: {:.3f}'.format(epoch+1, num_epochs,
                                                                                                                 train_mean_loss,
                                                                                                                 val_mean_loss,
                                                                                                                 train_mean_dice,
                                                                                                                 val_mean_dice))
        
        # Check for improvement in validation dice coefficient
        if val_mean_dice > best_val_dice:
            best_val_dice = val_mean_dice
            consecutive_no_improvement = 0
            print('Best validation dice coefficient improved to {:.3f}'.format(best_val_dice))
        else:
            consecutive_no_improvement += 1
            print('No improvement in validation dice coefficient for {} consecutive epochs'.format(consecutive_no_improvement))
            if consecutive_no_improvement >= patience:
                print('Early stopping triggered after {} epochs'.format(epoch+1))
                break

    return train_loss_history, train_dice_history, val_loss_history, val_dice_history,epoch+1

def eval_loop(model, loader, loss_func, scheduler,training=True):
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for step, (image, mask) in enumerate(loader):
            image = image.to(device).float()
            mask = mask.to(device).float()
    
            outputs = model(image)
            loss = loss_func(outputs, mask)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
            
            val_loss += loss
            val_dice += dice
        
        val_mean_dice = val_dice / len(loader)
        val_mean_loss = val_loss / step
        
        if training:
            scheduler.step(val_mean_dice)
        
    return val_mean_dice, val_mean_loss

In [None]:
def prediction_dice(net, test_dataloader):
    test_dice=0

    with torch.no_grad():  
        for batch_idx, (data, target) in enumerate(test_dataloader):
          data = data.to(device).float()
          target = target.to(device).float()
           
          pred = net(data)
          out_cut = np.copy(pred.data.cpu().numpy())
          out_cut[np.nonzero(out_cut < 0.5)] = 0.0
          out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
          dice = dice_coef_metric(out_cut, target.data.cpu().numpy())
          test_dice += dice
        mean_dice = test_dice / len(test_dataloader)
        return mean_dice

def predict(net, test_dataloader):
    test_dice=0

    with torch.no_grad(): 
        for batch_idx, (data, target) in enumerate(test_dataloader):
          data = data.to(device).float()
          target = target.to(device).float()
           
          pred = net(data)
    return data.data.cpu().numpy(),target.data.cpu().numpy(),pred.data.cpu().numpy()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=1, n_classes=5, bilinear=True).to(device)
print("Number of parameters: ", count_parameters(model))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)
start_time = time.time()
train_loss_history, train_dice_history, val_loss_history, val_dice_history = train_model_early_stopping(train_dataloader, val_dataloader, bce_dice_loss, optimizer, scheduler, num_epochs)
end_time = time.time()

Epoch: 1/100 |  Train Loss: 1.286, Val Loss: 1.617, Train DICE: 0.186, Val DICE: 0.249
Best validation dice coefficient improved to 0.249
Epoch: 2/100 |  Train Loss: 1.082, Val Loss: 1.379, Train DICE: 0.355, Val DICE: 0.487
Best validation dice coefficient improved to 0.487
Epoch: 3/100 |  Train Loss: 0.504, Val Loss: 0.285, Train DICE: 0.871, Val DICE: 0.988
Best validation dice coefficient improved to 0.988
Epoch: 4/100 |  Train Loss: 0.143, Val Loss: 0.124, Train DICE: 0.984, Val DICE: 0.990
Best validation dice coefficient improved to 0.990
Epoch: 5/100 |  Train Loss: 0.089, Val Loss: 0.087, Train DICE: 0.985, Val DICE: 0.990
Best validation dice coefficient improved to 0.990
Epoch: 6/100 |  Train Loss: 0.077, Val Loss: 0.079, Train DICE: 0.985, Val DICE: 0.989
No improvement in validation dice coefficient for 1 consecutive epochs
Epoch: 7/100 |  Train Loss: 0.070, Val Loss: 0.069, Train DICE: 0.985, Val DICE: 0.990
Best validation dice coefficient improved to 0.990
Epoch: 8/100 |

In [None]:
prediction_dice(model, test_dataloader) #unet

0.9309134400427358

In [None]:
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SegNet(1, 5).to(device)
print("Number of parameters: ", count_parameters(model))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)
start_time = time.time()
train_loss_history, train_dice_history, val_loss_history, val_dice_history = train_model_early_stopping(train_dataloader, val_dataloader, bce_dice_loss, optimizer, scheduler, num_epochs)
end_time = time.time()
print("time taken", end_time - start_time)

In [None]:
prediction_dice(model, test_dataloader) #segnet

0.8644542985718776