Skip to content

Commit

Permalink
More intermediate changes
Browse files Browse the repository at this point in the history
  • Loading branch information
shayansiddiqui committed Oct 23, 2018
1 parent df15d72 commit 05aa676
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 245 deletions.
206 changes: 108 additions & 98 deletions Run_QuickNAT.ipynb

Large diffs are not rendered by default.

Binary file removed models/quicknat_model.model
Binary file not shown.
64 changes: 62 additions & 2 deletions quickNat_pytorch/data_utils.py
Expand Up @@ -27,6 +27,67 @@ def __getitem__(self, index):
def __len__(self):
return len(self.y)

def get_imdb_data():
# TODO: Need to change later
NumClass = 28
data_root = "../data/MALC_Coronal_Data/"
# Load DATA
Data = h5py.File(data_root+'Data.h5', 'r')
a_group_key = list(Data.keys())[0]
Data = list(Data[a_group_key])
Data = np.squeeze(np.asarray(Data))
Data = Data.astype(np.float32)
Label = h5py.File(data_root+'label.h5', 'r')
a_group_key = list(Label.keys())[0]
Label = list(Label[a_group_key])
Label = np.squeeze(np.asarray(Label))
Label = Label.astype(np.float32)
set = h5py.File(data_root+'set.h5', 'r')
a_group_key = list(set.keys())[0]
set = list(set[a_group_key])
set = np.squeeze(np.asarray(set))
sz = Data.shape
Data = Data.reshape([sz[0], 1, sz[1], sz[2]])
weights = Label[:,1,:,:]
Label = Label[:,0,:,:]
sz = Label.shape
print(sz)
Label = Label.reshape([sz[0], 1, sz[1], sz[2]])
weights = weights.reshape([sz[0], 1, sz[1], sz[2]])
train_id = set == 1
test_id = set == 3

Tr_Dat = Data[train_id, :, :, :]
Tr_Label = np.squeeze(Label[train_id, :, :, :]) - 1
Tr_weights = weights[train_id, :, :, :]
Tr_weights = np.tile(Tr_weights, [1, NumClass, 1, 1])
print(np.amax(Tr_Label))
print(np.amin(Tr_Label))

Te_Dat = Data[test_id, :, :, :]
Te_Label = np.squeeze(Label[test_id, :, :, :]) - 1
Te_weights = weights[test_id, :, :, :]
Te_weights = np.tile(Te_weights, [1, NumClass, 1, 1])

del Data
del Label
del weights

# sz = Tr_Dat.shape
# sz_test = Te_Dat.shape
# y2 = np.ones((sz[0], NumClass, sz[2], sz[3]))
# y_test = np.ones((sz_test[0], NumClass, sz_test[2], sz_test[3]))
# for i in range(NumClass):
# y2[:, i, :, :] = np.squeeze(np.multiply(np.ones(Tr_Label.shape), ((Tr_Label == i))))
# y_test[:, i, :, :] = np.squeeze(np.multiply(np.ones(Te_Label.shape), ((Te_Label == i))))
#
# Tr_Label_bin = y2
# Te_Label_bin = y_test

return (ImdbData(Tr_Dat, Tr_Label, Tr_weights),ImdbData(Te_Dat, Te_Label, Te_weights))
# return (ImdbData(Tr_Dat, Tr_Label, Tr_Label_bin, Tr_weights),
# ImdbData(Te_Dat, Te_Label, Te_Label_bin, Te_weights))

def get_data(data_params):
Data_train = h5py.File(os.path.join(data_params['base_dir'], data_params['train_data_file'] ), 'r')
Label_train = h5py.File(os.path.join(data_params['base_dir'], data_params['train_label_file'] ), 'r')
Expand Down Expand Up @@ -113,8 +174,7 @@ def _remove_black(data, labels, background_label):
clean_data, clean_labels = [], []
for i, frame in enumerate(labels):
unique, counts = np.unique(frame, return_counts=True)
idx = np.where(unique == background_label)[0][0]
if counts[idx] / sum(counts) < .9:
if counts[0] / sum(counts) < .95:
clean_labels.append(frame)
clean_data.append(data[i])
return np.array(clean_data), np.array(clean_labels)
Expand Down
149 changes: 97 additions & 52 deletions quickNat_pytorch/log_utils.py
Expand Up @@ -7,70 +7,100 @@
from sklearn.metrics import confusion_matrix
import itertools
import torch
import threading
import math
import pandas as pd
import os
import shutil

