In [5]:
import os
import numpy as np
from matplotlib import pyplot as plt
import visdom_plot
from PIL import Image
import gc

In [6]:
from sklearn.model_selection import KFold
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
import torch.utils as utils
import torch.optim as optim
import radam
import time

In [26]:
import inspect

x,y,z = 1,2,3

l = []
l.append(x)
l.append(y)
l.append(z)

def retrieve_name(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    return [var_name for var_name, var_val in callers_local_vars if var_val is var][0]

find = retrieve_name(y)
print(l)
# l.index(find)

[1, 2, 3]


In [27]:
s = 'thisisme4'
print(s[-1])

4


In [12]:
torch.set_printoptions(linewidth=30)
torch.set_grad_enabled(True)
torch.set_printoptions(edgeitems=20)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

global transform
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [4]:
def calculate_average(num):
    return sum(num) / len(num)


def check_tensor():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data')
                                        and torch.is_tensor(obj.data)):
                del obj
                torch.cuda.empty_cache()
        except:
            pass


def read_data(raw_dir, img_dir, cross=True):
    list_data = []
    list_label = []
    raw_name = []

    for filename in os.listdir(raw_dir):
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        raw_name.append(filename)
        isovalue = int(filename.split('_')[1].strip('.raw'))

        f = np.fromfile(raw_dir + filename, dtype='uint8')
        # f = ((f.astype('float') - isovalue / 2) / 255) * 2
        f = (f.astype('float') - isovalue / 2) / 255
        raw_img = torch.Tensor(f).reshape([1, 64, 64, 64])
        list_data.append(raw_img.detach())
        del raw_img
        gc.collect()

        if os.path.isfile(img_dir + filename.replace('.raw', '.png')):
            item = filename.replace('.raw', '.png')
            im = transform(Image.open(img_dir + item))
            list_label.append(im)

    if cross:
        return list_data[:1600], list_label[:1600], list_data[
            1600:2000], list_label[1600:2000], raw_name[1600:2000]
    else:
        tensor_data = torch.stack(list_data)
        tensor_label = torch.stack(list_label)
        return utils.data.TensorDataset(tensor_data, tensor_label), raw_name


def flatten_indices(indices):
    indices = indices[:, :, 0, :, :]
    max = indices.size()[2] * indices.size()[3] * 4
    return (indices.int() - ((indices >= max).int() * max)).long()


def visualize_output(batch):
    fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(25, 25))
    for ax in axes.flatten():
        ax.axis('off')

    for i, img in enumerate(batch):
        axes[i // 4, i % 4].imshow(img.permute(1, 2, 0))


def draw_image(pred, image):
    visualize_output(pred.detach())
    visualize_output(image.detach())


def generate_sample(train_idx, valid_idx, list_data, list_label, test_data,
                    test_label, batch_size, fold_num):
    train_v = torch.stack(list(list_data[i] for i in train_idx))
    train_i = torch.stack(list(list_label[i] for i in train_idx))

    valid_v = torch.stack(list(list_data[i] for i in valid_idx))
    valid_i = torch.stack(list(list_label[i] for i in valid_idx))

    test_v = torch.stack(test_data)
    test_i = torch.stack(test_label)

    print()
    print()
    print('-' * 15, " New Fold %s" % fold_num, '-' * 15)

    train_dataset = utils.data.TensorDataset(train_v, train_i)
    valid_dataset = utils.data.TensorDataset(valid_v, valid_i)
    test_dataset = utils.data.TensorDataset(test_v, test_i)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    return train_loader, valid_loader, test_loader


def cross_validation_split(dataset,
                           sample_size,
                           val_split,
                           batch_size,
                           shuffle=True):
    random_seed = 42

    indices = list(range(sample_size))
    split = int(np.floor(val_split * sample_size))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_indices, valid_indices = indices[split:], indices[:split]

    train_sampler = utils.data.SubsetRandomSampler(train_indices)
    valid_sampler = utils.data.SubsetRandomSampler(valid_indices)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=valid_sampler)

    return train_loader, valid_loader


