In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import shutil
from PIL import Image
import sys
from torchvision.datasets import ImageFolder
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
#import albumentations as A
#from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 30
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_S = "gens.pth.tar"
CHECKPOINT_GEN_R = "genr.pth.tar"
CHECKPOINT_CRITIC_S = "critics.pth.tar"
CHECKPOINT_CRITIC_R = "criticr.pth.tar"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def getFileNames(rootDir):
    fileNames = []
    # Use os.walk() function to get the name of the folder, subfolder and file name in the root directory
    for dirName, subDirList, fileList in os.walk(rootDir):
        for fname in fileList:
            # Use the os.path.split() function to determine and get the suffix of a file
            if os.path.splitext(fname)[1] == '.png':
                fileNames.append(dirName+'/'+fname)
    return fileNames

In [None]:
# Define the locations for the articles and summaries
base = '/content/drive/MyDrive/Project/data'
path = base + '/T1'

In [None]:
path = '/content/drive/MyDrive/Project/data/rgbd'
img_path = getFileNames(path)
print(len(img_path))
#img_path.sort()
print(img_path)

## Data process

In [None]:
def readImgAndMove(imgPath):
    # Define the folder name for the category, determine if it exists and if not create the corresponding folder
    base = '/content/drive/MyDrive/Project/data'
    img_destination = base+'/syntheic'
    depth_destination = base+'/depth map'
    #if not (os.path.exists(img_destination) and os.path.exists(depth_destination)):
        #os.makedirs(img_destination)
        #os.makedirs(depth_destination)
     

    for file in os.listdir(imgPath):
        file_path = '/content/drive/MyDrive/Project/data/T1' + f'/{file}'
        #print(file_path)
        for item in os.listdir(file_path):
            if os.path.splitext(item)[1] == '.png' and 'Depth' in os.path.splitext(item)[0]:
                #print(item_path)
                source = file_path + f'/{item}'
                shutil.copy(source, depth_destination)
                old_name = '/content/drive/MyDrive/Project/data/depth map' + f'/{item}'
                new_name = '/content/drive/MyDrive/Project/data/depth map' + f'/{file}' +'_' + f'{item}'
                os.rename(old_name, new_name)
            elif os.path.splitext(item)[1] == '.png' and 'FrameBuffer' in os.path.splitext(item)[0]:
                source = file_path + f'/{item}'
                shutil.copy(source, img_destination)
                old_name = '/content/drive/MyDrive/Project/data/syntheic' + f'/{item}'
                new_name = '/content/drive/MyDrive/Project/data/syntheic' + f'/{file}' +'_' + f'{item}'
                os.rename(old_name, new_name)




In [None]:
readImgAndMove('/content/drive/MyDrive/Project/data/T1')

In [None]:
def gen_rgbd(rgb_path, depth_path):
    rgb_path_list = os.listdir(rgb_path)
    rgb_path_list.sort()
    depth_path_list = os.listdir(depth_path)
    depth_path_list.sort()

    #number=1
    for i, name in enumerate(rgb_path_list):
        rgb_path_single = rgb_path+'/'+rgb_path_list[i]
        depth_path_single = depth_path+'/'+depth_path_list[i]
        #print(rgb_path_single)
        #print(depth_path_single)
        
        # actually rgba(red, green, blue, alpha), so get first three channels
        rgb = cv2.imread(rgb_path_single, cv2.IMREAD_UNCHANGED)
        depth = cv2.imread(depth_path_single, cv2.IMREAD_UNCHANGED)
        
        rgb_array = np.array(rgb)
        depth_array = np.array(depth)
        #print(rgb_array.shape)
        #print(depth_array.shape)

        rgbd = np.zeros((256,256,4),dtype=np.uint8)
        rgbd[:, :, 0] = rgb[:, :, 0]
        rgbd[:, :, 1] = rgb[:, :, 1]
        rgbd[:, :, 2] = rgb[:, :, 2]
        rgbd[:, :, 3] = depth

        #name of rgbd
        number = os.path.splitext(rgb_path_list[i])[0][-5:]
        category = os.path.splitext(rgb_path_list[i])[0][:17]

        img_name = category + number +'.png'+'.png'
        #number+=1
        save_path = "/content/drive/MyDrive/Project/data/rgbd_order"
        if os.path.exists(save_path):
            '''调用cv.2的imwrite函数保存图片'''
            save_img = save_path + '/' +img_name
            cv2.imwrite(save_img, rgbd)
        else:
            os.mkdir(save_path)
            save_img = save_path + '/' +img_name
            cv2.imwrite(save_img, rgbd)
        




In [None]:
rgb_path = "/content/drive/MyDrive/Project/data/syntheic"
depth_path = "/content/drive/MyDrive/Project/data/depth map"
gen_rgbd(rgb_path, depth_path)