plt.switch_backend('agg')
plt.axis('scaled')

def _dice_confusion_matrix(batch_output, labels_batch, num_classes):
dice_cm = torch.zeros(num_classes,num_classes)
batch_size, H, W = batch_output.size()
for i in range(num_classes):
GT = (labels_batch == i).float()
for j in range(num_classes):
Pred = (batch_output == j).float()
inter = torch.sum(torch.mul(GT, Pred)) + 0.0001
#union = torch.sum(GT) + torch.sum(Pred) + 0.0001
#dice_cm[i,j] = 2 * torch.div(inter, union)
dice_cm[i,j] = inter / (batch_size * H * W)

avg_dice = torch.mean(torch.diagflat(dice_cm))
return avg_dice, dice_cm

def _dice_score_perclass(batch_output, labels, num_classes):
dice_perclass = torch.zeros(num_classes)
for i in range(num_classes):
GT = (labels == i).float()
Pred = (batch_output == i).float()
inter = torch.sum(torch.mul(GT, Pred)) + 0.0001
union = torch.sum(GT) + torch.sum(Pred) + 0.0001
dice_perclass[i] = (2 * torch.div(inter, union)) / len(batch_output)

return dice_perclass

class LogWriter:
def __init__(self, num_class, cm_cmap = plt.cm.Blues, cm_normalized= True):
self.train_writer = SummaryWriter("logs/train")
self.val_writer = SummaryWriter("logs/val")
self.cm_cmap = cm_cmap
def __init__(self, num_class, log_dir_name, exp_dir_name, use_last_checkpoint=False, labels=None, cm_cmap = plt.cm.Blues):
self.num_class=num_class
self.cm_normalized = cm_normalized
train_log_path, val_log_path = os.path.join(log_dir_name, exp_dir_name, "train"), os.path.join(log_dir_name, exp_dir_name, "val")
if not use_last_checkpoint:
if os.path.exists(train_log_path):
shutil.rmtree(train_log_path)
if os.path.exists(val_log_path):
shutil.rmtree(val_log_path)

self.writer = {
'train' : SummaryWriter(train_log_path),
'val' : SummaryWriter(val_log_path)
}

self.cm_cmap = cm_cmap
self._cm = {
'train': [],
'val': []
'train': torch.zeros(self.num_class, self.num_class),
'val': torch.zeros(self.num_class, self.num_class)
}
self.fig = plt.figure() #For confusion matrix
self._ds = torch.zeros(self.num_class)
self.labels = labels

def loss_per_iter(self, loss_value, i):
print('train : [iteration : ' + str(i) + '] : ' + str(loss_value))
self.train_writer.add_scalar('loss/per_iteration', loss_value, i)
self.writer['train'].add_scalar('loss/per_iteration', loss_value, i)

def loss_per_epoch(self, train_loss_value, val_loss_value, epoch, num_epochs):
self.train_writer.add_scalar('loss/per_epoch', train_loss_value, epoch)
self.val_writer.add_scalar('loss/per_epoch', val_loss_value, epoch)
print('[Epoch : ' + str(epoch) + '/' + str(num_epochs) + '] : train loss = ' + str(train_loss_value) + ', val loss = ' + str(val_loss_value))
def loss_per_epoch(self, loss_arr, phase, epoch):
writer = self.writer[phase]
if phase == 'train':
loss = loss_arr[-1]
else:
loss = np.mean(loss_arr)
self.writer[phase].add_scalar('loss/per_epoch', loss, epoch)
print('epoch '+phase + ' loss = ' + str(loss))

def close(self):
self.train_writer.close()
self.val_writer.close()