def save_output(tensor, dir, name):
    for index in range(tensor.size()[0]):
        if os.path.exists(dir + name[index].replace('.raw', '.png')):
            os.remove(dir + name[index].replace('.raw', '.png'))
        image = transforms.ToPILImage(mode='RGB')(tensor[index])
        output = transform(image)
        # print()
        # print(name[index].replace('.raw', '.png'))
        # print(tensor[index][0][32])
        # print(output[0][32])
        save_image(output, dir + name[index].replace('.raw', '.png'))


def train_model(model, optimizer, scheduler, train_loader):
    scheduler.step()
    total_loss = 0

    for volume, image in train_loader:
        optimizer.zero_grad()
        pred = model(volume)
        # loss = F.mse_loss(pred, image)
        loss = F.smooth_l1_loss(pred, image)

        loss.backward()
        optimizer.step()
        total_loss += loss.detach()

    loss = total_loss.item() / len(train_loader)

    return loss


def evaluate_model(model, scheduler, valid_loader):
    scheduler.step()
    total_loss = 0

    with torch.no_grad():
        for volume, image in valid_loader:
            pred_valid = model(volume)
            # loss_valid = F.mse_loss(pred_valid, image)
            loss_valid = F.smooth_l1_loss(pred_valid, image)
            total_loss += loss_valid.detach()

        loss = total_loss.item() / len(valid_loader)

    return loss


def predict_model(model, test_loader):
    result = torch.Tensor()
    with torch.no_grad():
        for volume, image in test_loader:
            pred = model(volume)
            result = torch.cat((result, pred), 0)

    return result

In [5]:
data_dir = './data/'
raw_dir = os.path.join(data_dir, 'raw/')
img_dir = os.path.join(data_dir, 'image/')
result_dir = os.path.join(data_dir, 'result/')
network_path = './data/network.pkl'

list_data, list_label, test_data, test_label, raw_name = read_data(raw_dir,
                                                                   img_dir,
                                                                   cross=True)

torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 64, 64])
torch.Size([2, 6

KeyboardInterrupt: 

In [None]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        # Convolution 1
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv1.weight)
        self.prelu1 = nn.PReLU()
        self.max1 = nn.MaxPool3d(kernel_size=(2, 2, 2),
                                 stride=(2, 2, 2),
                                 return_indices=True)

        # Convolution 2
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv2.weight)
        self.prelu2 = nn.PReLU()
        self.max2 = nn.MaxPool3d(kernel_size=(2, 2, 2),
                                 stride=(2, 2, 2),
                                 return_indices=True)

        # Convolution 3
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv3.weight)
        self.prelu3 = nn.PReLU()
        self.max3 = nn.MaxPool3d(kernel_size=(2, 2, 2),
                                 stride=(2, 2, 2),
                                 return_indices=True)

        # Convolution 4
        self.conv4 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv4.weight)
        self.prelu4 = nn.PReLU()
        self.max4 = nn.MaxPool3d(kernel_size=(2, 2, 2),
                                 stride=(2, 2, 2),
                                 return_indices=True)

        # Fully Connected / Dense Layer 1
        self.fc1 = nn.Linear(128 * 4 * 4 * 4, 128 * 4 * 4)
        self.drop = nn.Dropout(0.2)

        # De Convolution 1
        self.maxUn1 = torch.nn.MaxUnpool2d(2, stride=2)
        self.deconv1 = torch.nn.ConvTranspose2d(128, 64, 3, padding=1)
        nn.init.xavier_uniform_(self.deconv1.weight)
        self.prelu5 = nn.PReLU()

        # De Convolution 2
        self.maxUn2 = torch.nn.MaxUnpool2d(2, stride=2)
        self.deconv2 = torch.nn.ConvTranspose2d(64, 32, 3, padding=1)
        nn.init.xavier_uniform_(self.deconv2.weight)
        self.prelu6 = nn.PReLU()

        # De Convolution 3
        self.maxUn3 = torch.nn.MaxUnpool2d(2, stride=2)
        self.deconv3 = torch.nn.ConvTranspose2d(32, 16, 3, padding=1)
        nn.init.xavier_uniform_(self.deconv3.weight)
        self.prelu7 = nn.PReLU()

        # De Convolution 4
        self.maxUn4 = torch.nn.MaxUnpool2d(2, stride=2)
        self.deconv4 = torch.nn.ConvTranspose2d(16, 3, 3, padding=1)
        nn.init.xavier_uniform_(self.deconv4.weight)

    def forward(self, data):
        out = self.prelu1(self.conv1(data))
        size1 = out[:, :, 0, :, :].size()
        out, indices1 = self.max1(out)

        out = self.prelu2(self.conv2(out))
        size2 = out[:, :, 0, :, :].size()
        out, indices2 = self.max2(out)

        out = self.prelu3(self.conv3(out))
        size3 = out[:, :, 0, :, :].size()
        out, indices3 = self.max3(out)

        out = self.prelu4(self.conv4(out))
        size4 = out[:, :, 0, :, :].size()
        out, indices4 = self.max4(out)

        out = out.view(out.size(0), -1)
        out = F.leaky_relu(self.fc1(out))
        out = out.view(10, 128, 4, 4)
        out = self.drop(out)

        indices1 = flatten_indices(indices1)
        indices2 = flatten_indices(indices2)
        indices3 = flatten_indices(indices3)
        indices4 = flatten_indices(indices4)

        out = self.maxUn1(out, indices4)
        out = self.prelu5(self.deconv1(out))

        out = self.maxUn2(out, indices3)
        out = self.prelu6(self.deconv2(out))

        out = self.maxUn3(out, indices2)
        out = self.prelu7(self.deconv3(out))

        out = self.maxUn4(out, indices1)
        out = self.deconv4(out)

        return out

