<a href="https://colab.research.google.com/github/Shinyrose-A/MSc-project/blob/main/Cycle_Gan_rgbd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## packages and parameters

In [None]:
!pip install albumentations==0.4.6

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting albumentations==0.4.6
  Downloading albumentations-0.4.6.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 7.0 MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-py3-none-any.whl size=65174 sha256=896b06b312330bc00f0d3e6e1e87d7d927904b8c2ba0566a84140c74a197a588
  Stored in directory: /root/.cache/pip/wheels/cf/34/0f/cb2a5f93561a181a4bcc84847ad6aaceea8b5a3127469616cc
Successfully built albumentations
Installing collected packages: albumentations
  Attempting uninstall: albumentations
    Found existing installation: albumentations 1.2.1
    Uninstalling albumentations-1.2.1:
      Successfully uninstalled albumentations-1.2.1
Successfully installed albumentations-0.4.6


In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import shutil
from PIL import Image
import sys
from torchvision.datasets import ImageFolder
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 20
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_S = "/content/drive/MyDrive/Project/data/checkpoints/gens.pth.tar"
CHECKPOINT_GEN_R = "/content/drive/MyDrive/Project/data/checkpoints/genr.pth.tar"
CHECKPOINT_CRITIC_S = "/content/drive/MyDrive/Project/data/checkpoints/critics.pth.tar"
CHECKPOINT_CRITIC_R = "/content/drive/MyDrive/Project/data/checkpoints/criticr.pth.tar"

TRANSFORMS_rgbd = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
)

TRANSFORMS = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
)

cuda


In [None]:
def save_checkpoint(model, optimizer, PATH="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, PATH)


def load_checkpoint(PATH, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])


    optimizer.param_groups[0]['capturable'] = True
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

## Discriminator

dense resnet

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            # kernel size = 4, padding = 1
            nn.Conv2d(in_channels, out_channels, 4,stride ,1 , bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self,x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            # stride = 1 for last one and 2 for first 3
            layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
        # the out_channels is 1, since output 0 or 1 to indicate true or fake
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

## Generator

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs) 
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()

        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class Generator(nn.Module):
    # num_residuals can be 6 or 9
    def __init__(self, in_channels, out_channels, num_features = 64, num_residuals=18):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True),
        )

        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self,x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
def test():
    in_channels = 4
    out_channels = 3
    img_size = 256
    x = torch.randn((2, in_channels, img_size, img_size))
    gen = Generator(in_channels, out_channels)
    #print(gen)
    print(gen(x).shape)

test()

## Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Generate rgbd image

In [None]:
def gen_rgbd(rgb_path, depth_path):
    rgb_path_list = os.listdir(rgb_path)
    rgb_path_list.sort()
    depth_path_list = os.listdir(depth_path)
    depth_path_list.sort()

    number=1
    for i, name in enumerate(rgb_path_list):
        rgb_path_single = rgb_path+'/'+rgb_path_list[i]
        depth_path_single = depth_path+'/'+depth_path_list[i]
        #print(rgb_path_single)
        #print(depth_path_single)
        
        # actually rgba(red, green, blue, alpha), so get first three channels
        rgb = cv2.imread(rgb_path_single, cv2.IMREAD_UNCHANGED)
        depth = cv2.imread(depth_path_single, cv2.IMREAD_UNCHANGED)
        
        rgb_array = np.array(rgb)
        depth_array = np.array(depth)
        #print(rgb_array.shape)
        #print(depth_array.shape)

        rgbd = np.zeros((256,256,4),dtype=np.uint8)
        rgbd[:, :, 0] = rgb[:, :, 0]
        rgbd[:, :, 1] = rgb[:, :, 1]
        rgbd[:, :, 2] = rgb[:, :, 2]
        rgbd[:, :, 3] = depth

        #print(rgbd)

        img_name = str(number)+'.png'
        number+=1
        save_path = "/content/drive/MyDrive/Project/data/rgbd"
        if os.path.exists(save_path):
            '''调用cv.2的imwrite函数保存图片'''
            save_img = save_path + '/' +img_name
            cv2.imwrite(save_img, rgbd)
        else:
            os.mkdir(save_path)
            save_img = save_path + '/' +img_name
            cv2.imwrite(save_img, rgbd)
        




In [None]:
rgb_path = "/content/drive/MyDrive/Project/data/syntheic"
depth_path = "/content/drive/MyDrive/Project/data/depth map"
#gen_rgbd(rgb_path, depth_path)

