# MUNIT

## Data preparation

In [None]:
import glob
import pandas as pd
from sklearn.model_selection import train_test_split



if True:
    
    # write list to binary file
    def write_list(a_list, a_filename):
        
        df = pd.DataFrame(a_list, columns=['paths'])
        df.to_csv(a_filename, header=False, index=False)

        
    cirrus_paths = glob.glob("E:/Christina/Result_Data/CIRRUS_*/enface/*_cube_z.tif")
    fundus_paths = glob.glob("E:/Christina/Result_Data/ADAM_*/fundus/*jpg")
    
    cirrus_paths_train, cirrus_paths_test = train_test_split(cirrus_paths, test_size=20)
    fundus_paths_train, fundus_paths_test = train_test_split(fundus_paths, test_size=20)
    
    data_list_train_cirrus = "munit/datasets/cirrus_train.txt"
    data_list_test_cirrus  = "munit/datasets/cirrus_test.txt"
    data_list_train_fundus = "munit/datasets/fundus_train.txt"
    data_list_test_fundus  = "munit/datasets/fundus_test.txt"
    
    
    write_list(cirrus_paths_train, data_list_train_cirrus)
    write_list(cirrus_paths_test, data_list_test_cirrus)
    write_list(fundus_paths_train, data_list_train_fundus)
    write_list(fundus_paths_test, data_list_test_fundus)
    

# Train

In [None]:
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

# vgg_preprocess needs a device

import os
import sys
import shutil

from munit.trainer import MUNIT_Trainer, UNIT_Trainer
from munit.utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer

import torch
import torch.nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import tensorboardX

config_file = 'munit/configs/img2img_list.yaml'
output_path = 'munit/results'
trainer = "MUNIT" # or UNIT

device = "cuda"
resume = False

# Load experiment setting
config = get_config(config_file)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = output_path

# Setup model and data loader
if trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config, device)
elif trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.to(device)
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).to(device)
train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).to(device)
test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).to(device)
test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).to(device)

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(config_file))[0]
train_writer = tensorboardX.SummaryWriter(os.path.join(output_path + "/logs", model_name))
output_directory = os.path.join(output_path + "/outputs", model_name)
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(config_file, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

# Start training
start_iter = trainer.resume(checkpoint_directory, hyperparameters=config) if resume else 0

for iterations in range(start_iter, max_iter):
    for i_batch, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        
        
        images_a, images_b = images_a.to(device).detach(), images_b.to(device).detach()

        #with Timer("Elapsed time in update: %f"):
        # Main training code
        trainer.dis_update(images_a, images_b, config)
        trainer.gen_update(images_a, images_b, config)
        torch.cuda.synchronize()
            
        trainer.update_learning_rate()

    print("This Iteration: %08d/%d" % (iterations, max_iter))
    
    # log stuff
    if iterations % config['image_save_iter'] == 0: # config['image_save_iter']

            print("Save Iteration: %08d/%d" % (iterations, max_iter))
            write_loss(iterations, trainer, train_writer)

            # Write images
            with torch.no_grad():
                test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
            write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

            # save images
            with torch.no_grad():
                image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'train_current')

            # Save network weights
            trainer.save(checkpoint_directory, iterations)
            
        



# Inference

In [None]:
import os