In [1]:
import os
import numpy as np
from matplotlib import pyplot as plt
import visdom_plot
from PIL import Image
import gc

In [3]:
from sklearn.model_selection import KFold
import torch
import torch.nn.functional as func
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torch.utils.data as data
import torch.optim as optim
import radam
import time

import Network

In [26]:
t = torch.ones([2200, 1, 128, 128, 128], dtype=torch.float16)

In [4]:
torch.set_printoptions(linewidth=30)
torch.set_grad_enabled(True)
torch.set_printoptions(edgeitems=20)

global transform
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [5]:
def calculate_average(num):
    return sum(num) / len(num)


def check_tensor():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data')
                                        and torch.is_tensor(obj.data)):
                del obj
                torch.cuda.empty_cache()
        except:
            pass


def read_data(raw_dir, img_dir, cross=True):
    list_data = []
    list_label = []
    raw_name = []

    for filename in os.listdir(raw_dir):
        raw_name.append(filename)
        isovalue = int(filename.split('_')[1].strip('.raw'))

        f = np.fromfile(raw_dir + filename, dtype='uint8')
        f = (f.astype('float') - isovalue / 2) / 255
        raw_img = torch.Tensor(f).reshape([1, 128, 128, 128])
        list_data.append(raw_img.detach())
        del raw_img
        gc.collect()

        if os.path.isfile(img_dir + filename.replace('.raw', '.png')):
            item = filename.replace('.raw', '.png')
            im = transform(Image.open(img_dir + item))
            list_label.append(im)

    if cross:
        return list_data[:3200], list_label[:3200], list_data[
            3200:3600], list_label[3200:3600], raw_name[3200:3600]
    else:
        tensor_data = torch.stack(list_data)
        tensor_label = torch.stack(list_label)
        return utils.data.TensorDataset(tensor_data, tensor_label), raw_name


def visualize_output(batch):
    fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(25, 25))
    for ax in axes.flatten():
        ax.axis('off')

    for i, img in enumerate(batch):
        axes[i // 4, i % 4].imshow(img.permute(1, 2, 0))


def draw_image(pred, image):
    visualize_output(pred.detach())
    visualize_output(image.detach())


def generate_sample(train_idx, valid_idx, list_data, list_label, test_data,
                    test_label, batch_size, fold_num):
    train_v = torch.stack(list(list_data[i] for i in train_idx))
    train_i = torch.stack(list(list_label[i] for i in train_idx))

    valid_v = torch.stack(list(list_data[i] for i in valid_idx))
    valid_i = torch.stack(list(list_label[i] for i in valid_idx))

    test_v = torch.stack(test_data)
    test_i = torch.stack(test_label)

    print()
    print()
    print('-' * 15, " New Fold %s" % fold_num, '-' * 15)

    train_dataset = utils.data.TensorDataset(train_v, train_i)
    valid_dataset = utils.data.TensorDataset(valid_v, valid_i)
    test_dataset = utils.data.TensorDataset(test_v, test_i)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    return train_loader, valid_loader, test_loader


def cross_validation_split(dataset,
                           sample_size,
                           val_split,
                           batch_size,
                           shuffle=True):
    random_seed = 42

    indices = list(range(sample_size))
    split = int(np.floor(val_split * sample_size))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_indices, valid_indices = indices[split:], indices[:split]

    train_sampler = utils.data.SubsetRandomSampler(train_indices)
    valid_sampler = utils.data.SubsetRandomSampler(valid_indices)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=valid_sampler)

    return train_loader, valid_loader


def save_output(tensor, dir, name):
    for index in range(tensor.size()[0]):
        if os.path.exists(dir + name[index].replace('.raw', '.png')):
            os.remove(dir + name[index].replace('.raw', '.png'))
        # print()
        # print(name[index].replace('.raw', '.png'))
        # print(tensor[index][0][32])
        # print(output[0][32])
        save_image(tensor[index], dir + name[index].replace('.raw', '.png'))


