## packages and parameters

In [None]:
!pip install albumentations==0.4.6

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import shutil
from PIL import Image
import sys
from torchvision.datasets import ImageFolder
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
BATCH_SIZE = 1
LEARNING_RATE = (1e-6)
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 20
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_S = "/content/drive/MyDrive/Project/data/checkpoints/gens_pix_rgb_PATCH70_noskip.pth.tar"
CHECKPOINT_GEN_R = "/content/drive/MyDrive/Project/data/checkpoints/genr_pix_rgb_PATCH70_noskip.pth.tar"
CHECKPOINT_CRITIC_S = "/content/drive/MyDrive/Project/data/checkpoints/critics_pix_rgb_PATCH70_noskip.pth.tar"
CHECKPOINT_CRITIC_R = "/content/drive/MyDrive/Project/data/checkpoints/criticr_pix_rgb_PATCH70_noskip.pth.tar"

TRANSFORMS_rgbd = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
)

TRANSFORMS = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
)

In [None]:
def save_checkpoint(model, optimizer, PATH="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, PATH)


def load_checkpoint(PATH, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    optimizer.param_groups[0]['capturable'] = True
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

## Discriminator

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            # kernel size = 4, padding = 1
            nn.Conv2d(in_channels, out_channels, 4,stride ,1 , bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self,x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            # stride = 1 for last one and 2 for first 3
            layers.append(CNNBlock(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
        # the out_channels is 1, since output 0 or 1 to indicate true or fake
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [None]:
def test():
    x = torch.randn((5,3,256,256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)

test()

## U-net Generator

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True) if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU(inplace = True)
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=False)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))


In [None]:
class Generator_noskip(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(up1)
        up3 = self.up3(up2)
        up4 = self.up4(up3)
        up5 = self.up5(up4)
        up6 = self.up6(up5)
        up7 = self.up7(up6)
        return self.final_up(up7)

In [None]:
def test():
    in_channels = 4
    out_channels = 3
    img_size = 256
    x = torch.randn((2, in_channels, img_size, img_size))
    gen = Generator(in_channels, out_channels)
    #print(gen)
    print(gen(x).shape)

test()

## Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Load images

In [None]:
class gen_dataset(Dataset):
    def __init__(self, root_rgbd, root_real, transform=None, transform_rgbd=None):
        self.root_rgbd = root_rgbd
        self.root_real = root_real
        self.transform = transform
        self.transform_rgbd = transform_rgbd

        # to make os.listdir not shuffle
        #root_syntheic = os.getcwd()

        self.rgbd_images = os.listdir(root_rgbd)
        self.real_images = os.listdir(root_real)


        self.length_dataset = max(len(self.rgbd_images), len(self.real_images))
        self.rgbd_length = len(self.rgbd_images)
        self.real_length = len(self.real_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rgbd_img = self.rgbd_images[index % self.rgbd_length]
        real_img = self.real_images[index % self.real_length]

        rgbd_path = os.path.join(self.root_rgbd, rgbd_img)
        real_path = os.path.join(self.root_real, real_img)

        #rgbd_img = np.array(Image.open(rgbd_path))
        rgbd_img = np.array(Image.open(rgbd_path).convert("RGB"))
        real_img = np.array(Image.open(real_path).convert("RGB"))
        #real_img = real_img.permute(2,0,1)

        if self.transform:

            real_img = self.transform(image=real_img)["image"]
            rgbd_img = self.transform(image=rgbd_img)["image"]
        if self.transform_rgbd:
            rgbd_img = self.transform_rgbd(image=rgbd_img)["image"]

        
        return rgbd_img, real_img

## Train

In [None]:
def train_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, l1, G_loss_func, D_loss_func, d_scaler, g_scaler, epoch):
    D_loss_all = 0
    G_loss_all = 0
    loop = tqdm(loader, leave=True)

    for idx, (rgbd_img, real_img) in enumerate(loop):
        #syntheic_img = syntheic_img.permute(0,3,1,2)
        #depth_img = depth_img.permute(0,3,1,2)
        #real_img = real_img.permute(0,3,1,2)

        #input_img = syntheic_img
        input_img = rgbd_img
        #real_img = torch.cat([real_img, depth_img],1)
        input_img = input_img.float()
        real_img = real_img.float()
        
        input_img = input_img.to(DEVICE)
        real_img = real_img.to(DEVICE)

        
        # Train Discriminators
        with torch.cuda.amp.autocast():
            # Disc R
            fake_R = gen_R(input_img)
            D_R_real = disc_R(real_img)
            D_R_fake = disc_R(fake_R.detach())

            #R_reals += D_R_real.mean().item()
            #R_fakes += D_R_fake.mean().item()

            D_R_real_loss = D_loss_func(D_R_real, torch.ones_like(D_R_real))
            D_R_fake_loss = D_loss_func(D_R_fake, torch.zeros_like(D_R_fake))
            D_R_loss = D_R_real_loss + D_R_fake_loss

            # Disc S
            fake_S = gen_S(real_img)
            D_S_real = disc_S(input_img)
            D_S_fake = disc_S(fake_S.detach())

            D_S_real_loss = D_loss_func(D_S_real, torch.ones_like(D_S_real))
            D_S_fake_loss = D_loss_func(D_S_fake, torch.zeros_like(D_S_fake))
            D_S_loss = D_S_real_loss + D_S_fake_loss

            # put it togethor
            D_loss = (D_R_loss + D_S_loss)/2
            D_loss_all += D_loss
            #print('\n'+str(D_loss.item()))

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_R_fake = disc_R(fake_R)
            D_S_fake = disc_S(fake_S)
            loss_G_R = G_loss_func(D_R_fake, torch.ones_like(D_R_fake))
            loss_G_S = G_loss_func(D_S_fake, torch.ones_like(D_S_fake))

            # cycle loss
            cycle_S = gen_S(fake_R)
            cycle_R = gen_R(fake_S)
            cycle_S_loss = l1(input_img, cycle_S)
            cycle_R_loss = l1(real_img, cycle_R)

            # identity loss (set lambda_identity=0)
            identity_S = 0 #gen_S(input_img)
            identity_R = 0 #gen_R(real_img)
            identity_S_loss = 0 #l1(input_img, identity_S)
            identity_R_loss = 0 #l1(real_img, identity_R)

            # add all togethor
            G_loss = (
                loss_G_S
                + loss_G_R
                + cycle_S_loss * LAMBDA_CYCLE
                + cycle_R_loss * LAMBDA_CYCLE
                + identity_R_loss * LAMBDA_IDENTITY
                + identity_S_loss * LAMBDA_IDENTITY
            )
            
            G_loss_all += G_loss
            #print('\n'+str(G_loss.item()))

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        #fake_R = fake_R.squeeze()/2*255+.5*255
        

        if idx % 1000 == 0:
            save_image(fake_R[0]*0.5+0.5, f"/content/drive/MyDrive/Project/data/saved_images_rgb_pix/real_{epoch}_{idx}.png")
            #print(fake_R[0])
            #save_image(fake_S[0], f"/content/drive/MyDrive/Project/data/saved_images/syntheic_{idx}.png")

        #loop.set_postfix(R_real=R_reals/(idx+1), R_fake=R_fakes/(idx+1))
        #loop.set_postfix(D_loss = D_loss.item(), G_loss = G_loss.item())

        D_loss_avg = D_loss_all/(idx+1)
        G_loss_avg = G_loss_all/(idx+1)

        loop.set_postfix(D_loss = D_loss_avg.item(), G_loss = G_loss_avg.item(), epoch = epoch)

    return D_loss_avg.item(), G_loss_avg.item() 

In [None]:
def main():
    disc_S = Discriminator(in_channels=3).to(DEVICE)
    disc_R = Discriminator(in_channels=3).to(DEVICE)
    gen_S = Generator(in_channels=3, out_channels=3).to(DEVICE)
    gen_R = Generator(in_channels=3, out_channels=3).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_S.parameters())+list(disc_R.parameters()),
        lr = LEARNING_RATE*2,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_S.parameters())+list(gen_R.parameters()),
        lr = LEARNING_RATE*5,
        betas = (0.5, 0.999)
    )

    #loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    BCE = nn.BCEWithLogitsLoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_S, gen_S, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_GEN_R, gen_R, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_S, disc_S, opt_disc, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_R, disc_R, opt_disc, LEARNING_RATE,
        )


    Dataset = gen_dataset(root_rgbd = "/content/drive/MyDrive/Project/data/synthetic",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized_specfic",
                      transform=TRANSFORMS,
                      transform_rgbd=None)

    loader = DataLoader(Dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory = True)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    D_loss_list = []
    G_loss_list = []
    for epoch in range(NUM_EPOCHS):
        D_loss_avg, G_loss_avg = train_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, L1, mse, mse, d_scaler, g_scaler, epoch)

        D_loss_list.append(D_loss_avg)
        G_loss_list.append(G_loss_avg)
        
        if SAVE_MODEL:
            save_checkpoint(gen_S, opt_gen, PATH=CHECKPOINT_GEN_S)
            save_checkpoint(gen_R, opt_gen, PATH=CHECKPOINT_GEN_R)
            save_checkpoint(disc_S, opt_disc, PATH=CHECKPOINT_CRITIC_S)
            save_checkpoint(disc_R, opt_disc, PATH=CHECKPOINT_CRITIC_R)

    print(D_loss_list)
    print(G_loss_list)