def update_cm_per_iter(self, predicted_labels, correct_labels, labels, phase):
self._cm[phase].append(confusion_matrix(correct_labels.flatten(), predicted_labels.flatten(), range(self.num_class)))


def image_per_epoch(self, prediction, ground_truth, phase, epoch):
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))
ax[0].imshow(prediction, cmap = 'jet', vmin=0, vmax=self.num_class-1)
ax[0].set_title("Predicted", fontsize=10, color = "blue")
ax[0].axis('off')
ax[1].imshow(ground_truth, cmap = 'jet', vmin=0, vmax=self.num_class-1)
ax[1].set_title("Ground Truth", fontsize=10, color = "blue")
ax[1].axis('off')
def update_cm_per_iter(self, predictions, correct_labels, phase):
_, batch_output = torch.max(predictions, dim=1)
_, cm_batch = _dice_confusion_matrix(batch_output, correct_labels, self.num_class)
self._cm[phase]+=cm_batch.cpu()
del cm_batch, batch_output

if phase == 'train':
self.train_writer.add_figure('sample_prediction/' + phase, fig, epoch)
else:
self.val_writer.add_figure('sample_prediction/' + phase, fig, epoch)

def reset_cms(self):
self._cm = {key : [] for key, item in self._cm.items()}

def update_dice_score_per_iteration(self, predictions, correct_labels, epoch):
_, batch_output = torch.max(predictions, dim=1)
score_vector = _dice_score_perclass(batch_output, correct_labels, self.num_class)
self._ds += score_vector.cpu()

def cm_per_epoch(self, labels, phase, epoch, iteration):
cm = np.mean(self._cm[phase], axis = 0)
print("CM Shape : ", cm.shape)
def dice_score_per_epoch(self, epoch, i_batch):
ds = (self._ds / (i_batch + 1)).cpu().numpy()
self.writer['val'].add_histogram("dice_score/", ds, epoch)
self.writer['val'].add_text("dice_score/", str(ds), epoch)

if self.cm_normalized:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig = matplotlib.figure.Figure(figsize=(10, 10), dpi=360, facecolor='w', edgecolor='k')
def cm_per_epoch(self, phase, epoch, i_batch):
cm = (self._cm[phase] / (i_batch + 1)).cpu().numpy()
fig = matplotlib.figure.Figure(figsize=(10, 10), dpi=180, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)

classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in self.labels]
classes = ['\n'.join(wrap(l, 40)) for l in classes]

tick_marks = np.arange(len(classes))
Expand All @@ -88,16 +118,31 @@ def cm_per_epoch(self, labels, phase, epoch, iteration):
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()

fmt = '.2f' if self.cm_normalized else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt) if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "white" if cm[i, j] > thresh else "black")
ax.text(j, i, format(cm[i, j], '.2f') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "white" if cm[i, j] > thresh else "black")

fig.set_tight_layout(True)
np.set_printoptions(precision=2)
if phase == 'train':
self.train_writer.add_figure('confusion_matrix/' + phase, fig, epoch)
else:
self.val_writer.add_figure('confusion_matrix/' + phase, fig, epoch)

self.writer[phase].add_figure('confusion_matrix/' + phase, fig, epoch)


def image_per_epoch(self, prediction, ground_truth, phase, epoch):
ncols = 2
nrows = len(prediction)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 20))

for i in range(nrows):
ax[i][0].imshow(prediction[i], cmap = 'jet', vmin=0, vmax=self.num_class-1)
ax[i][0].set_title("Predicted", fontsize=10, color = "blue")
ax[i][0].axis('off')
ax[i][1].imshow(ground_truth[i], cmap = 'jet', vmin=0, vmax=self.num_class-1)
ax[i][1].set_title("Ground Truth", fontsize=10, color = "blue")
ax[i][1].axis('off')
fig.set_tight_layout(True)
self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch)

def close(self):
self.writer['train'].close()
self.writer['val'].close()