In [None]:
def Resize_Image(root_path,w=256,h=256):
    file_path_list = os.listdir(root_path)
    for file in file_path_list:
        file_path = root_path + f'/{file}'
        image_path_list = os.listdir(file_path)
        image_path_list.sort()
        number=1
        for filename in image_path_list:
            file = file_path+'/'+filename
            img = cv2.imread((file), cv2.IMREAD_COLOR)
            
            new_img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
            #print(new_img.shape)
            img_name = str(number)+'.png'
            number+=1
            '''生成图片存储的目标路径'''
            save_path = "/content/drive/MyDrive/Project/data/real_resized_specfic"
            if os.path.exists(save_path):
                '''调用cv.2的imwrite函数保存图片'''
                save_img = save_path + '/' +img_name
                cv2.imwrite(save_img, new_img)
            else:
                os.mkdir(save_path)
                save_img = save_path + '/' +img_name
                cv2.imwrite(save_img, new_img)


In [None]:
Resize_Image('/content/drive/MyDrive/Project/data/kvasir-dataset-v2',w=256,h=256)

## Generate data set

In [None]:
class gen_dataset(Dataset):
    def __init__(self, root_rgbd, root_real, root_compare_rgb, root_compare_depth, transform=None, transform_rgbd=None):
        self.root_rgbd = root_rgbd
        self.root_real = root_real
        self.root_compare_rgb = root_compare_rgb
        self.root_compare_depth = root_compare_depth
        self.transform = transform
        self.transform_rgbd = transform_rgbd

        # to make os.listdir not shuffle
        #root_syntheic = os.getcwd()

        self.rgbd_images = os.listdir(root_rgbd)
        self.real_images = os.listdir(root_real)


        self.length_dataset = min(len(self.rgbd_images), len(self.real_images))
        self.rgbd_length = len(self.rgbd_images)
        self.real_length = len(self.real_images)

        self.rgbd_images.sort()
        self.compare_rgb_images = os.listdir(root_compare_rgb)
        self.compare_rgb_images.sort()
        self.compare_depth_images = os.listdir(root_compare_depth)
        self.compare_depth_images.sort()        

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rgbd_img = self.rgbd_images[index % self.rgbd_length]
        real_img = self.real_images[index % self.real_length]
        compare_rgb_img = self.compare_rgb_images[index % self.real_length]
        compare_depth_img = self.compare_depth_images[index % self.real_length]

        rgbd_path = os.path.join(self.root_rgbd, rgbd_img)
        real_path = os.path.join(self.root_real, real_img)
        compare_rgb_path = os.path.join(self.root_compare_rgb, compare_rgb_img)
        compare_depth_path = os.path.join(self.root_compare_depth, compare_depth_img)

        rgbd_img = np.array(Image.open(rgbd_path))
        real_img = np.array(Image.open(real_path).convert("RGB"))
        compare_rgb_img = np.array(Image.open(compare_rgb_path))
        compare_depth_img = np.array(Image.open(compare_depth_path))

        if self.transform:

            real_img = self.transform(image=real_img)["image"]
        if self.transform_rgbd:
            rgbd_img = self.transform_rgbd(image=rgbd_img)["image"]

        
        return rgbd_img, real_img, compare_rgb_path, compare_depth_path, rgbd_path

In [None]:
Dataset = gen_dataset(root_rgbd = "/content/drive/MyDrive/Project/data/rgbd_order",
                      root_real = "/content/drive/MyDrive/Project/data/real_resized_specfic",
                      root_compare_rgb = "/content/drive/MyDrive/Project/data/syntheic",
                      root_compare_depth = "/content/drive/MyDrive/Project/data/depth map",
                      transform=None,
                      transform_rgbd=None)

loader = DataLoader(Dataset, batch_size = 1, shuffle=True, num_workers=2, pin_memory = True)

In [None]:
for data in loader:
    rgbd_img, real_img, compare_rgb_path, compare_depth_img, rgbd_path = data
    #compare_rgb_img = compare_rgb_img.permute(0,3,1,2)
    #compare_depth_img = compare_depth_img.permute(0,3,1,2)
    print(compare_rgb_path)
    print(compare_depth_img)
    print(rgbd_path)

## FID

https://colab.research.google.com/github/pytorch-ignite/pytorch-ignite.ai/blob/gh-pages/blog/2021-08-11-GAN-evaluation-using-FID-and-IS.ipynb#scrollTo=b2r7OHYEGxLz

https://github.com/mseitzer/pytorch-fid

In [None]:
pip install pytorch-fid==0.1.1

In [None]:
import pytorch_fid.fid_score

In [None]:
pytorch_fid.fid_score.calculate_fid_given_paths(['/content/drive/MyDrive/Project/data/real_resized_specfic', '/content/drive/MyDrive/Project/data/real_resized'], 1, 'cude', 2048)