In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch

##############################
#           U-NET
##############################


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(Generator, self).__init__()

        self.down1 = UNetDown(in_channels, 16)
        self.down2 = UNetDown(16, 32)
        self.down3 = UNetDown(32, 64)
        self.down4 = UNetDown(64, 128)
        self.down5 = UNetDown(128, 128)
        self.down6 = UNetDown(128, 256)
        self.up1 = UNetUp(256, 128)
        self.up2 = UNetUp(256, 128)
        self.up3 = UNetUp(256, 64)
        self.up4 = UNetUp(128, 32)
        self.up5 = UNetUp(64, 16)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, out_channels, 3, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        u1 = self.up1(d6, d5)
        u2 = self.up2(u1, d4)
        u3 = self.up3(u2, d3)
        u4 = self.up4(u3, d2)
        u5 = self.up5(u4, d1)

        return self.final(u5)

In [None]:
import torch
from torch import nn
class Discriminator(nn.Module):

    #discriminator model
    def __init__(self, in_channel=1):
        super(Discriminator,self).__init__()
        
        self.t1=nn.Sequential(
            nn.Conv2d(in_channels=in_channel,out_channels=32,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2)
        )
        
        self.t2=nn.Sequential(
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        
        self.t3=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.t4=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.t5=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=1,kernel_size=(3,3),stride=1,padding=1)
        )

        
    
    def forward(self, x):
        x=self.t1(x)
        x=self.t2(x)
        x=self.t3(x)
        x=self.t4(x)
        x=self.t5(x)
        return x #output of discriminator

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
import os
import numpy as np
import time
import datetime
import sys
from PIL import Image
from skimage import io, img_as_ubyte, exposure
from glob import glob

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import Dataset
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

epoch=0
n_epochs=200
dataset_name='HeLa_3D'
dataset_folder="dataset/train"
val_folder="dataset/val"
batch_size=8
lr=0.0002
b1=0.5
b2=0.999
decay_epoch=100
n_cpu=8
img_size=256 
channels=3
sample_interval=200
checkpoint_interval=5
num_critic=50

os.makedirs("images/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 0.95

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_size // 2 ** 4, img_size // 2 ** 4)

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

if epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (dataset_name, epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (dataset_name, epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))


class ImageDataset(Dataset):
    def __init__(self, root, transform=None, re_norm=True):
        self.re_norm = re_norm
        self.transform = transforms.Compose(transform)
        self.files = sorted(glob(root+"/*.*"))

    def __getitem__(self, index):
        img = io.imread(self.files[index])
        if self.re_norm:
            img = exposure.rescale_intensity(img, in_range=(np.percentile(img, 0), np.percentile(img, 100)), out_range=(0, 1))
        img = Image.fromarray(img_as_ubyte(img))
        img= self.transform(img)
        r_low = 0.02
        r_high = 0.05
        img_erased = transforms.RandomErasing(p=1, scale=(r_low, r_high), ratio=(0.25, 2))(img)
        img_erased = transforms.RandomErasing(p=1, scale=(r_low, r_high), ratio=(0.25, 2))(img_erased)
        img_erased = transforms.RandomErasing(p=1, scale=(r_low, r_high), ratio=(0.25, 2))(img_erased)
        # img_erased = transforms.RandomErasing(p=1, scale=(r_low, r_high), ratio=(0.25, 2))(img_erased)
        # img_erased = transforms.RandomErasing(p=1, scale=(r_low, r_high), ratio=(0.25, 2))(img_erased)
        return img, img_erased

    def __len__(self):
        return len(self.files)

# Configure dataloaders
transform = [
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomAffine(180, scale=(0.75, 1.5), shear=45),
    transforms.RandomCrop(img_size),
    transforms.ToTensor()]
dataset = ImageDataset(root=dataset_folder, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)

val_transform = [
    transforms.Grayscale(num_output_channels=1),
    # transforms.RandomAffine(180, scale=(0.75, 1.5), shear=45),
    transforms.CenterCrop(img_size),
    transforms.ToTensor()]
val_dataset = ImageDataset(root=val_folder, transform=val_transform)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs[1].type(Tensor))
    ### centercrop
    # real_A[:, :, int(img_size/4):int(3*img_size/4), int(img_size/4):int(3*img_size/4)] = 0
    real_B = Variable(imgs[0].type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, "images/%s/%s.png" % (dataset_name, batches_done), nrow=4, normalize=True)


# ----------
#  Training
# ----------

prev_time = time.time()

for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = Variable(batch[1].type(Tensor)) # input
        ### center crop
        # real_A[:, :, int(img_size/4):int(3*img_size/4), int(img_size/4):int(3*img_size/4)] = 0
        ### origial
        real_B = Variable(batch[0].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = (1-lambda_pixel)*loss_GAN + lambda_pixel*loss_pixel

        loss_G.backward()

        optimizer_G.step()

        if i%num_critic==0:

        # ---------------------
        #  Train Discriminator
        # ---------------------

            optimizer_D.zero_grad()

            # Real loss
            pred_real = discriminator(real_B)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach())
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            loss_D.backward()
            optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (dataset_name, epoch))
        torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (dataset_name, epoch))