In [1]:
import os
import pandas
import matplotlib.pyplot as plt
import seaborn
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid

import albumentations as A

In [None]:
class CustomSynthiaDS(Dataset):
    def __init__(self, root = "/media/mountHDD2/synthia-sf", mode = "Left"):
        self.mode = mode
        self.imgs = sorted(glob(root + f"/*/RGB{self.mode}/*"))
        self.masks = sorted(glob(root + f"/*/GT{self.mode}Debug/*"))
        self.depths = sorted(glob(root + f"/*/DepthDebug{self.mode}/*"))

        self.transform = A.Compose(
            [
                # A.Resize(256, 256),
                A.HorizontalFlip(p=0.2),
            ]
        )

        self.outer_transform = A.Compose(
            [
                A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ]
        )

        self.

    def __len__(self):
        return len(self.imgs)

    @property
    def num_img(self):
        return len(self.imgs)

    @property
    def num_msk(self):
        return len(self.masks)

    @property
    def num_dpt(self):
        return len(self.depths)

    @staticmethod
    def process_mask(x):
        uniques = torch.unique(x, sorted = True)
        for i, v in enumerate(uniques):
            x[x == v] = i
        
        x = x.to(dtype=torch.long)
        onehot = F.one_hot(x.squeeze(1), 3).permute(0, 3, 1, 2)[0].float()
        return onehot

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        msk_path = self.masks[idx]
        dpt_path = self.depths[idx]

        