In [None]:
def getFileNames(rootDir):
    fileNames = []
    # 利用os.walk()函数获取根目录下文件夹名称，子文件夹名称及文件名称
    for dirName, subDirList, fileList in os.walk(rootDir):
        for fname in fileList:
            # 用os.path.split()函数来判断并获取文件的后缀名
            if os.path.splitext(fname)[1] == '.png':
                fileNames.append(dirName+'/'+fname)
    return fileNames

In [None]:
path = '/content/drive/MyDrive/Project/data/rgbd'
img_path = getFileNames(path)
print(len(img_path))
#img_path.sort()
print(img_path)

5460
['/content/drive/MyDrive/Project/data/rgbd/271.png', '/content/drive/MyDrive/Project/data/rgbd/272.png', '/content/drive/MyDrive/Project/data/rgbd/273.png', '/content/drive/MyDrive/Project/data/rgbd/274.png', '/content/drive/MyDrive/Project/data/rgbd/275.png', '/content/drive/MyDrive/Project/data/rgbd/276.png', '/content/drive/MyDrive/Project/data/rgbd/277.png', '/content/drive/MyDrive/Project/data/rgbd/278.png', '/content/drive/MyDrive/Project/data/rgbd/279.png', '/content/drive/MyDrive/Project/data/rgbd/280.png', '/content/drive/MyDrive/Project/data/rgbd/281.png', '/content/drive/MyDrive/Project/data/rgbd/282.png', '/content/drive/MyDrive/Project/data/rgbd/283.png', '/content/drive/MyDrive/Project/data/rgbd/284.png', '/content/drive/MyDrive/Project/data/rgbd/286.png', '/content/drive/MyDrive/Project/data/rgbd/287.png', '/content/drive/MyDrive/Project/data/rgbd/285.png', '/content/drive/MyDrive/Project/data/rgbd/288.png', '/content/drive/MyDrive/Project/data/rgbd/289.png', '/cont

### Load images

In [None]:
class gen_dataset(Dataset):
    def __init__(self, root_rgbd, root_real, transform=None, transform_rgbd=None):
        self.root_rgbd = root_rgbd
        self.root_real = root_real
        self.transform = transform
        self.transform_rgbd = transform_rgbd

        # to make os.listdir not shuffle
        #root_syntheic = os.getcwd()

        self.rgbd_images = os.listdir(root_rgbd)
        self.real_images = os.listdir(root_real)


        self.length_dataset = max(len(self.rgbd_images), len(self.real_images))
        self.rgbd_length = len(self.rgbd_images)
        self.real_length = len(self.real_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rgbd_img = self.rgbd_images[index % self.rgbd_length]
        real_img = self.real_images[index % self.real_length]

        rgbd_path = os.path.join(self.root_rgbd, rgbd_img)
        real_path = os.path.join(self.root_real, real_img)

        rgbd_img = np.array(Image.open(rgbd_path))
        real_img = np.array(Image.open(real_path).convert("RGB"))
        #real_img = real_img.permute(2,0,1)

        if self.transform:

            real_img = self.transform(image=real_img)["image"]
        if self.transform_rgbd:
            rgbd_img = self.transform_rgbd(image=rgbd_img)["image"]

        
        return rgbd_img, real_img

In [None]:
dataset = gen_dataset(root_rgbd = "/content/drive/MyDrive/Project/data/rgbd",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized",
                      transform=TRANSFORMS,
                      transform_rgbd=TRANSFORMS_rgbd)

In [None]:
loader = DataLoader(dataset, batch_size = 10, shuffle=True, num_workers=2)

## Train

In [None]:
def train_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, epoch):
    D_loss_all = 0
    G_loss_all = 0
    loop = tqdm(loader, leave=True)

    for idx, (rgbd_img, real_img) in enumerate(loop):
        #syntheic_img = syntheic_img.permute(0,3,1,2)
        #depth_img = depth_img.permute(0,3,1,2)
        #real_img = real_img.permute(0,3,1,2)

        #input_img = syntheic_img
        input_img = rgbd_img
        #real_img = torch.cat([real_img, depth_img],1)
        input_img = input_img.float()
        real_img = real_img.float()
        
        input_img = input_img.to(DEVICE)
        real_img = real_img.to(DEVICE)

        
        # Train Discriminators
        with torch.cuda.amp.autocast():
            # Disc R
            fake_R = gen_R(input_img)
            D_R_real = disc_R(real_img)
            D_R_fake = disc_R(fake_R.detach())

            #R_reals += D_R_real.mean().item()
            #R_fakes += D_R_fake.mean().item()

            D_R_real_loss = mse(D_R_real, torch.ones_like(D_R_real))
            D_R_fake_loss = mse(D_R_fake, torch.zeros_like(D_R_fake))
            D_R_loss = D_R_real_loss + D_R_fake_loss

            # Disc S
            fake_S = gen_S(real_img)
            D_S_real = disc_S(input_img)
            D_S_fake = disc_S(fake_S.detach())

            D_S_real_loss = mse(D_S_real, torch.ones_like(D_S_real))
            D_S_fake_loss = mse(D_S_fake, torch.zeros_like(D_S_fake))
            D_S_loss = D_S_real_loss + D_S_fake_loss

            # put it togethor
            D_loss = (D_R_loss + D_S_loss)/2
            D_loss_all += D_loss
            #print('\n'+str(D_loss.item()))

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_R_fake = disc_R(fake_R)
            D_S_fake = disc_S(fake_S)
            loss_G_R = mse(D_R_fake, torch.ones_like(D_R_fake))
            loss_G_S = mse(D_S_fake, torch.ones_like(D_S_fake))

            # cycle loss
            cycle_S = gen_S(fake_R)
            cycle_R = gen_R(fake_S)
            cycle_S_loss = l1(input_img, cycle_S)
            cycle_R_loss = l1(real_img, cycle_R)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_S = 0 #gen_S(input_img)
            identity_R = 0 #gen_R(real_img)
            identity_S_loss = 0 #l1(input_img, identity_S)
            identity_R_loss = 0 #l1(real_img, identity_R)

            # add all togethor
            G_loss = (
                loss_G_S
                + loss_G_R
                + cycle_S_loss * LAMBDA_CYCLE
                + cycle_R_loss * LAMBDA_CYCLE
                + identity_R_loss * LAMBDA_IDENTITY
                + identity_S_loss * LAMBDA_IDENTITY
            )
            
            G_loss_all += G_loss
            #print('\n'+str(G_loss.item()))

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        #fake_R = fake_R.squeeze()/2*255+.5*255
        

        if idx % 1000 == 0:
            save_image(fake_R[0]*0.5+0.5, f"/content/drive/MyDrive/Project/data/saved_images_rgbd/real_{epoch}_{idx}.png")
            #print(fake_R[0])
            #save_image(fake_S[0], f"/content/drive/MyDrive/Project/data/saved_images/syntheic_{idx}.png")

        #loop.set_postfix(R_real=R_reals/(idx+1), R_fake=R_fakes/(idx+1))
        #loop.set_postfix(D_loss = D_loss.item(), G_loss = G_loss.item())

        D_loss_avg = D_loss_all/(idx+1)
        G_loss_avg = G_loss_all/(idx+1)

        loop.set_postfix(D_loss = D_loss_avg.item(), G_loss = G_loss_avg.item(), epoch = epoch)

    return D_loss_avg.item(), G_loss_avg.item() 

In [None]:
def main():
    disc_S = Discriminator(in_channels=4).to(DEVICE)
    disc_R = Discriminator(in_channels=3).to(DEVICE)
    gen_S = Generator(in_channels=3, out_channels=4, num_residuals=9).to(DEVICE)
    gen_R = Generator(in_channels=4, out_channels=3, num_residuals=9).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_S.parameters())+list(disc_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_S.parameters())+list(gen_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )

    #loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_S, gen_S, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_GEN_R, gen_R, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_S, disc_S, opt_disc, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_R, disc_R, opt_disc, LEARNING_RATE,
        )


    Dataset = gen_dataset(root_rgbd = "/content/drive/MyDrive/Project/data/rgbd",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized_specfic",
                      transform=TRANSFORMS,
                      transform_rgbd=TRANSFORMS_rgbd)

    loader = DataLoader(Dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory = True)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    D_loss_list = []
    G_loss_list = []
    for epoch in range(NUM_EPOCHS):
        D_loss_avg, G_loss_avg = train_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, epoch)

        D_loss_list.append(D_loss_avg)
        G_loss_list.append(G_loss_avg)
        
        if SAVE_MODEL:
            save_checkpoint(gen_S, opt_gen, PATH=CHECKPOINT_GEN_S)
            save_checkpoint(gen_R, opt_gen, PATH=CHECKPOINT_GEN_R)
            save_checkpoint(disc_S, opt_disc, PATH=CHECKPOINT_CRITIC_S)
            save_checkpoint(disc_R, opt_disc, PATH=CHECKPOINT_CRITIC_R)

    print(D_loss_list)
    print(G_loss_list)

