In [37]:
from google.colab import drive
import os

drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab Notebooks

os.chdir('./OCTA_CycleGAN/')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Colab Notebooks


In [38]:
# Prepare package
!pip install visdom
!pip install lpips



In [39]:
# !git clone https://github.com/fbcotter/pytorch_wavelets
%cd ./pytorch_wavelets/
!pip install .
%cd ../

/content/drive/MyDrive/Colab Notebooks/OCTA_CycleGAN/pytorch_wavelets
Processing /content/drive/MyDrive/Colab Notebooks/OCTA_CycleGAN/pytorch_wavelets
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: pytorch-wavelets
  Building wheel for pytorch-wavelets (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-wavelets: filename=pytorch_wavelets-1.3.0-py3-none-any.whl size=54869 sha256=1bd8b90678ec3ae4d3d247ef5fa576573c109f499abeba87c73d3d569f5a9bfe
  Stored in directory: /tmp/pip-ephem-wheel-cache-_uu3q3ks/wheels/82/1f/1d/df88cea24a9de9a259b29c50aa658dd7e6ed94eb3b6b6d3152
Successfully built py

/content/drive/MyDrive/Colab Notebooks/OCTA_CycleGAN


In [40]:
import glob
import random
import os
from PIL import Image
import numpy as np
import time
import datetime
import sys

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from visdom import Visdom

import argparse
import itertools
import matplotlib.pyplot as plt

from pytorch_wavelets import DWTForward

import pdb
import skimage.metrics

from tqdm import tqdm

import lpips
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
# loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization

# import pytorch_fft.fft.autograd as fft

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/alex.pth


In [41]:
# Helper Functions
def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)


class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)


In [42]:
# Dataloader
class ImageDataset(Dataset):
    def __init__(self, root, transforms_A=None, transforms_B=None, unaligned=False, mode='train'):
        self.transformA = transforms.Compose(transforms_A)
        self.transformB = transforms.Compose(transforms_B)

        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, 'trainA') + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, 'trainB') + '/*.*'))

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % len(self.files_A)]).convert('L')
        item_A = self.transformA(img_A)

        if self.unaligned:
            item_B = self.transformB(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]).convert('L'))
        else:
            item_B = self.transformB(Image.open(self.files_B[index % len(self.files_B)]).convert('L'))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [43]:
class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc=1, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(Discriminator, self).__init__()
        use_bias = True
        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

In [44]:
def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
        nets (network list)   -- a list of networks
        requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

In [45]:
def save_sample(epoch, tensor, suffix="_real"):
    output = tensor.cpu().detach().numpy().squeeze(0).squeeze(0)
    plt.imsave('./checkpoint_baseline/image_'+str(epoch+1)+suffix+'.jpeg', output, cmap="gray")

In [46]:
#### Defination of local variables
input_nc = 1
output_nc = 1
batchSize = 1
size_A, size_B = 256, 256
lr = 2e-4
n_epochs, epoch, decay_epoch = 60, 0, 10
n_cpu = 2
dataroot = "./dataset/OCTA_baseline"
cuda = True

In [47]:
if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

In [48]:
class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc=1, output_nc=1, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet submodule with skip connections.
        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        use_bias = True
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)

In [49]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [50]:
###### Definition of variables ######
# Networks

netG_A2B = UnetGenerator()
netG_B2A = UnetGenerator()
netD_A = Discriminator()
netD_B = Discriminator()

if cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
# criterion_phase = phase_consistency_loss()
criterion_identity = torch.nn.L1Loss()


# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(netD_A.parameters(), netD_B.parameters()), lr=lr, betas=(0.5, 0.999))


lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batchSize, input_nc, size_A, size_A)
# input_B = Tensor(batchSize, output_nc, size_A, size_A)
input_B = Tensor(batchSize, output_nc, size_B, size_B)
target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_A = [ 
                transforms.ToTensor(),
                # transforms.Normalize((0.246), (0.170)),
                transforms.Normalize((0.5), (0.5)),
                # transforms.CenterCrop(size_A),
                transforms.RandomCrop((size_A, size_A))
                ]
                