In [None]:
if __name__ == "__main__":
    main()

## Clean memory


In [None]:
torch.cuda.empty_cache()

In [None]:
torch.cuda.max_memory_reserved()

In [None]:
!cat /proc/meminfo

In [None]:
!pip install GPUtil

from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()  

prob withwith cyclegan:https://zhuanlan.zhihu.com/p/45164258

## Test


In [None]:
class gen_dataset_test(Dataset):
    def __init__(self, root_rgbd, root_real, root_compare_rgb, root_compare_depth, transform=None, transform_rgbd=None):
        self.root_rgbd = root_rgbd
        self.root_real = root_real
        self.root_compare_rgb = root_compare_rgb
        self.root_compare_depth = root_compare_depth
        self.transform = transform
        self.transform_rgbd = transform_rgbd

        # to make os.listdir not shuffle
        #root_syntheic = os.getcwd()

        self.rgbd_images = os.listdir(root_rgbd)
        self.real_images = os.listdir(root_real)


        self.length_dataset = max(len(self.rgbd_images), len(self.real_images))
        self.rgbd_length = len(self.rgbd_images)
        self.real_length = len(self.real_images)

        #self.rgbd_images.sort()
        self.compare_rgb_images = os.listdir(root_compare_rgb)
        self.compare_rgb_images.sort()
        self.compare_depth_images = os.listdir(root_compare_depth)
        self.compare_depth_images.sort()        

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rgbd_img = self.rgbd_images[index % self.rgbd_length]
        real_img = self.real_images[index % self.real_length]
        compare_rgb_img = self.compare_rgb_images[index % self.real_length]
        compare_depth_img = self.compare_depth_images[index % self.real_length]

        rgbd_path = os.path.join(self.root_rgbd, rgbd_img)
        real_path = os.path.join(self.root_real, real_img)
        compare_rgb_path = os.path.join(self.root_compare_rgb, compare_rgb_img)
        compare_depth_path = os.path.join(self.root_compare_depth, compare_depth_img)

        rgbd_img = np.array(Image.open(rgbd_path).convert("RGB"))
        real_img = np.array(Image.open(real_path).convert("RGB"))


        if self.transform:
            real_img = self.transform(image=real_img)["image"]
            rgbd_img = self.transform(image=rgbd_img)["image"]

        
        return rgbd_img, real_img, compare_rgb_path, compare_depth_path, rgbd_path

