In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# This project aims on converting one domain 256x256 RGB image to the other domain 256x256 RGB image.
# The model is same as GAN.ipynb, but the train method is different.
# The code is hardly based on adaddinpersson's project

# reference: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
#            https://github.com/aladdinpersson/Machine-Learning-Collection
#            https://github.com/yunjey/pytorch-tutorial



In [None]:
# read data from img256

from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np

class DirtyCleanDataset(Dataset):
    def __init__(self, root_dirty, root_clean, transform = None):
        self.root_dirty = root_dirty
        self.root_clean = root_clean
        self.transform = transform
        
        self.dirty_images = os.listdir(root_dirty)
        self.clean_images = os.listdir(root_clean)
        self.length_dataset = max(len(self.dirty_images),len(self.clean_images))
        self.clean_len = len(self.clean_images)
        self.dirty_len = len(self.dirty_images)
        
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        dimage = self.dirty_images[index % self.dirty_len]
        cimage = self.clean_images[index % self.clean_len]
        
        dpath = os.path.join(self.root_dirty, dimage)
        cpath = os.path.join(self.root_clean, cimage)
        
        dimage = np.array(Image.open(dpath).convert("RGB"))
        cimage = np.array(Image.open(cpath).convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image = dimage, image0 = cimage)
            dimage = augmentations["image"]
            cimage = augmentations["image0"]
            
        return dimage, cimage


In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.utils as vutils

dataroot = '/workspace/datasets/Kaggle'

image_size = 256
batch_size = 32
workers = 3

trans = A.Compose(
    [
        A.Resize(width = image_size, height = 256),
        A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5], max_pixel_value = 255),
        ToTensorV2(),
    ],
    additional_targets={"image0":"image"}
)

dataset = DirtyCleanDataset(root_dirty = dataroot+'/train', root_clean = dataroot+'/train_cleaned',transform=trans)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers = workers)

import matplotlib.pyplot as plt

real_batch = next(iter(dataloader))
# clean image
#plt.imshow(np.transpose(real_batch[1][30].cpu(),(1,2,0)))
# dirty image
#plt.imshow(np.transpose(real_batch[0][30].cpu(),(1,2,0)))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].cpu()[:32], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
real_batch[0][30].shape

In [None]:
# Build the model

import torch.nn as nn

loop = 5

class Discriminator(nn.Module):
    def __init__(self, input_channels, features_d):
        super(Discriminator,self).__init__()
        modules = []
        modules.append(nn.Conv2d(input_channels, features_d, kernel_size = 4, stride = 2, padding = 1))
        modules.append(nn.LeakyReLU(0.2))
        for i in range(loop):
            modules.append(self._conv(features_d * (2**i), features_d * (2**(i+1)),4,2,1))
        modules.append(nn.Conv2d(features_d * (2**loop), 1, kernel_size = 4, stride = 2, padding = 0))
        modules.append(nn.Sigmoid())
        self.disc = nn.Sequential(*modules)
    
    def _conv(self, in_channel, out_channel, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.LeakyReLU(0.2)
            )
    
    def forward(self, x):
        return self.disc(x)
    
    
class Generator(nn.Module):
    def __init__(self, input_channels, features_g):
        super(Generator, self).__init__()
        modules = []
        modules.append(self._dconv(input_channels,features_g,4,2,1))
        temp = features_g
        for i in range(loop):
            #modules.append(self._dconv(int(input_channels / (2**i)), int(features_g / (2**(i+1))),4,2,1))
            modules.append(self._dconv(temp, temp*2,4,2,1))
            temp = temp * 2
        for i in range(loop):
            #modules.append(self._uconv(input_channels * (2**i), features_g * (2**(i+1)),4,2,1))
            modules.append(self._uconv(int(temp), int(temp / 2),4,2,1))
            temp = temp/2
        modules.append(self._uconv(int(temp), input_channels,4,2,1))
        self.gen = nn.Sequential(*modules) 

        
    def _dconv(self, in_channel, out_channel, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2)
        )
    def _uconv(self, in_channel, out_channel, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2)
            )
    
    def forward(self,x):
        return self.gen(x)


# as for DCGAN    
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            m.init.normal_(m.weight.data, 0.0, 0.02)
            
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
def test():
    N, in_channels, H, W = 8, 3, 256, 256
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    #print(disc(x).shape)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(in_channels, 8)
    z = torch.randn((N, in_channels, H, W))
    #print(gen(z).shape)
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"


test()

In [None]:
from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)
D2C = Generator(input_channels = 3, features_g = 52).to(device)
D2C.apply(weights_init)
C2D = Generator(input_channels = 3, features_g = 52).to(device)
C2D.apply(weights_init)
summary(C2D,(3,256,256))

In [None]:
DD = Discriminator(input_channels = 3, features_d = 52).to(device)
DD.apply(weights_init)
DC = Discriminator(input_channels = 3, features_d = 52).to(device)
DC.apply(weights_init)
summary(DC,(3,256,256))