transforms_B = [ 
                transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5)),
                # transforms.Normalize((0.286), (0.200)),
                # transforms.CenterCrop(size_B),
                transforms.RandomCrop((size_B, size_B))
                ]
dataset = ImageDataset(dataroot, transforms_A=transforms_A, transforms_B=transforms_B, unaligned=True)
print (len(dataset))
dataloader = DataLoader(dataset, batch_size=batchSize, shuffle=True)

# Loss plot
# logger = Logger(n_epochs, len(dataloader))
###################################


354


In [51]:
import warnings
warnings.filterwarnings('ignore')

In [52]:
lr_img = Image.open("./test/6x6_256/270_3.png").convert('L')
hr_img = Image.open("./test/3x3_256/270_6.png").convert('L')

T_1 = transforms.Compose([ transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5)),
                 ])
T_2 = transforms.Compose([ transforms.ToTensor(),                         
                transforms.Normalize((0.5), (0.5))])

lr_img = T_1(lr_img).cuda().unsqueeze(0)
hr_img = T_2(hr_img).cuda().unsqueeze(0)

In [53]:
def eval(model):
  lr = "./test/6x6_256/"
  hr = "./test/3x3_256/"
  num, psnr, ssim = 0, 0, 0
  T_1 = transforms.Compose([ transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5)),
                 ])
  T_2 = transforms.Compose([ transforms.ToTensor(),                         
                  transforms.Normalize((0.5), (0.5))])
  for i in tqdm(range(297)):
    lr_path = os.path.join(lr, str(i)+"_3.png")
    hr_path = os.path.join(hr, str(i)+"_6.png")
    if os.path.isfile(lr_path) and os.path.isfile(hr_path):
      lr_img = Image.open(lr_path).convert('L')
      hr_img = Image.open(hr_path).convert('L')
      
      lr_img = T_1(lr_img).cuda().unsqueeze(0)
      hr_img = T_2(hr_img).cuda().unsqueeze(0)
      
      sr_img = model(lr_img)

      yimg = sr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
      gtimg = hr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
      psnr += (skimage.metrics.peak_signal_noise_ratio(yimg, gtimg))
      ssim += (skimage.metrics.structural_similarity(yimg, gtimg))
      num += 1
  print(" PSNR: %.4f SSIM: %.4f"%(psnr/num, ssim/num))

In [54]:
###### Training ######
for epoch in range(epoch, n_epochs):
    real_out, fake_out = None, None
    for i, batch in enumerate(dataloader):
        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))

        ######### (1) forward #########
        fake_B = netG_A2B(real_A)
        recovered_A = netG_B2A(fake_B)
        fake_A = netG_B2A(real_B)
        recovered_B = netG_A2B(fake_A)


        ###### (2) G_A and G_B ######
        set_requires_grad([netD_A, netD_B], False)
        optimizer_G.zero_grad()

        pred_fake = netD_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        idt_A = netG_A2B(real_B)
        loss_idt_B = criterion_identity(idt_A, real_B) * 0.5

        idt_B = netG_B2A(real_A)
        loss_idt_A = criterion_identity(idt_B, real_A) * 0.5
        

        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0

        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB + loss_idt_B + loss_idt_A

        loss_G.backward()        
        optimizer_G.step()

        ###### (3) D_A and D_B ######
        set_requires_grad([netD_A, netD_B], True)
        optimizer_D.zero_grad()

        # Real loss
        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)
        # Fake loss
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = netD_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)
        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()


        # Real loss
        pred_real = netD_B(real_B)
        loss_D_real = criterion_GAN(pred_real, target_real)      
        # Fake loss
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)
        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()

        optimizer_D.step()
        
        ####################################
        ####################################

        if i == 1:
          x = real_A.detach()
          real_out = x
          fake_out = netG_A2B(x)
      
    save_sample(epoch, real_out, "_input")
    save_sample(epoch, fake_out, "_output")

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D.step()



    # Save models checkpoints
    # torch.save(LR_encoding.state_dict(), 'output/LR_encoding.pth')
    if epoch%5==4:
      torch.save(netG_A2B.state_dict(), './baseline_output/netG_A2B_epoch'+str(epoch+1)+'.pth')
    print("Epoch (%d/%d) Finished" % (epoch+1, n_epochs))

    sr_img = netG_A2B(lr_img)
    LPIPS = loss_fn_alex(hr_img.cpu(), sr_img.cpu())

    yimg = sr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
    hr_img_cpu = hr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
    psnr = skimage.metrics.peak_signal_noise_ratio(yimg, hr_img_cpu)
    ssim = skimage.metrics.structural_similarity(yimg, hr_img_cpu)

    print(("PSNR: %.4f SSIM: %.4f LPIPS:"%(psnr, ssim)), LPIPS.data)

    if epoch%3 == 0:
      eval(netG_A2B)

