# MUNIT

## Data preparation

In [1]:
import glob
import pandas as pd
from sklearn.model_selection import train_test_split



if True:
    
    # write list to binary file
    def write_list(a_list, a_filename):
        
        df = pd.DataFrame(a_list, columns=['paths'])
        df.to_csv(a_filename, header=False, index=False)

        
    cirrus_paths = glob.glob("E:/Christina/Result_Data/CIRRUS_*/enface/*_cube_z.tif")
    fundus_paths = glob.glob("E:/Christina/Result_Data/ADAM_*/fundus/*jpg")
    
    cirrus_paths_train, cirrus_paths_test = train_test_split(cirrus_paths, test_size=20)
    fundus_paths_train, fundus_paths_test = train_test_split(fundus_paths, test_size=20)
    
    data_list_train_cirrus = "munit/datasets/cirrus_train.txt"
    data_list_test_cirrus  = "munit/datasets/cirrus_test.txt"
    data_list_train_fundus = "munit/datasets/fundus_train.txt"
    data_list_test_fundus  = "munit/datasets/fundus_test.txt"
    
    
    write_list(cirrus_paths_train, data_list_train_cirrus)
    write_list(cirrus_paths_test, data_list_test_cirrus)
    write_list(fundus_paths_train, data_list_train_fundus)
    write_list(fundus_paths_test, data_list_test_fundus)
    

ValueError: test_size=20 should be either positive and smaller than the number of samples 0 or a float in the (0, 1) range

# Train

In [2]:
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

# vgg_preprocess needs a device

import os
import sys
import shutil

from munit.trainer import MUNIT_Trainer, UNIT_Trainer
from munit.utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer

import torch
import torch.nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import tensorboardX

config_file = 'munit/configs/img2img_list.yaml'
output_path = 'munit/results'
trainer = "MUNIT" # or UNIT

device = "cuda"
resume = False

# Load experiment setting
config = get_config(config_file)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = output_path

# Setup model and data loader
if trainer == 'MUNIT':
    trainer = MUNIT_Trainer(config, device)
elif trainer == 'UNIT':
    trainer = UNIT_Trainer(config)
else:
    sys.exit("Only support MUNIT|UNIT")
trainer.to(device)
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).to(device)
train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).to(device)
test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).to(device)
test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).to(device)

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(config_file))[0]
train_writer = tensorboardX.SummaryWriter(os.path.join(output_path + "/logs", model_name))
output_directory = os.path.join(output_path + "/outputs", model_name)
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(config_file, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

# Start training
start_iter = trainer.resume(checkpoint_directory, hyperparameters=config) # if resume else 0

for iterations in range(start_iter, max_iter):
    for i_batch, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        
        
        images_a, images_b = images_a.to(device).detach(), images_b.to(device).detach()

        #with Timer("Elapsed time in update: %f"):
        # Main training code
        trainer.dis_update(images_a, images_b, config)
        trainer.gen_update(images_a, images_b, config)
        torch.cuda.synchronize()
            
        trainer.update_learning_rate()

    print("This Iteration: %08d/%d" % (iterations, max_iter))
    
    # log stuff
    if iterations % config['image_save_iter'] == 0: # config['image_save_iter']

            print("Save Iteration: %08d/%d" % (iterations, max_iter))
            write_loss(iterations, trainer, train_writer)

            # Write images
            with torch.no_grad():
                test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
            write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

            # save images
            with torch.no_grad():
                image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'train_current')

            # Save network weights
            trainer.save(checkpoint_directory, iterations)
            
        



FileNotFoundError: [Errno 2] No such file or directory: 'munit/datasets/fundus_train.txt'

In [11]:
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

# vgg_preprocess needs a device

import os
import sys
import shutil

from munit.trainer import MUNIT_Trainer, UNIT_Trainer
from munit.utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer

import torch
import torch.nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import tensorboardX

config_file = 'munit/configs/img2img_list.yaml'
output_path = 'munit/results'
trainer = "MUNIT" # or UNIT