In [None]:
if __name__ == "__main__":
    main()

100%|██████████| 5462/5462 [34:28<00:00,  2.64it/s, D_loss=0.252, G_loss=2.83, epoch=0]
100%|██████████| 5462/5462 [11:38<00:00,  7.82it/s, D_loss=0.219, G_loss=2.5, epoch=1]
100%|██████████| 5462/5462 [11:54<00:00,  7.64it/s, D_loss=0.197, G_loss=2.52, epoch=2]
100%|██████████| 5462/5462 [12:03<00:00,  7.55it/s, D_loss=0.201, G_loss=2.48, epoch=3]
100%|██████████| 5462/5462 [12:02<00:00,  7.56it/s, D_loss=0.196, G_loss=2.47, epoch=4]
100%|██████████| 5462/5462 [12:03<00:00,  7.55it/s, D_loss=0.197, G_loss=2.44, epoch=5]
100%|██████████| 5462/5462 [12:11<00:00,  7.47it/s, D_loss=0.2, G_loss=2.4, epoch=6]
100%|██████████| 5462/5462 [12:15<00:00,  7.42it/s, D_loss=0.194, G_loss=2.38, epoch=7]
100%|██████████| 5462/5462 [11:54<00:00,  7.65it/s, D_loss=0.189, G_loss=2.38, epoch=8]
100%|██████████| 5462/5462 [11:55<00:00,  7.63it/s, D_loss=0.19, G_loss=2.35, epoch=9]
100%|██████████| 5462/5462 [11:58<00:00,  7.60it/s, D_loss=0.184, G_loss=2.34, epoch=10]
100%|██████████| 5462/5462 [12:00<00

[0.2523498237133026, 0.21854223310947418, 0.19657599925994873, 0.20077016949653625, 0.19608400762081146, 0.1972602754831314, 0.19957491755485535, 0.19447419047355652, 0.188841313123703, 0.19031089544296265, 0.18399553000926971, 0.18347066640853882, 0.18051044642925262, 0.1777505874633789, 0.18767398595809937, 0.18624868988990784, 0.18276447057724, 0.18218521773815155, 0.18894615769386292, 0.19062072038650513]
[2.8267698287963867, 2.4991700649261475, 2.517655611038208, 2.4787521362304688, 2.4659016132354736, 2.437272071838379, 2.398709774017334, 2.3812484741210938, 2.376523971557617, 2.3484930992126465, 2.3414769172668457, 2.3205103874206543, 2.3202571868896484, 2.310685396194458, 2.266965627670288, 2.2522621154785156, 2.2624168395996094, 2.2578296661376953, 2.2345690727233887, 2.2148542404174805]


prob withwith cyclegan:https://zhuanlan.zhihu.com/p/45164258

20+13



```
100%|██████████| 5462/5462 [34:28<00:00,  2.64it/s, D_loss=0.252, G_loss=2.83, epoch=0]
100%|██████████| 5462/5462 [11:38<00:00,  7.82it/s, D_loss=0.219, G_loss=2.5, epoch=1]
100%|██████████| 5462/5462 [11:54<00:00,  7.64it/s, D_loss=0.197, G_loss=2.52, epoch=2]
100%|██████████| 5462/5462 [12:03<00:00,  7.55it/s, D_loss=0.201, G_loss=2.48, epoch=3]
100%|██████████| 5462/5462 [12:02<00:00,  7.56it/s, D_loss=0.196, G_loss=2.47, epoch=4]
100%|██████████| 5462/5462 [12:03<00:00,  7.55it/s, D_loss=0.197, G_loss=2.44, epoch=5]
100%|██████████| 5462/5462 [12:11<00:00,  7.47it/s, D_loss=0.2, G_loss=2.4, epoch=6]
100%|██████████| 5462/5462 [12:15<00:00,  7.42it/s, D_loss=0.194, G_loss=2.38, epoch=7]
100%|██████████| 5462/5462 [11:54<00:00,  7.65it/s, D_loss=0.189, G_loss=2.38, epoch=8]
100%|██████████| 5462/5462 [11:55<00:00,  7.63it/s, D_loss=0.19, G_loss=2.35, epoch=9]
100%|██████████| 5462/5462 [11:58<00:00,  7.60it/s, D_loss=0.184, G_loss=2.34, epoch=10]
100%|██████████| 5462/5462 [12:00<00:00,  7.58it/s, D_loss=0.183, G_loss=2.32, epoch=11]
100%|██████████| 5462/5462 [11:57<00:00,  7.61it/s, D_loss=0.181, G_loss=2.32, epoch=12]
100%|██████████| 5462/5462 [11:58<00:00,  7.60it/s, D_loss=0.178, G_loss=2.31, epoch=13]
100%|██████████| 5462/5462 [11:57<00:00,  7.61it/s, D_loss=0.188, G_loss=2.27, epoch=14]
100%|██████████| 5462/5462 [12:03<00:00,  7.55it/s, D_loss=0.186, G_loss=2.25, epoch=15]
100%|██████████| 5462/5462 [12:10<00:00,  7.47it/s, D_loss=0.183, G_loss=2.26, epoch=16]
100%|██████████| 5462/5462 [12:23<00:00,  7.34it/s, D_loss=0.182, G_loss=2.26, epoch=17]
100%|██████████| 5462/5462 [12:27<00:00,  7.31it/s, D_loss=0.189, G_loss=2.23, epoch=18]
100%|██████████| 5462/5462 [12:24<00:00,  7.33it/s, D_loss=0.191, G_loss=2.21, epoch=19]
[0.2523498237133026, 0.21854223310947418, 0.19657599925994873, 0.20077016949653625, 0.19608400762081146, 0.1972602754831314, 0.19957491755485535, 0.19447419047355652, 0.188841313123703, 0.19031089544296265, 0.18399553000926971, 0.18347066640853882, 0.18051044642925262, 0.1777505874633789, 0.18767398595809937, 0.18624868988990784, 0.18276447057724, 0.18218521773815155, 0.18894615769386292, 0.19062072038650513]
[2.8267698287963867, 2.4991700649261475, 2.517655611038208, 2.4787521362304688, 2.4659016132354736, 2.437272071838379, 2.398709774017334, 2.3812484741210938, 2.376523971557617, 2.3484930992126465, 2.3414769172668457, 2.3205103874206543, 2.3202571868896484, 2.310685396194458, 2.266965627670288, 2.2522621154785156, 2.2624168395996094, 2.2578296661376953, 2.2345690727233887, 2.2148542404174805]
```



100%|██████████| 5462/5462 [12:37<00:00,  7.21it/s, D_loss=0.185, G_loss=2.22, epoch=19]

100%|██████████| 5462/5462 [13:57<00:00,  6.52it/s, D_loss=0.21, G_loss=2.04, epoch=12]

## Test

In [None]:
class gen_dataset_test(Dataset):
    def __init__(self, root_rgbd, root_real, root_compare_rgb, root_compare_depth, transform=None, transform_rgbd=None):
        self.root_rgbd = root_rgbd
        self.root_real = root_real
        self.root_compare_rgb = root_compare_rgb
        self.root_compare_depth = root_compare_depth
        self.transform = transform
        self.transform_rgbd = transform_rgbd

        # to make os.listdir not shuffle
        #root_syntheic = os.getcwd()

        self.rgbd_images = os.listdir(root_rgbd)
        self.real_images = os.listdir(root_real)


        self.length_dataset = max(len(self.rgbd_images), len(self.real_images))
        self.rgbd_length = len(self.rgbd_images)
        self.real_length = len(self.real_images)

        #self.rgbd_images.sort()
        self.compare_rgb_images = os.listdir(root_compare_rgb)
        self.compare_rgb_images.sort()
        self.compare_depth_images = os.listdir(root_compare_depth)
        self.compare_depth_images.sort()        

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rgbd_img = self.rgbd_images[index % self.rgbd_length]
        real_img = self.real_images[index % self.real_length]
        compare_rgb_img = self.compare_rgb_images[index % self.real_length]
        compare_depth_img = self.compare_depth_images[index % self.real_length]

        rgbd_path = os.path.join(self.root_rgbd, rgbd_img)
        real_path = os.path.join(self.root_real, real_img)
        compare_rgb_path = os.path.join(self.root_compare_rgb, compare_rgb_img)
        compare_depth_path = os.path.join(self.root_compare_depth, compare_depth_img)

        rgbd_img = np.array(Image.open(rgbd_path))
        real_img = np.array(Image.open(real_path).convert("RGB"))


        if self.transform:
            real_img = self.transform(image=real_img)["image"]
        if self.transform_rgbd:
            rgbd_img = self.transform_rgbd(image=rgbd_img)["image"]

        
        return rgbd_img, real_img, compare_rgb_path, compare_depth_path, rgbd_path

In [None]:
def test_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, epoch):
    D_loss_all = 0
    G_loss_all = 0
    loop = tqdm(loader, leave=True)

    for idx, (rgbd_img, real_img, compare_rgb_path, compare_depth_path, rgbd_path) in enumerate(loop):
        input_img = rgbd_img

        input_img = input_img.float()
        real_img = real_img.float()
        
        input_img = input_img.to(DEVICE)
        real_img = real_img.to(DEVICE)

        fake_R = gen_R(input_img)

        if idx % 2 == 0:
            save_image(fake_R[0]*0.5+0.5, f"/content/drive/MyDrive/Project/data/result_rgbd_original/output/output_{epoch}_{idx}.png")

            rgb_img = cv2.imread((compare_rgb_path[0]), cv2.IMREAD_COLOR)
            depth_img = cv2.imread((compare_depth_path[0]), cv2.IMREAD_GRAYSCALE)
            cv2.imwrite(f"/content/drive/MyDrive/Project/data/result_rgbd_original/synthetic/input_rgb_{epoch}_{idx}.png", rgb_img)
            cv2.imwrite(f"/content/drive/MyDrive/Project/data/result_rgbd_original/depth/input_depth_{epoch}_{idx}.png", depth_img) 

            #print(fake_R[0])
            #out_image = input_img[:,0:3,:,:]
            #save_image(out_image[0], f"/content/drive/MyDrive/Project/data/saved_images_rgbd_pix/input_{epoch}_{idx}.png")

        #loop.set_postfix(R_real=R_reals/(idx+1), R_fake=R_fakes/(idx+1))
        #loop.set_postfix(D_loss = D_loss.item(), G_loss = G_loss.item())

        D_loss_avg = D_loss_all/(idx+1)
        G_loss_avg = G_loss_all/(idx+1)

        loop.set_postfix(epoch = epoch)
 