In [None]:
# read test image

Test = DirtyCleanDataset(root_dirty = dataroot+'/train', root_clean = dataroot+'/test',transform=trans)

Testloader = torch.utils.data.DataLoader(Test, batch_size=32, shuffle = True, num_workers = workers)

test_batch = next(iter(Testloader))


plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(test_batch[1].cpu(), padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# Train the model

from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Learning_rate = 1e-4
Channels_img = 3
Features_dis = 52
Feature_gen = 52
Num_epochs = 10
beta1 = 0.5
lambda_cycleloss = 0.1
lambda_identityloss = 0.1


opt_Gen = optim.Adam(list(D2C.parameters()) + list(C2D.parameters()), lr = Learning_rate, betas = (beta1, 0.999))
opt_Dis = optim.Adam(list(DD.parameters()) + list(DC.parameters()), lr = Learning_rate, betas = (beta1, 0.999))

Dis_scaler = torch.cuda.amp.GradScaler()
Gen_scaler = torch.cuda.amp.GradScaler()

criterion_ = nn.BCEWithLogitsLoss()
criterion1 = nn.MSELoss()
criterion2 = nn.L1Loss()

#writer_dirty = SummaryWriter(f"/workspace/practice/GAN/dirty")
#writer_clean = SummaryWriter(f"/workspace/practice/GAN/clean")

G_losses = []
D_losses = []
img_list = []
iters = 0

print("Starting Train: ...")

for epoch in range(Num_epochs):
    for batch_idx, (dirty,clean) in enumerate(dataloader, 0):
        clean = clean.to(device)
        dirty = dirty.to(device)

        #Train the Two Discriminator DC, DD:
        #DC : figure out ture clean and fake clean
        #DD : figure out ture dirty and fake dirty
        with torch.cuda.amp.autocast():
            
            #Loss DC:
            dirty_clean = D2C(dirty)
            disC_fake = DC(dirty_clean.detach())
            disC_real = DC(clean)
            disC_real_loss = criterion1(disC_real, torch.ones_like(disC_real))
            disC_fake_loss = criterion1(disC_fake, torch.zeros_like(disC_fake))
            DC_loss = disC_real_loss + disC_fake_loss
            
            #Loss DD:
            clean_dirty = C2D(clean)
            disD_fake = DD(clean_dirty.detach())
            disD_real = DD(dirty)
            disD_real_loss = criterion1(disD_real, torch.ones_like(disD_real))
            disD_fake_loss = criterion1(disD_fake, torch.zeros_like(disC_fake))
            DD_loss = disD_real_loss + disD_fake_loss
            
            Dis_loss = (DD_loss + DC_loss)/2
            
        opt_Dis.zero_grad()
        Dis_scaler.scale(Dis_loss).backward()
        Dis_scaler.step(opt_Dis)
        Dis_scaler.update()
        
        
        #Train the Two Generator D2C, C2D:
        with torch.cuda.amp.autocast():
            
            # adversial loss
            DisC_fake = DC(dirty_clean)
            DisD_fake = DD(clean_dirty)
            loss_D2C_DC = criterion1(DisC_fake, torch.ones_like(DisC_fake))
            loss_C2D_DD = criterion1(DisD_fake, torch.ones_like(DisD_fake))
        
            # cycle consistance loss
            re_dirty = C2D(dirty_clean)
            re_clean = D2C(clean_dirty)
            cycle_dirty_loss = criterion2(dirty, re_dirty)
            cycle_clean_loss = criterion2(clean, re_clean)
            
            # identity loss
            identity_dirty = C2D(dirty)
            identity_clean = D2C(clean)
            identity_dirty_loss = criterion2(dirty, identity_dirty)
            identity_clean_loss = criterion2(clean, identity_clean)
            
            Gen_loss = (loss_D2C_DC + loss_C2D_DD
                        + lambda_cycleloss* cycle_dirty_loss + lambda_cycleloss * cycle_clean_loss 
                        + lambda_identityloss * identity_dirty_loss + lambda_identityloss * identity_clean_loss)
            
        opt_Gen.zero_grad()
        Gen_scaler.scale(Gen_loss).backward()
        Gen_scaler.step(opt_Gen)
        Gen_scaler.update
            
        if batch_idx % 1 == 0:
            print(f"Epoch: [{epoch}/{Num_epochs}],  Batch: [{batch_idx}/{len(dataloader)}],  Loss_D: {Dis_loss:.4f},  Loss_G: {Gen_loss:.4f}")



In [None]:
torch.save(DD.state_dict(), './DD.pth')
torch.save(DC.state_dict(), './DC.pth')
torch.save(D2C.state_dict(), './D2C.pth')
torch.save(C2D.state_dict(), './C2D.pth')