device = "cuda"
resume = False

# Load experiment setting
config = get_config(config_file)
max_iter = config['max_iter']
display_size = config['display_size']
config['vgg_model_path'] = output_path

# Setup model and data loader
trainer = MUNIT_Trainer(config, device)

trainer.to(device)

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(config_file))[0]
train_writer = tensorboardX.SummaryWriter(os.path.join(output_path + "/logs", model_name))
output_directory = "D:/cir plex"
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(config_file, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

# Start training
start_iter = trainer.resume(checkpoint_directory, hyperparameters=config) if resume else 0



            
        



# Inference

In [None]:
import os

In [5]:
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
# from __future__ import print_function
from munit.utils import get_config, pytorch03_to_pytorch04
from munit.trainer import MUNIT_Trainer, UNIT_Trainer
from torch.autograd import Variable
import torchvision.utils as vutils
import sys
import torch
import os
from torchvision import transforms
from PIL import Image

config_file = 'munit/configs/img2img_list.yaml'
output_folder = 'E:/plex_to_cirrus_munit'
trainer = "MUNIT" # or UNIT

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# Load experiment setting
config = get_config(config_file)
num_style = 1 # if opts.style != '' else opts.num_style

# Setup model and data loader
config['vgg_model_path'] = output_folder

style_dim = config['gen']['style_dim']
trainer = MUNIT_Trainer(config)

checkpoint = "MUNIT/results/ckpt/plex_cir/gen_00000201.pt"

try:
    state_dict = torch.load(checkpoint)
    trainer.gen_a.module.load_state_dict(state_dict['a'])
    trainer.gen_b.module.load_state_dict(state_dict['b'])
except:
    state_dict = pytorch03_to_pytorch04(torch.load(checkpoint), trainer)
    trainer.gen_a.module.load_state_dict(state_dict['a'])
    trainer.gen_b.module.load_state_dict(state_dict['b'])

trainer.cuda()
trainer.eval()
encode = trainer.gen_b.module.encode # if opts.a2b else trainer.gen_b.encode # encode function
style_encode = trainer.gen_a.module.encode # if opts.a2b else trainer.gen_a.encode # encode function
decode = trainer.gen_a.module.decode # if opts.a2b else trainer.gen_a.decode # decode function

new_size = 256 # config['new_size']

import glob
content_paths = glob.glob("E:/enface_plex/*")
style_paths = glob.glob("D:/Christina/Results_Data/CIRRUS_Glaucoma/enface/*_cube*")
import random

for image_id, content_path in enumerate(content_paths):

    with torch.no_grad():
        transform = transforms.Compose([transforms.Resize(new_size),
                                        transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        
        
        #content_path = "C:/snec_data/Result_Data/plex/enface/enface_SAMPLE_Angio (6mmx6mm)_7-31-2019_10-51-21_OD_sn0027_cube_z.tif"
        style_path = random.choice(style_paths) # "C:/snec_data/Result_Data/cirrus/enface/enface_PSNASERI0096_Angiography 6x6 mm_10-3-2018_14-42-42_OD_sn1080_cube_z.tif"
        #style = "E:/enface_PSNASERI0003_Angiography 6x6 mm_7-3-2018_10-24-16_OS_sn0029_cube_z.tif"
        #input_path = "C:/Users/Prinzessin/projects/image_data/iChallenge_AMD_OD_Fovea_lesions/ADAM_400_testset/V0001.jpg"
        # style = "C:/Users/Prinzessin/Documents/gray.png"

        

        print(content_path)

        content_image = Variable(transform(Image.open(content_path).convert('RGB')).unsqueeze(0).cuda()) 
        style_image = Variable(transform(Image.open(style_path).convert('RGB')).unsqueeze(0).cuda()) # if style != '' else None
        # Variable(transform(Image.open(style).convert('RGB')).unsqueeze(0).cuda()) if style != '' else None
        # Variable(transform(Image.open(input_path).convert('RGB')).unsqueeze(0).cuda())
        
        """
        device = "cuda"
        # Write images
        with torch.no_grad():

            s_a1 = Variable(torch.randn(1, 8, 1, 1).to(device))
            s_b1 = Variable(torch.randn(1, 8, 1, 1).to(device))
            s_a2 = Variable(torch.randn(x_a.size(0), 8, 1, 1).to(device))
            s_b2 = Variable(torch.randn(x_b.size(0), 8, 1, 1).to(device))
            x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
            for i in range(x_a.size(0)):
                c_a, s_a_fake = trainer.gen_a.module.encode(x_a[i].unsqueeze(0))
                c_b, s_b_fake = trainer.gen_b.module.encode(x_b[i].unsqueeze(0))
                x_a_recon.append(trainer.gen_a.module.decode(c_a, s_a_fake))
                x_b_recon.append(trainer.gen_b.module.decode(c_b, s_b_fake))
                x_ba1.append(trainer.gen_a.module.decode(c_b, s_a1[i].unsqueeze(0)))
                x_ba2.append(trainer.gen_a.module.decode(c_b, s_a2[i].unsqueeze(0)))
                x_ab1.append(trainer.gen_b.module.decode(c_a, s_b1[i].unsqueeze(0)))
                x_ab2.append(trainer.gen_b.module.decode(c_a, s_b2[i].unsqueeze(0)))
            x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
            x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
            x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

            vutils.save_image(x_ab1.data, os.path.join(output_folder, 'x_ab1.jpg'), padding=0, normalize=True)
            vutils.save_image(x_ba1.data, os.path.join(output_folder, 'x_ba1.jpg'), padding=0, normalize=True)
            #return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2
        """
        
        
        
        # Start testing
        content, _ = encode(content_image)

        #print(content.shape)



        

        #print(style_rand)
        #print(content.permute(2,3,1,0).squeeze()[:,:,5].shape)

        # vutils.save_image(content.permute(2,3,1,0).squeeze()[:,:,10].data, os.path.join(output_folder, 'a.jpg'), padding=0, normalize=True)

        if False: #style != '':
            _, style = style_encode(style_image)

        if True: # else:
            style_rand = Variable(torch.randn(num_style, style_dim, 1, 1).cuda())
            style = style_rand

        #print(style.shape)

        #print(style)

        #for j in range(num_style):
        j = 0
        s = style[j].unsqueeze(0)
        outputs = decode(content, s)
        
        print(outputs.shape)
        
        # outputs = (outputs + 1) / 2.
        path = os.path.join(output_folder, 'output_img{:03d}.jpg'.format(image_id))
        
        
        print(outputs.shape)
        
        invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.5, 1/0.5, 1/0.5 ]),
                                transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

        outputs = invTrans(outputs)
        
        
        
        vutils.save_image(outputs.data, path, padding=0, normalize=False)

        # also save input images
        # vutils.save_image(content_image.data, os.path.join(output_folder, 'input.jpg'), padding=0, normalize=True)
        


E:/enface_plex\PREMODEL_HTN_851_Angio (6mmx6mm)_4-1-2022_14-47-25_OS_sn1861_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_851_Angio (6mmx6mm)_4-1-2022_14-41-21_OD_sn1856_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_849_Angio (6mmx6mm)_3-9-2022_16-26-12_OS_sn1838_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_849_Angio (6mmx6mm)_3-9-2022_16-22-50_OD_sn1834_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_848_Angio (6mmx6mm)_3-9-2022_10-47-16_OS_sn1830_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_848_Angio (6mmx6mm)_3-9-2022_10-38-17_OD_sn1826_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])
E:/enface_plex\PREMODEL_HTN_846_Angio (6mmx6mm)_3-1-2022_15-2-56_OS_sn1803_cube_z.tif
torch.Size([1, 3, 256, 256])
torch.Size([1, 