In [1]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import torchvision
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

In [4]:
ORIGINAL_IMAGE = 'E:/datasets/NDI_images/Integreted/Observed/circle_binaryzation_pics'
TARGET_IMAGE = 'E:/datasets/NDI_images/Integreted/Calculated/'

In [11]:
class ThreeChannelNDIDatasetContrastiveLearning(Dataset):
    def __init__(self):
        super(ThreeChannelNDIDatasetContrastiveLearning, self).__init__()
        original_images = list(Path(ORIGINAL_IMAGE).glob('*.jpg'))
        origins, targets, labels = [], [], []
        to_tensor_func = torchvision.transforms.ToTensor()
        for original_image in original_images:
            origins.append(to_tensor_func(Image.open(str(original_image)).convert('RGB')))
            targets.append(to_tensor_func(Image.open(str(Path.joinpath(Path(TARGET_IMAGE), original_image.name)))))
            labels.append(int(original_image.name.split('.')[0]) - 1)
        random_index = np.random.permutation(len(origins))
        self.origins, self.targets, self.labels = [], [], []
        for index in random_index:
            self.origins.append(origins[index])
            self.targets.append(targets[index])
            self.labels.append(labels[index])

    def __getitem__(self, idx):
        return self.origins[idx], self.targets[idx], self.labels[idx]

    def __len__(self):
        return len(self.origins)

In [12]:
dataset_a = ThreeChannelNDIDatasetContrastiveLearning()
train_dataset_a, val_dataset_a = torch.utils.data.random_split(dataset_a, [160, 24])
train_iter_a = DataLoader(train_dataset_a, batch_size=16, shuffle=True, drop_last=True)
val_iter_a = DataLoader(val_dataset_a, batch_size=len(val_dataset_a))

In [13]:
for d, e, f in train_iter_a:
    break
f

tensor([  6,  82,  75, 167,  16,  77,  52, 163,  66, 105,   7,  86,  54, 123,
         95,  25])

In [None]:
origin = transforms.ToPILImage()(d[0].squeeze(0))
target = transforms.ToPILImage()(e[0].squeeze(0))
label = f[0]
origin

In [22]:
def split_train_validation_randomly(original_path, target_path):
    original_images = list(sorted(list(map(str, list(Path(original_path).glob('*.jpg'))))))
    target_images = list(sorted(list(map(str, list(Path(target_path).glob('*.jpg'))))))
    images = list(zip(original_images, target_images))
    train_images, val_images = torch.utils.data.random_split(images, [160, 24])
    return train_images, val_images

class ThreeChannelNDIDatasetContrastiveLearningWithAug(Dataset):
    def __init__(self, images, evaluate=False):
        super(ThreeChannelNDIDatasetContrastiveLearningWithAug, self).__init__()
        if not evaluate:
            self.images = images[0]
        else:
            self.images = images[1]
        self.transforms = transforms.Compose([
            # transforms.CenterCrop(100),
            # transforms.Resize(224),
            # transforms.RandomHorizontalFlip(0.5),
            # transforms.RandomVerticalFlip(0.5),
            # transforms.RandomRotation(90),
            ])

    def __getitem__(self, idx):
        origin_path, target_path = self.images[idx]
        origin = Image.open(origin_path).convert('RGB')
        target = Image.open(target_path)
        origin, target = self.transforms(torch.cat((transforms.ToTensor()(origin).unsqueeze(0), transforms.ToTensor()(target).unsqueeze(0)), dim=0))
        label = int(origin_path.split('\\')[-1].split('.')[0]) - 1
        return origin, target, label

    def __len__(self):
        return len(self.images)

In [23]:
images_b = split_train_validation_randomly(ORIGINAL_IMAGE, TARGET_IMAGE)
train_dataset_b = ThreeChannelNDIDatasetContrastiveLearningWithAug(images_b, False)
val_dataset_b = ThreeChannelNDIDatasetContrastiveLearningWithAug(images_b, True)
train_iter_b = DataLoader(train_dataset_b, batch_size=16, shuffle=True, drop_last=True)
val_iter_b = DataLoader(val_dataset_b, batch_size=len(val_dataset_b))

['E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\1.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\10.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\100.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\101.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\102.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\103.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\104.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\105.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\106.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\107.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzation_pics\\108.jpg', 'E:\\datasets\\NDI_images\\Integreted\\Observed\\circle_binaryzatio

In [27]:
len(train_iter_a)

10

In [26]:
len(train_iter_b)

10