In [None]:
def test_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, epoch):
    D_loss_all = 0
    G_loss_all = 0
    loop = tqdm(loader, leave=True)

    for idx, (rgbd_img, real_img, compare_rgb_path, compare_depth_path, rgbd_path) in enumerate(loop):
        input_img = rgbd_img

        input_img = input_img.float()
        real_img = real_img.float()
        
        input_img = input_img.to(DEVICE)
        real_img = real_img.to(DEVICE)

        fake_R = gen_R(input_img)

        if idx % 10 == 0:
            save_image(fake_R[0]*0.5+0.5, f"/content/drive/MyDrive/Project/data/result_rgb_pix/output/output_{epoch}_{idx}.png")

            rgb_img = cv2.imread((compare_rgb_path[0]), cv2.IMREAD_COLOR)
            #depth_img = cv2.imread((compare_depth_path[0]), cv2.IMREAD_GRAYSCALE)
            cv2.imwrite(f"/content/drive/MyDrive/Project/data/result_rgb_pix/synthetic/input_rgb_{epoch}_{idx}.png", rgb_img)
            #cv2.imwrite(f"/content/drive/MyDrive/Project/data/result_rgb_doubleunet/depth/input_depth_{epoch}_{idx}.png", depth_img) 

            #print(fake_R[0])
            #out_image = input_img[:,0:3,:,:]
            #save_image(out_image[0], f"/content/drive/MyDrive/Project/data/saved_images_rgbd_pix/input_{epoch}_{idx}.png")

        #loop.set_postfix(R_real=R_reals/(idx+1), R_fake=R_fakes/(idx+1))
        #loop.set_postfix(D_loss = D_loss.item(), G_loss = G_loss.item())

        D_loss_avg = D_loss_all/(idx+1)
        G_loss_avg = G_loss_all/(idx+1)

        loop.set_postfix(epoch = epoch)
 

