In [2]:
training_frac = 0.4 # use only 40% of files
dataset_dir = '../dataset/project_dataset_corel'
n_folds = 3

seed = 1234

In [2]:
import os, os.path as osp
from sklearn.model_selection import train_test_split
from collections import defaultdict
import itertools
import random
import pandas as pd


def list_join_dir(dir):
    return sorted([osp.join(dir, f) for f in os.listdir(dir)])

def list_images_gts(img_dir, reduce_factor=None):
    images, gts = [], []
    for img_path in os.listdir(img_dir):
        img_basename = osp.basename(img_path)
        class_id = int(img_basename.split('.')[0].split('_')[0])
        images.append(img_path)
        gts.append(class_id)
    
    if reduce_factor:
        images, _, gts, _ = train_test_split(images, gts, test_size=reduce_factor, stratify=gts, random_state=1234)

    return images, gts

def create_img_pairs(imgs, gts):
    generator = random.Random(seed)
    img_per_class = defaultdict(list)
    for idx, (img, gt) in enumerate(zip(imgs, gts)):
        img_per_class[gt].append(img)
    
    pair_files = []
    pair_labels = [] # 0 for same label, 1 otherwise
    # same label
    for label, value in img_per_class.items():
        pair_files += list(itertools.combinations(value, 2))
    same_len = len(pair_files)
    pair_labels = [0 for i in range(same_len)]

    weight_per_first_img_gt = defaultdict(list)
    for label, value in img_per_class.items():
        for img, gt in zip(imgs, gts):
            w = 0 if gt==label else len(img_per_class[gt])
            weight_per_first_img_gt[label].append(w)
    # different label
    first_imgs = generator.choices(list(zip(imgs, gts)), k=same_len)
    for (f_img, f_gt) in first_imgs:
        s_img = generator.choices(imgs, weights=weight_per_first_img_gt[f_gt])[0]
        pair_files.append((f_img, s_img))
    pair_labels += [1 for i in range(same_len)]

    return pair_files, pair_labels

def prepare_fold(fold_dir):
    def prepare_dir(dir, reduce_factor):
        imgs, gts = list_images_gts(dir, reduce_factor=reduce_factor)
        pair_files, pair_labels = create_img_pairs(imgs, gts)

        df = pd.DataFrame(pair_files, columns=['img_1', 'img_2'])
        df['label'] = pair_labels
        return df
    def to_csv(df, name):
        df.to_csv(osp.join(fold_dir, name), header=None, index=False)

    df_train = prepare_dir(osp.join(fold_dir, 'train'), 1-training_frac)
    df_val = prepare_dir(osp.join(fold_dir, 'val'), None)
    df_test = prepare_dir(osp.join(fold_dir, 'test'), None)

    to_csv(df_train, 'ct_train.csv')
    to_csv(df_val, 'ct_val.csv')
    to_csv(df_test, 'ct_test.csv')


In [3]:
for fold in range(n_folds):
    prepare_fold(osp.join(dataset_dir, f'fold{fold}'))

In [7]:
import pandas as pd
import typing
import os.path as osp
from PIL import Image

class CtDataset():
    def __init__(self, fold_dir, mode: typing.Literal['test', 'val', 'train'], transforms=None) -> None:
        self.fold_dir = fold_dir
        self.mode = mode
        self.transforms = transforms

        self.img_dir = osp.join(fold_dir, mode)
        self.pair_df = pd.read_csv(osp.join(fold_dir, f'ct_{self.mode}.csv'), header=None)

        self.img_dict = {}
        img_list = pd.concat((self.pair_df[0], self.pair_df[1]), ignore_index=True).drop_duplicates().reset_index(drop=True)

        for idx, item in img_list.items():
            img_path = osp.join(self.img_dir, item)
            img = Image.open(img_path)
            self.img_dict[item] = img
        
    def __len__(self):
        return len(self.pair_df)
    
    def __getitem__(self, idx):
        pair = self.pair_df.loc[idx, :]
        img1 = self.img_dict[pair[0]]
        img2 = self.img_dict[pair[1]]
        if self.transforms:
            img1 = self.transforms(img1)
            img2 = self.transforms(img2)
        return img1, img2, pair[2]


In [17]:
d = CtDataset(osp.join(dataset_dir, 'fold0'), 'train')
print(d[0])
print(d[500])

(<PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F5652FCC6A0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F5652FCC430>, 0)
(<PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F5652FCCA60>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F5652FCD840>, 1)


(<PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224>,
 0)