In [None]:
def main():
    disc_S = Discriminator(in_channels=4).to(DEVICE)
    disc_R = Discriminator(in_channels=3).to(DEVICE)
    gen_S = Generator(in_channels=3, out_channels=4, num_residuals=9).to(DEVICE)
    gen_R = Generator(in_channels=4, out_channels=3, num_residuals=9).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_S.parameters())+list(disc_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_S.parameters())+list(gen_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )

    #loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    LOAD_MODEL = True

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_S, gen_S, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_GEN_R, gen_R, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_S, disc_S, opt_disc, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_R, disc_R, opt_disc, LEARNING_RATE,
        )


    Dataset = gen_dataset_test(root_rgbd = "/content/drive/MyDrive/Project/data/rgbd_order",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized",
                      root_compare_rgb = "/content/drive/MyDrive/Project/data/synthetic",
                      root_compare_depth = "/content/drive/MyDrive/Project/data/depth map",
                      transform=TRANSFORMS,
                      transform_rgbd=TRANSFORMS_rgbd)

    loader = DataLoader(Dataset, batch_size = BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory = True)


    D_loss_list = []
    G_loss_list = []

    NUM_EPOCHS = 1
    for epoch in range(NUM_EPOCHS):
        test_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, epoch)


In [None]:
if __name__ == "__main__":
    main()

=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint


100%|██████████| 5460/5460 [36:44<00:00,  2.48it/s, epoch=0]


## Metrics

In [None]:
pip install pytorch-fid==0.1.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-fid==0.1.1
  Downloading pytorch-fid-0.1.1.tar.gz (9.3 kB)
Building wheels for collected packages: pytorch-fid
  Building wheel for pytorch-fid (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-fid: filename=pytorch_fid-0.1.1-py3-none-any.whl size=10186 sha256=ee9623418f729f75874e96439679da404f2d9633175a1aa2989a58a3718f3ff5
  Stored in directory: /root/.cache/pip/wheels/02/c5/ed/58ac12fce449ae1c1501c2e676988975e1afc852f6967ceb6b
Successfully built pytorch-fid
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.1.1


In [None]:
import pytorch_fid.fid_score

In [None]:
pytorch_fid.fid_score.calculate_fid_given_paths(['/content/drive/MyDrive/Project/data/real_resized_specfic', '/content/drive/MyDrive/Project/data/result_rgbd_original/output'], 1, 'cude', 2048)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

100%|██████████| 1050/1050 [00:29<00:00, 35.75it/s]
100%|██████████| 546/546 [00:12<00:00, 43.99it/s]


144.10242892130026