This is a project for medical 3d voxel classification mission for machine learning course.

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as udata
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm_notebook
from GoogLeNet import GoogLeNet
from SimpleNet import SimpleNet
from SimpleNet import CropNet
from SimpleNet import DenseSharp
from tensorboardX import SummaryWriter
from DataProcessing import *

hyper parameters definition

In [2]:
LR = 1e-4
batch_size = 15
EPOCH = 500

First read data from files

In [3]:
voxel_train, seg_train, total_batch_size = data_read('data/train_val/candidate{}.npz', 584, notebook=True)
voxel_test, seg_test, test_batch_size = data_read('data/test/candidate{}.npz', 584, notebook=True)
train_label = pd.read_csv('data/train_val.csv').values[:, 1].astype(int)
print('Read Complete!')


HBox(children=(IntProgress(value=0, description='reading', max=584, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, description='reading', max=584, style=ProgressStyle(description_width='ini…


Read Complete!


In [4]:
train_label = data_to_tensor(train_label)
voxel_train = data_to_tensor(voxel_train)
voxel_test = data_to_tensor(voxel_test)
seg_train = data_to_tensor(seg_train, dtype=torch.bool)
seg_test = data_to_tensor(seg_test, dtype=torch.bool)
train_data = udata.TensorDataset(voxel_train, train_label)
validate_batch_size = round(0.2 * total_batch_size)
train_batch_size = total_batch_size - validate_batch_size
train_data, validate_data = udata.random_split(train_data, [train_batch_size, validate_batch_size])
voxel_loader = udata.DataLoader(train_data, batch_size, shuffle=True)
voxel_loaderv = udata.DataLoader(validate_data, batch_size, shuffle=True)

In [5]:
masked_voxel_train = data_augment(voxel_train, normalize=True, mask=seg_train)
masked_voxel_test = data_augment(voxel_test, normalize=True, mask=seg_test)
masked_voxel_train = masked_voxel_train.unsqueeze(1)
masked_voxel_test = masked_voxel_test.unsqueeze(1)
masked_train_data = udata.TensorDataset(masked_voxel_train, train_label)
masked_train_data, masked_validate_data = udata.random_split(masked_train_data, [train_batch_size, validate_batch_size])
masked_voxel_loader = udata.DataLoader(masked_train_data, batch_size, shuffle=True)
masked_voxel_loaderv = udata.DataLoader(masked_validate_data, batch_size, shuffle=True)

In [6]:
resize_voxel_train = data_resize(voxel_train, seg_train, 32, masked=True)
resize_voxel_test = data_resize(voxel_test, seg_test, 32, masked=True)
resize_voxel_train = resize_voxel_train.unsqueeze(1)
resize_voxel_test = resize_voxel_test.unsqueeze(1)
resize_train_data = udata.TensorDataset(resize_voxel_train, train_label)
resize_train_data, resize_validate_data = udata.random_split(resize_train_data, [train_batch_size, validate_batch_size])
resize_voxel_loader = udata.DataLoader(resize_train_data, batch_size, shuffle=True)
resize_voxel_loaderv = udata.DataLoader(resize_validate_data, batch_size, shuffle=True)

In [7]:
crop_voxel_train = data_crop(voxel_train, seg_train, crop_size=44)
crop_voxel_test = data_crop(voxel_test, seg_test, crop_size=44)
crop_voxel_train = crop_voxel_train.unsqueeze(1)
crop_voxel_test = crop_voxel_test.unsqueeze(1)
crop_train_data = udata.TensorDataset(crop_voxel_train, train_label)
crop_train_data, crop_validate_data = udata.random_split(crop_train_data, [train_batch_size, validate_batch_size])
crop_voxel_loader = udata.DataLoader(crop_train_data, batch_size, shuffle=True)
crop_voxel_loaderv = udata.DataLoader(crop_validate_data, batch_size, shuffle=True)

We can now build the model.

In [None]:
writer = SummaryWriter()
torch.cuda.empty_cache()
net = DenseSharp().cuda()
batch_num = int(train_batch_size / batch_size) + 1
batch_numv = int(validate_batch_size / batch_size) + 1
losses = torch.zeros(EPOCH * batch_num)
accuracys = torch.zeros(EPOCH * batch_num)
lossesv = torch.zeros(EPOCH * batch_num)
accuracysv = torch.zeros(EPOCH * batch_num)
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
loss_func = nn.MSELoss()
max_accuracy = 0
for epoch in range(EPOCH):    
    for j, (voxel, label) in enumerate(tqdm_notebook(resize_voxel_loader, desc='training')):
        voxel = voxel.cuda()
        label = label.cuda()
        prediction = net(voxel)
        loss = loss_func(prediction, label.to(dtype=torch.float32))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses[epoch * batch_num + j] = loss.item()
        prediction[prediction > 0.5] = 1
        prediction[prediction < 0.5] = 0
        accuracy = (((prediction.round() == label.round()).sum()).to(dtype=torch.float32) / batch_size).item()
        accuracys[epoch * batch_num + j] = accuracy
        writer.add_scalar('scalar/loss', loss.item(), epoch * batch_num + j)
        writer.add_scalar('scalar/accuracy', accuracy, epoch * batch_num + j)
    avg_loss = torch.mean(losses[epoch * batch_num:(epoch + 1) * batch_num])
    avg_accuracy = torch.mean(accuracys[epoch * batch_num:(epoch + 1) * batch_num])
    print('loss: ', avg_loss.item())
    print('accuracy: ', avg_accuracy.item())
    
    for j, (voxel, label) in enumerate(tqdm_notebook(resize_voxel_loaderv, desc='validating')):
        voxel = voxel.cuda()
        label = label.cuda()
        prediction = net(voxel)
        loss = loss_func(prediction, label.to(dtype=torch.float32))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossesv[epoch * batch_numv + j] = loss.item()
        prediction[prediction > 0.5] = 1
        prediction[prediction < 0.5] = 0
        accuracy = (((prediction.round() == label.round()).sum()).to(dtype=torch.float32) / batch_size).item()
        accuracysv[epoch * batch_numv + j] = accuracy
        writer.add_scalar('scalar/loss_validate', loss.item(), epoch * batch_numv + j)
        writer.add_scalar('scalar/accuracy_validate', accuracy, epoch * batch_numv + j)
    avg_loss = torch.mean(lossesv[epoch * batch_numv:(epoch + 1) * batch_numv])
    avg_accuracy = torch.mean(accuracysv[epoch * batch_numv:(epoch + 1) * batch_numv])
    print('loss: ', avg_loss.item())
    print('accuracy: ', avg_accuracy.item())
    
    if avg_accuracy > max_accuracy:
        max_accuracy = avg_accuracy.item()
        torch.save(net.state_dict(), 'net_best.pkl')


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2390819489955902
accuracy:  0.6186667084693909


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.22302262485027313
accuracy:  0.6095238327980042


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2269986867904663
accuracy:  0.6560001373291016


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.23836612701416016
accuracy:  0.571428656578064


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2245134860277176
accuracy:  0.626666784286499


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.2516498267650604
accuracy:  0.5809523463249207


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.22695964574813843
accuracy:  0.6480001211166382


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.215408593416214
accuracy:  0.6095237731933594


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.22206443548202515
accuracy:  0.6293334364891052


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.22967903316020966
accuracy:  0.5809524655342102


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2143627554178238
accuracy:  0.642666757106781


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.24214176833629608
accuracy:  0.5428572297096252


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.22524717450141907
accuracy:  0.642666757106781


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.2232247292995453
accuracy:  0.5809523463249207


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.21820291876792908
accuracy:  0.6560000777244568


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.23259107768535614
accuracy:  0.571428656578064


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2140943706035614
accuracy:  0.6906667351722717


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.20009945333003998
accuracy:  0.6095237731933594


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.2210243195295334
accuracy:  0.6666668057441711


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.20609258115291595
accuracy:  0.5809524655342102


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.20810005068778992
accuracy:  0.6933334469795227


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.2161097228527069
accuracy:  0.5904762148857117


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.20738358795642853
accuracy:  0.6853334903717041


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.23065759241580963
accuracy:  0.6190477013587952


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.20133857429027557
accuracy:  0.6746667623519897


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.1966933310031891
accuracy:  0.5999999642372131


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.20041730999946594
accuracy:  0.7146668434143066


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.19175343215465546
accuracy:  0.6285714507102966


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.19453196227550507
accuracy:  0.7226668000221252


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.19470271468162537
accuracy:  0.5904761552810669


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.19520659744739532
accuracy:  0.6960001587867737


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.20832403004169464
accuracy:  0.6285714507102966


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.21004609763622284
accuracy:  0.661333441734314


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.22069552540779114
accuracy:  0.6095238327980042


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.20112121105194092
accuracy:  0.6933334469795227


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.18599964678287506
accuracy:  0.6380952596664429


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.19320380687713623
accuracy:  0.6986667513847351


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.19714342057704926
accuracy:  0.6095237731933594


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…


loss:  0.18300144374370575
accuracy:  0.7493334412574768


HBox(children=(IntProgress(value=0, description='validating', max=7, style=ProgressStyle(description_width='in…


loss:  0.17974352836608887
accuracy:  0.6857143044471741


HBox(children=(IntProgress(value=0, description='training', max=25, style=ProgressStyle(description_width='ini…

visualization

In [None]:
writer.close()
print(max_accuracy)

In [None]:
torch.cuda.empty_cache()
batch_size = 1
net.load_state_dict(torch.load('net_best.pkl'))
net = net.cpu()
test_loader = udata.DataLoader(resize_voxel_test, batch_size, shuffle=False)
prediction = torch.zeros(test_batch_size)
for j, voxel in enumerate(tqdm_notebook(test_loader)):
    # voxel = voxel.cuda()
    predict = net(voxel)
    prediction[j] = predict.cpu().detach()

In [None]:
submit = pd.read_csv('data/sampleSubmission.csv')
submit['Predicted'] = prediction.numpy()
submit.to_csv('submit120705.csv', index=False)


In [None]:
prediction_abs = prediction.clone()
prediction_abs[prediction_abs >= 0.5] = 1
prediction_abs[prediction_abs < 0.5] = 0
submit = pd.read_csv('data/sampleSubmission.csv')
submit['Predicted'] = prediction_abs.numpy()
submit.to_csv('submit120702_zc.csv', index=False)

In [None]:
print(prediction.numpy())

In [None]:
l1, = plt.plot(list(range(500)), accuracysv[2500:3000])
plt.legend(handles=[l1], labels=['train'], loc='best')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title('Loss on Train Set')
plt.show()