In [None]:
# Hyper Parameters
num_epochs = 10
batch_size = 10
learning_rate = 0.001

In [None]:
with torch.cuda.device(0):
    # network = Network.Network()
    # network.load_state_dict(torch.load(network_path))

    skf = KFold(n_splits=8, shuffle=True, random_state=0)
    network = Network.Network()
    optimizer = radam.RAdam(network.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    fold_train = {}
    fold_valid = {}

    fold_num = 1
    for t_idx, v_idx in skf.split(list_data, list_label):
        train_loss = []
        valid_loss = []
        plotter = visdom_plot.VisdomLinePlotter(
            env_name='Volume to Image: 8 fold, 30 epochs each')

        train_loader, valid_loader, test_loader = generate_sample(
            t_idx, v_idx, list_data, list_label, test_data, test_label,
            batch_size, fold_num)
        since = time.time()

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch + 1, num_epochs))

            network.train()
            t_loss = train_model(network, optimizer, scheduler, train_loader)
            print("     Train Loss:", t_loss)
            train_loss.append(t_loss)
            plotter.plot('loss', 'train', 'Fold %s Loss' % fold_num, epoch,
                         t_loss)

            network.eval()
            v_loss = evaluate_model(network, scheduler, valid_loader)
            print("     Valid Loss:", v_loss)
            valid_loss.append(v_loss)
            plotter.plot('loss', 'valid', 'Fold %s Loss' % fold_num, epoch,
                         v_loss)

            print('-' * 40)

        fold_train["Fold %s" % fold_num] = train_loss
        fold_valid["Fold %s" % fold_num] = valid_loss

        print('-' * 40)
        time_elapsed = time.time() - since
        print('{} Epoch Time Total: {:.0f}m {:.0f}s'.format(
            num_epochs, time_elapsed // 60, time_elapsed % 60))
        print('{} Epoch Train Loss Average: {}'.format(
            num_epochs, calculate_average(train_loss)))
        print('{} Epoch Valid Loss Average: {}'.format(
            num_epochs, calculate_average(valid_loss)))

        fold_num += 1
        break

    torch.save(network.state_dict(), network_path)

    pred = predict_model(network, test_loader)
    save_output(pred, result_dir, raw_name)

    print(fold_train)
    print(fold_valid)