In [None]:
def main():
    disc_S = Discriminator(in_channels=3).to(DEVICE)
    disc_R = Discriminator(in_channels=3).to(DEVICE)
    gen_S = Generator(in_channels=3, out_channels=3).to(DEVICE)
    gen_R = Generator(in_channels=3, out_channels=3).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_S.parameters())+list(disc_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_S.parameters())+list(gen_R.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )

    #loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    LOAD_MODEL = True

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_S, gen_S, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_GEN_R, gen_R, opt_gen, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_S, disc_S, opt_disc, LEARNING_RATE,
        )

        load_checkpoint(
            CHECKPOINT_CRITIC_R, disc_R, opt_disc, LEARNING_RATE,
        )


    Dataset = gen_dataset_test(root_rgbd = "/content/drive/MyDrive/Project/data/synthetic",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized",
                      root_compare_rgb = "/content/drive/MyDrive/Project/data/synthetic",
                      root_compare_depth = "/content/drive/MyDrive/Project/data/depth map",
                      transform=TRANSFORMS,
                      transform_rgbd=TRANSFORMS_rgbd)

    loader = DataLoader(Dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory = True)


    D_loss_list = []
    G_loss_list = []

    NUM_EPOCHS = 1
    for epoch in range(NUM_EPOCHS):
        test_fn(disc_S, disc_R, gen_S, gen_R, loader, opt_disc, opt_gen, epoch)


In [None]:
if __name__ == "__main__":
    main()

## Metric

In [None]:
pip install pytorch-fid==0.1.1

In [None]:
import pytorch_fid.fid_score

In [None]:
pytorch_fid.fid_score.calculate_fid_given_paths(['/content/drive/MyDrive/Project/data/real_resized_specfic', '/content/drive/MyDrive/Project/data/result_rgb_pix/output'], 1, 'cude', 2048)

In [None]:
pytorch_fid.fid_score.calculate_fid_given_paths(['/content/drive/MyDrive/Project/data/real_resized_specfic', '/content/drive/MyDrive/Project/data/result_rgb_original/output'], 1, 'cude', 2048)