Epoch (1/60) Finished
PSNR: 16.3521 SSIM: 0.4478 LPIPS: tensor([[[[0.2085]]]])


100%|██████████| 297/297 [00:12<00:00, 24.01it/s]


 PSNR: 16.0952 SSIM: 0.4495
Epoch (2/60) Finished
PSNR: 17.3119 SSIM: 0.4746 LPIPS: tensor([[[[0.1982]]]])
Epoch (3/60) Finished
PSNR: 17.2580 SSIM: 0.4742 LPIPS: tensor([[[[0.2105]]]])
Epoch (4/60) Finished
PSNR: 17.6370 SSIM: 0.4990 LPIPS: tensor([[[[0.1886]]]])


100%|██████████| 297/297 [00:12<00:00, 24.13it/s]


 PSNR: 16.9378 SSIM: 0.4729
Epoch (5/60) Finished
PSNR: 17.7875 SSIM: 0.4879 LPIPS: tensor([[[[0.1951]]]])
Epoch (6/60) Finished
PSNR: 17.6590 SSIM: 0.4904 LPIPS: tensor([[[[0.1823]]]])
Epoch (7/60) Finished
PSNR: 17.6848 SSIM: 0.4985 LPIPS: tensor([[[[0.1873]]]])


100%|██████████| 297/297 [00:12<00:00, 23.84it/s]


 PSNR: 16.9839 SSIM: 0.4596
Epoch (8/60) Finished
PSNR: 17.8293 SSIM: 0.5048 LPIPS: tensor([[[[0.1838]]]])
Epoch (9/60) Finished
PSNR: 17.8088 SSIM: 0.5114 LPIPS: tensor([[[[0.1891]]]])
Epoch (10/60) Finished
PSNR: 18.4365 SSIM: 0.5279 LPIPS: tensor([[[[0.1870]]]])


100%|██████████| 297/297 [00:12<00:00, 22.93it/s]


 PSNR: 17.3784 SSIM: 0.4840
Epoch (11/60) Finished
PSNR: 18.3157 SSIM: 0.5262 LPIPS: tensor([[[[0.1810]]]])
Epoch (12/60) Finished
PSNR: 18.3677 SSIM: 0.5257 LPIPS: tensor([[[[0.1930]]]])
Epoch (13/60) Finished
PSNR: 18.0692 SSIM: 0.5073 LPIPS: tensor([[[[0.1977]]]])


100%|██████████| 297/297 [00:12<00:00, 24.00it/s]


 PSNR: 17.0293 SSIM: 0.4564
Epoch (14/60) Finished
PSNR: 18.3762 SSIM: 0.5249 LPIPS: tensor([[[[0.1926]]]])
Epoch (15/60) Finished
PSNR: 18.1685 SSIM: 0.5269 LPIPS: tensor([[[[0.1952]]]])
Epoch (16/60) Finished
PSNR: 18.4130 SSIM: 0.5280 LPIPS: tensor([[[[0.1932]]]])