def train_model(model, optimizer, scheduler, train_loader):
    scheduler.step()
    total_loss = 0

    for volume, image in train_loader:
        optimizer.zero_grad()
        pred = model(volume)
        loss = F.mse_loss(pred, image)
        # loss = F.smooth_l1_loss(pred, image)

        loss.backward()
        optimizer.step()
        total_loss += loss.detach()

    loss = total_loss.item() / len(train_loader)

    return loss


def evaluate_model(model, scheduler, valid_loader):
    scheduler.step()
    total_loss = 0

    with torch.no_grad():
        for volume, image in valid_loader:
            pred_valid = model(volume)
            loss_valid = F.mse_loss(pred_valid, image)
            # loss_valid = F.smooth_l1_loss(pred_valid, image)
            total_loss += loss_valid.detach()

        loss = total_loss.item() / len(valid_loader)

    return loss


def predict_model(model, test_loader):
    result = torch.Tensor()
    with torch.no_grad():
        for volume, image in test_loader:
            pred = model(volume)
            result = torch.cat((result, pred), 0)

    return result

In [6]:
data_dir = './data/'
raw_dir = os.path.join(data_dir, 'raw/')
img_dir = os.path.join(data_dir, 'image/')
result_dir = os.path.join(data_dir, 'result/')
network_path = './data/network.pkl'

In [None]:
global raw1
global raw2
global raw3
global raw4
global raw5

global image1
global image2
global image3
global image4
global image5

In [7]:
num_epochs = 10
batch_size = 10
learning_rate = 0.001

In [None]:
with torch.cuda.device(0):
    # network = Network.Network()
    # network.load_state_dict(torch.load(network_path))

    skf = KFold(n_splits=8, shuffle=True, random_state=0)
    network = Network.Network()
    optimizer = radam.RAdam(network.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    fold_train = {}
    fold_valid = {}

    fold_num = 1
    for t_idx, v_idx in skf.split(list_data, list_label):
        train_loss = []
        valid_loss = []
        plotter = visdom_plot.VisdomLinePlotter(
            env_name='Volume to Image: 8 fold, 30 epochs each')

        train_loader, valid_loader, test_loader = generate_sample(
            t_idx, v_idx, list_data, list_label, test_data, test_label,
            batch_size, fold_num)
        since = time.time()

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch + 1, num_epochs))

            network.train()
            t_loss = train_model(network, optimizer, scheduler, train_loader)
            print("     Train Loss:", t_loss)
            train_loss.append(t_loss)
            plotter.plot('loss', 'train', 'Fold %s Loss' % fold_num, epoch,
                         t_loss)

            network.eval()
            v_loss = evaluate_model(network, scheduler, valid_loader)
            print("     Valid Loss:", v_loss)
            valid_loss.append(v_loss)
            plotter.plot('loss', 'valid', 'Fold %s Loss' % fold_num, epoch,
                         v_loss)

            print('-' * 40)

            if epoch % 5 == 0:
                pred = predict_model(network, test_loader)
                save_output(pred, result_dir, raw_name)

        fold_train["Fold %s" % fold_num] = train_loss
        fold_valid["Fold %s" % fold_num] = valid_loss

        print('-' * 40)
        time_elapsed = time.time() - since
        print('{} Epoch Time Total: {:.0f}m {:.0f}s'.format(
            num_epochs, time_elapsed // 60, time_elapsed % 60))
        print('{} Epoch Train Loss Average: {}'.format(
            num_epochs, calculate_average(train_loss)))
        print('{} Epoch Valid Loss Average: {}'.format(
            num_epochs, calculate_average(valid_loss)))

        fold_num += 1
        break

    torch.save(network.state_dict(), network_path)

    pred = predict_model(network, test_loader)
    save_output(pred, result_dir, raw_name)

    print(fold_train)
    print(fold_valid)