10 changes: 5 additions & 5 deletions quickNat_pytorch/net_api/losses.py
@@ -1,6 +1,6 @@
import torch
import numpy as np
from torch.nn.modules.loss import _Loss
from torch.nn.modules.loss import _Loss, _WeightedLoss
from torch.autograd import Function, Variable
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -33,7 +33,7 @@ def dice_coeff(input, target):
return s / (i + 1)


class DiceLoss(_Loss):
class DiceLoss(_WeightedLoss):
def forward(self, output, target, weights=None, ignore_index=None):
"""
output : NxCxHxW Variable
Expand Down Expand Up @@ -69,7 +69,7 @@ def forward(self, output, target, weights=None, ignore_index=None):
return loss_per_channel.sum() / output.size(1)


class CrossEntropyLoss2d(nn.Module):
class CrossEntropyLoss2d(_WeightedLoss):
def __init__(self, weight=None):
super(CrossEntropyLoss2d, self).__init__()
self.nll_loss = nn.CrossEntropyLoss(weight)
Expand All @@ -78,15 +78,15 @@ def forward(self, inputs, targets):
return self.nll_loss(inputs, targets)


class CombinedLoss(nn.Module):
class CombinedLoss(_Loss):
def __init__(self):
super(CombinedLoss, self).__init__()
self.cross_entropy_loss = CrossEntropyLoss2d()
self.dice_loss = DiceLoss()

def forward(self, input, target, weight):
# TODO: why?
target = target.type(torch.LongTensor).cuda()
#target = target.type(torch.LongTensor).cuda()
input_soft = F.softmax(input, dim = 1)
y2 = torch.mean(self.dice_loss(input_soft, target))
y1 = torch.mean(torch.mul(self.cross_entropy_loss.forward(input, target), weight))
Expand Down
3 changes: 2 additions & 1 deletion quickNat_pytorch/net_api/sub_module.py
Expand Up @@ -54,7 +54,7 @@ def __init__(self, params):
self.batchnorm2 = nn.BatchNorm2d(num_features=conv1_out_size)
self.batchnorm3 = nn.BatchNorm2d(num_features=conv2_out_size)
self.prelu = nn.PReLU()
if params['drop_out'] > 0.0:
if params['drop_out'] > 0:
self.drop_out_needed = True
self.drop_out = nn.Dropout2d(params['drop_out'])
else:
Expand Down Expand Up @@ -86,6 +86,7 @@ def forward(self, input):
if self.se_block_type is not se.SELayer.NONE:
out_block = self.SELayer(out_block)


if self.drop_out_needed:
out_block = self.drop_out(out_block)

Expand Down
18 changes: 10 additions & 8 deletions quickNat_pytorch/quickNAT.py
Expand Up @@ -9,7 +9,7 @@
class quickNAT(nn.Module):
"""
A PyTorch implementation of QuickNAT
Coded by Abhijit
Coded by Abhijit and Shayan
param ={
'num_channels':1,
Expand All @@ -32,7 +32,6 @@ def __init__(self, params):
self.encode1 = sm.EncoderBlock(params)
params['num_channels'] = 64
self.encode2 = sm.EncoderBlock(params)
# params['num_channels'] = 64 # This can be used to change the numchannels for each block
self.encode3 = sm.EncoderBlock(params)
self.encode4 = sm.EncoderBlock(params)
self.bottleneck = sm.DenseBlock(params)
Expand Down Expand Up @@ -92,20 +91,23 @@ def predict(self, X, enable_dropout = False):
- X: Volume to be predicted
"""
self.eval()
torch.no_grad()

if type(X) is np.ndarray:
X = Variable(torch.Tensor(X).cuda())
X = torch.tensor(X, requires_grad = False).cuda(non_blocking=True)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.cuda()
X = X.cuda(non_blocking=True)

if enable_dropout:
self.enable_test_dropout()

out = self.forward(X)
with torch.no_grad():
out = self.forward(X)

max_val, idx = torch.max(out,1)
idx = idx.data.cpu().numpy()
idx = np.squeeze(idx)
return idx
prediction = np.squeeze(idx)
del X, out, idx, max_val
return prediction



0 comments on commit 05aa676

Please sign in to comment.