100%|██████████| 297/297 [00:12<00:00, 24.22it/s]


 PSNR: 17.3559 SSIM: 0.4778
Epoch (17/60) Finished
PSNR: 18.3331 SSIM: 0.5227 LPIPS: tensor([[[[0.1881]]]])
Epoch (18/60) Finished
PSNR: 18.4692 SSIM: 0.5246 LPIPS: tensor([[[[0.1923]]]])
Epoch (19/60) Finished
PSNR: 18.3486 SSIM: 0.5221 LPIPS: tensor([[[[0.1991]]]])


100%|██████████| 297/297 [00:12<00:00, 23.96it/s]


 PSNR: 17.2584 SSIM: 0.4755
Epoch (20/60) Finished
PSNR: 18.3981 SSIM: 0.5186 LPIPS: tensor([[[[0.1990]]]])
Epoch (21/60) Finished
PSNR: 18.0214 SSIM: 0.5022 LPIPS: tensor([[[[0.2034]]]])
Epoch (22/60) Finished
PSNR: 18.0872 SSIM: 0.5014 LPIPS: tensor([[[[0.1894]]]])


100%|██████████| 297/297 [00:12<00:00, 24.16it/s]


 PSNR: 17.0334 SSIM: 0.4530
Epoch (23/60) Finished
PSNR: 17.9029 SSIM: 0.4913 LPIPS: tensor([[[[0.1831]]]])
Epoch (24/60) Finished
PSNR: 17.6729 SSIM: 0.4809 LPIPS: tensor([[[[0.1863]]]])
Epoch (25/60) Finished
PSNR: 17.6154 SSIM: 0.4809 LPIPS: tensor([[[[0.1840]]]])


100%|██████████| 297/297 [00:12<00:00, 22.99it/s]


 PSNR: 16.6225 SSIM: 0.4332
Epoch (26/60) Finished
PSNR: 17.5479 SSIM: 0.4751 LPIPS: tensor([[[[0.1826]]]])
Epoch (27/60) Finished
PSNR: 17.8650 SSIM: 0.4813 LPIPS: tensor([[[[0.1902]]]])
Epoch (28/60) Finished
PSNR: 17.6524 SSIM: 0.4703 LPIPS: tensor([[[[0.2012]]]])


100%|██████████| 297/297 [00:12<00:00, 24.25it/s]


 PSNR: 16.6630 SSIM: 0.4296
Epoch (29/60) Finished
PSNR: 17.5750 SSIM: 0.4661 LPIPS: tensor([[[[0.1875]]]])
Epoch (30/60) Finished
PSNR: 17.7025 SSIM: 0.4770 LPIPS: tensor([[[[0.1666]]]])
Epoch (31/60) Finished
PSNR: 17.5422 SSIM: 0.4770 LPIPS: tensor([[[[0.1666]]]])


100%|██████████| 297/297 [00:12<00:00, 24.16it/s]


 PSNR: 16.7556 SSIM: 0.4436
Epoch (32/60) Finished
PSNR: 17.9343 SSIM: 0.4966 LPIPS: tensor([[[[0.1756]]]])
Epoch (33/60) Finished
PSNR: 17.7038 SSIM: 0.4819 LPIPS: tensor([[[[0.1795]]]])
Epoch (34/60) Finished
PSNR: 18.5259 SSIM: 0.5351 LPIPS: tensor([[[[0.2235]]]])


100%|██████████| 297/297 [00:12<00:00, 24.15it/s]


 PSNR: 17.3078 SSIM: 0.4840
Epoch (35/60) Finished
PSNR: 18.3724 SSIM: 0.5277 LPIPS: tensor([[[[0.1839]]]])
Epoch (36/60) Finished
PSNR: 17.9733 SSIM: 0.4984 LPIPS: tensor([[[[0.1715]]]])
Epoch (37/60) Finished
PSNR: 17.7283 SSIM: 0.4811 LPIPS: tensor([[[[0.1798]]]])


100%|██████████| 297/297 [00:12<00:00, 24.24it/s]


 PSNR: 16.9038 SSIM: 0.4482
Epoch (38/60) Finished
PSNR: 17.5662 SSIM: 0.4813 LPIPS: tensor([[[[0.1670]]]])
Epoch (39/60) Finished
PSNR: 17.5300 SSIM: 0.4636 LPIPS: tensor([[[[0.1628]]]])
Epoch (40/60) Finished
PSNR: 17.6465 SSIM: 0.4770 LPIPS: tensor([[[[0.1682]]]])


100%|██████████| 297/297 [00:12<00:00, 23.10it/s]


 PSNR: 16.9249 SSIM: 0.4538
Epoch (41/60) Finished
PSNR: 17.6824 SSIM: 0.4785 LPIPS: tensor([[[[0.1711]]]])
Epoch (42/60) Finished
PSNR: 17.9058 SSIM: 0.4964 LPIPS: tensor([[[[0.1729]]]])
Epoch (43/60) Finished
PSNR: 17.6093 SSIM: 0.4794 LPIPS: tensor([[[[0.1600]]]])


100%|██████████| 297/297 [00:12<00:00, 23.87it/s]


 PSNR: 16.9699 SSIM: 0.4588
Epoch (44/60) Finished
PSNR: 17.3982 SSIM: 0.4667 LPIPS: tensor([[[[0.1670]]]])
Epoch (45/60) Finished
PSNR: 17.4477 SSIM: 0.4645 LPIPS: tensor([[[[0.1781]]]])
Epoch (46/60) Finished
PSNR: 17.5914 SSIM: 0.4729 LPIPS: tensor([[[[0.1645]]]])


100%|██████████| 297/297 [00:12<00:00, 23.81it/s]


 PSNR: 16.8449 SSIM: 0.4486
Epoch (47/60) Finished
PSNR: 17.4547 SSIM: 0.4716 LPIPS: tensor([[[[0.1745]]]])
Epoch (48/60) Finished
PSNR: 17.6427 SSIM: 0.4797 LPIPS: tensor([[[[0.1758]]]])
Epoch (49/60) Finished
PSNR: 17.5663 SSIM: 0.4730 LPIPS: tensor([[[[0.1823]]]])


100%|██████████| 297/297 [00:12<00:00, 23.57it/s]


 PSNR: 16.8798 SSIM: 0.4503
Epoch (50/60) Finished
PSNR: 17.5518 SSIM: 0.4726 LPIPS: tensor([[[[0.1730]]]])
Epoch (51/60) Finished
PSNR: 17.7508 SSIM: 0.4827 LPIPS: tensor([[[[0.1658]]]])
Epoch (52/60) Finished
PSNR: 17.9432 SSIM: 0.4964 LPIPS: tensor([[[[0.1673]]]])


100%|██████████| 297/297 [00:12<00:00, 23.56it/s]


 PSNR: 16.9190 SSIM: 0.4546
Epoch (53/60) Finished
PSNR: 17.7188 SSIM: 0.4782 LPIPS: tensor([[[[0.1713]]]])
Epoch (54/60) Finished
PSNR: 17.6524 SSIM: 0.4759 LPIPS: tensor([[[[0.1703]]]])
Epoch (55/60) Finished
PSNR: 17.6841 SSIM: 0.4793 LPIPS: tensor([[[[0.1663]]]])


100%|██████████| 297/297 [00:12<00:00, 22.89it/s]


 PSNR: 16.8258 SSIM: 0.4489
Epoch (56/60) Finished
PSNR: 17.6535 SSIM: 0.4739 LPIPS: tensor([[[[0.1696]]]])
Epoch (57/60) Finished
PSNR: 17.4588 SSIM: 0.4653 LPIPS: tensor([[[[0.1683]]]])
Epoch (58/60) Finished
PSNR: 17.6741 SSIM: 0.4783 LPIPS: tensor([[[[0.1670]]]])


100%|██████████| 297/297 [00:12<00:00, 23.66it/s]


 PSNR: 16.9675 SSIM: 0.4577
Epoch (59/60) Finished
PSNR: 17.6016 SSIM: 0.4751 LPIPS: tensor([[[[0.1703]]]])
Epoch (60/60) Finished
PSNR: 17.4536 SSIM: 0.4667 LPIPS: tensor([[[[0.1711]]]])


In [55]:
def result_save_sample(epoch, tensor=None, suffix="_real", img=None, img_mode=False):
    if tensor != None:
      output = tensor.cpu().detach().numpy().squeeze(0).squeeze(0)
      plt.imsave('./results/image_baseline_'+str(epoch)+suffix+'.jpeg', output, cmap="gray")
    if img_mode:
      plt.imsave('./results/image_baseline_'+str(epoch)+suffix+'.jpeg', img, cmap="gray")

In [56]:
netG_A2B = torch.load('./baseline_output/netG_A2B_epoch40.pth')
type(netG_A2B)
model = UnetGeneratorA2B(output_nc, input_nc).cuda()
model.load_state_dict(netG_A2B, strict=False)

NameError: ignored

In [None]:
img = dataset[0]['A']
x = img.unsqueeze(0).cuda()
plt.imshow(img.squeeze(0), "gray")
result_save_sample(1, tensor=x, suffix="_input")

In [None]:
_, y = model(x)
yimg = y.cpu().detach().numpy().squeeze(0).squeeze(0)
plt.imshow(yimg, "gray")
result_save_sample(1, tensor=y, suffix="_output")

In [None]:
import cv2
upsample = cv2.resize(img.squeeze(0).cpu().numpy(), dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
upsample.shape
result_save_sample(1, img_mode=True, img=upsample, suffix="_interpolation")

In [None]:
import numpy as np
import cv2
from matplotlib import pyplot as plt

img = y.cpu().detach().numpy().squeeze(0).squeeze(0)
# img = y.cpu().detach().numpy().squeeze(0).squeeze(0)
f = np.fft.fft2(img, axes=(-2, -1))
fshift = np.fft.fftshift(f)
res = np.log(np.abs(fshift))
pha = np.angle(fshift)
plt.figure(figsize=(11, 11))
plt.subplot(331), plt.imshow(img, 'gray'), plt.title('Original Image')
plt.axis('off')
plt.subplot(332), plt.imshow(res, 'gray'), plt.title('Fourier Amplitude')
plt.axis('off')
plt.subplot(333), plt.imshow(pha, 'gray'), plt.title('Fourier Phase')
plt.axis('off')

In [None]:
img = dataset[0]['A']
x = img.unsqueeze(0).cuda()
plt.imshow(img.squeeze(0), "gray")
_, y = model(x)

In [None]:
yimg = y.cpu().detach().numpy().squeeze(0).squeeze(0)
plt.imshow(yimg, "gray")

In [None]:
class Test_ImageDataset(Dataset):
    def __init__(self, root, transforms_test=None, unaligned=True, mode='test'):
        self.transformA = transforms.Compose(transforms_test)
        self.transformB = transforms.Compose(transforms_test)

        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, '6x6_256/') + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '3x3_256/') + '/*.*'))

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % len(self.files_A)]).convert('L')
        item_A = self.transformA(img_A)

        if self.unaligned:
            item_B = self.transformB(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]).convert('L'))
        else:
            item_B = self.transformB(Image.open(self.files_B[index % len(self.files_B)]).convert('L'))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
test_path = "./test/"
transforms_test = [ 
                transforms.ToTensor(),
                # transforms.Normalize((0.246), (0.170)),
                transforms.Normalize((0.5), (0.5)) ]
test_dataset = Test_ImageDataset(test_path, transforms_test=transforms_test, unaligned=True)

In [None]:
import cv2
img = test_dataset[2]['A']
img = cv2.resize(img.squeeze(0).cpu().numpy(), dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
x = torch.tensor(img).unsqueeze(0).unsqueeze(0).cuda()
# plt.imshow(img.squeeze(0), "gray")
_, y = model(x)

In [None]:
plt.imshow(img, "gray")

In [None]:
yimg = y.cpu().detach().numpy().squeeze(0).squeeze(0)
plt.imshow(yimg, "gray")

In [None]:
f = np.fft.fft2(yimg, axes=(-2, -1))
fshift = np.fft.fftshift(f)
res = np.log(np.abs(fshift))
pha = np.angle(fshift)
plt.figure(figsize=(11, 11))
plt.subplot(331), plt.imshow(yimg, 'gray'), plt.title('Original Image')
plt.axis('off')
plt.subplot(332), plt.imshow(res, 'gray'), plt.title('Fourier Amplitude')
plt.axis('off')
plt.subplot(333), plt.imshow(pha, 'gray'), plt.title('Fourier Phase')
plt.axis('off')

Pick a pair of test data to evaluate

In [None]:
netG_A2B = torch.load('./baseline_output/netG_A2B_epoch40.pth')
type(netG_A2B)
model = UnetGeneratorA2B(output_nc, input_nc).cuda()
model.load_state_dict(netG_A2B, strict=False)

In [None]:
lr_img = Image.open("./test/6x6_256/270_3.png").convert('L')
hr_img = Image.open("./test/3x3_256/270_6.png").convert('L')
# lr_img = Image.open("./dataset/Colab_centered_OCTA/trainA/STDR403_20181029_101618_Angio (1)_R_001.png").convert('L')
# hr_img = Image.open("./dataset/Colab_centered_OCTA/trainB/STDR403_20181029_101802_Angio (1)_R_001.png").convert('L')
T_1 = transforms.Compose([ transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5)),
                transforms.Resize([128, 128]) ])
T_2 = transforms.Compose([ transforms.ToTensor(),
                transforms.Normalize((0.5), (0.5))])
# lr_img = cv2.resize(np.array(lr_img), dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
# lr_img = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).cuda()
lr_img = T_1(lr_img).cuda().unsqueeze(0)
hr_img = T_2(hr_img).cuda().unsqueeze(0)
# lr_img.size()
_, sr_img = model(lr_img)

In [None]:
ximg = lr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
print(ximg.shape)
plt.imshow(ximg, "gray")

In [None]:
yimg = sr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
plt.imshow(yimg, "gray")

In [None]:
gtimg = hr_img.cpu().detach().numpy().squeeze(0).squeeze(0)
plt.imshow(gtimg, "gray")

In [None]:
import skimage.metrics
print(skimage.metrics.peak_signal_noise_ratio(yimg, gtimg))
print(skimage.metrics.structural_similarity(yimg, gtimg))

# Result

In [None]:
import skimage.metrics
print(skimage.metrics.peak_signal_noise_ratio(yimg, gtimg))
print(skimage.metrics.structural_similarity(yimg, gtimg))

In [None]:
lr_img = Image.open("./test/6x6_256/270_3.png").convert('L')
hr_img = Image.open("./test/3x3_256/270_6.png").convert('L')
lr_img = T_1(lr_img)
hr_img = T_2(hr_img)

In [None]:
input = lr_img.cuda().unsqueeze(0)
_, output = model(input)
# output = output.cpu().detach().numpy().squeeze(0).squeeze(0)

In [None]:
d = loss_fn_alex(hr_img.cpu(), output.cpu())
print(d)

In [None]:
x = F.interpolate(input, scale_factor=2, mode='nearest')
d = loss_fn_alex(hr_img.cpu(), x.cpu())
print(d)