In [None]:
import lightning as pl
import torch
from PIL import Image
from os.path import join as pjoin 
import os
import random
import torchvision.transforms as T
import matplotlib.pyplot as plt

In [None]:
pl.seed_everything(0)

In [None]:
class MIDVHoloDataset:
    IMAGES_TRANSFORM = [Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270]
    def __init__(self, input_dir, transform, split_dir="", split_file="train.txt", only_label=None, flip_rot=True) -> None:
        # self.input_dir = input_dir
        self.transform = transform
        self.labels_dict = {"fraud/copy_without_holo":{}, "fraud/photo_holo_copy":{}, "fraud/pseudo_holo_copy":{}, "origins":{}}
        self.shorttopath = {"copy_without_holo":"fraud/copy_without_holo", "photo_holo_copy":"fraud/photo_holo_copy", "pseudo_holo_copy":"fraud/pseudo_holo_copy", "origins":"origins"}
        self.fraud_names = [k for k in self.labels_dict if k != "origins"]
        self.files = []
        self.labels = []
        self.input_dir = os.path.normpath(input_dir)
        self.only_label = only_label
        for l in self.labels_dict:
            files_tmp, labels_tmp = self.getFilesSplit(pjoin(self.input_dir, l), split_dir, split_file)
            self.files += files_tmp
            self.labels += labels_tmp
        self.lenght = self.__len__()
        self.flip_rot = flip_rot
        if self.flip_rot:
            print("random flip and rotation")
    
    def randomFlipRotation(self, imgs):
        op = random.choice(self.IMAGES_TRANSFORM)
        return [img.transpose(op) for img in imgs]
    
    def getFilesSplit(self, input_dir, split_dir, split_file=""):
        images = []
        labels = []
        general_type = os.path.basename(input_dir)
        if len(split_dir):
            with open(pjoin(split_dir, self.shorttopath[general_type], split_file)) as f: #f"train.txt"
                video_names = f.read().split("\n")
        else:
            with open(pjoin(input_dir, f"{general_type}.lst")) as f:
                video_names = f.read().split("\n")[:-1]
        for vn in video_names:
            name = general_type if general_type == "origins" else "fraud/"+general_type
            if self.only_label is not None:
                # will only takes origins (only_label True) or frauds (only_label False)
                if (general_type == "origins") != self.only_label:
                    continue
            l = f"{name}/{os.path.dirname(vn)}"
            with open(pjoin(input_dir, vn)) as f:
                tmp_lst = [v for v in f.read().split("\n") if v != ""]
                images += tmp_lst
                labels += [l] * len(tmp_lst)
                self.labels_dict[name][l] = tmp_lst
        assert len(images) == len(labels), "images must be the same size as labels"
        return images, labels

    def __getitem__(self, idx: int):
        f = self.files[idx]
        l = self.labels[idx]
        if "origins" in l:
            im = Image.open(pjoin(self.input_dir, l, f))
            tmp_l = self.labels[idx+1 if idx+1 < self.lenght else idx-1]
            if tmp_l == l:
                im_n = Image.open(pjoin(self.input_dir, tmp_l, self.files[idx+1 if idx+1 < self.lenght else idx-1]))
            else:
                im_n = Image.open(pjoin(self.input_dir, self.labels[idx-1], self.files[idx-1]))
                    
            if self.flip_rot and random.random() < 0.5:
                im, im_n = self.randomFlipRotation((im, im_n))

            return [self.transform(im), self.transform(im), self.transform(im_n)], l
        else:
            im = Image.open(pjoin(self.input_dir, l, f))
            fraud = "/".join(l.split("/")[:2])
            img_path_tmp = random.choice(self.labels_dict[fraud][l])
            im_p = Image.open(pjoin(self.input_dir, l, img_path_tmp))
            possible_frauds = [k for k in self.fraud_names if k != fraud]

            fraud_n = random.choice(possible_frauds)
            k_n = fraud_n + "/"+"/".join(l.split("/")[2:])
            im_n = random.choice(self.labels_dict[fraud_n][k_n])
            im_n = Image.open(pjoin(self.input_dir, k_n, im_n))

            if self.flip_rot and random.random() < 0.5:
                im, im_p, im_n = self.randomFlipRotation((im, im_p, im_n))

            return [self.transform(im), self.transform(im_p), self.transform(im_n)], l
        
    def __len__(self) -> int:
        return len(self.files)

In [None]:
IMAGENET_NORMALIZE = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}

visu_transform = T.Compose(
    [
        T.Resize(size=256), 
        T.RandomResizedCrop(size=(224, 224), scale=(0.8, 1)),
        T.RandomApply(torch.nn.ModuleList([
            T.GaussianBlur(kernel_size=[3, 11], sigma=[2, 10])]), p=0.4),
        T.RandomApply(torch.nn.ModuleList([
            T.ColorJitter(brightness=0.3, contrast=0.1, saturation=0.05)]), p=0.4),

        # for visualization
        # T.ToTensor(), 
        # T.Normalize(
        #     mean=IMAGENET_NORMALIZE["mean"],
        #     std=IMAGENET_NORMALIZE["std"],
        # ),
    ]
)
data_dir, split_dir = "data/midv-holo/crop_ovds/", "../data/splits_kfold_s0/k0"
data = MIDVHoloDataset(data_dir, visu_transform, split_dir, "trainval/train_train.txt")

### Visualizing training samples with transformations 

In [None]:
# i = random.randint(0, len(data))
pl.seed_everything(0)
i = 4660
ims, l = data[i]
f, ax = plt.subplots(1, 3, figsize=(10, 4))
f.suptitle(f"path: {l} ({i}th image)")
ax[0].set_title("anchor img")
ax[0].imshow(ims[0])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_title("positive img")
ax[1].imshow(ims[1])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_title("negative img")
ax[2].imshow(ims[2])
ax[2].set_xticks([])
ax[2].set_yticks([])

In [None]:
pl.seed_everything(6)
n = 6
f, ax = plt.subplots(n, 3, figsize=(7, 14))
for j in range(n):
    i = random.randint(0, len(data))
    ims, l = data[i]
    ax[j, 0].set_title("FAKE" if "fraud" in l else "LEGIT")
    ax[j, 0].imshow(ims[0])
    ax[j, 0].set_xticks([])
    ax[j, 0].set_yticks([])
    # ax[j, 1].set_title("positive img")
    ax[j, 1].imshow(ims[1])
    ax[j, 1].set_xticks([])
    ax[j, 1].set_yticks([])
    # ax[j, 2].set_title("negative img")
    ax[j, 2].imshow(ims[2])
    ax[j, 2].set_xticks([])
    ax[j, 2].set_yticks([])

In [None]:
ims2, l = data[11932]
pl.seed_everything(0)
f, ax = plt.subplots(1, 3, figsize=(10, 4))
f.suptitle(f"path: {l} ({i}th image)")
ax[0].set_title("anchor img")
ax[0].imshow(ims2[0])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_title("positive img")
ax[1].imshow(ims2[1])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_title("negative img")
ax[2].imshow(ims2[2])
ax[2].set_xticks([])
ax[2].set_yticks([])

### Only frauds

In [None]:
data2 = MIDVHoloDataset(data_dir, visu_transform, split_dir, "trainval/train_train.txt", False)

In [None]:
len(data2),len(data)

In [None]:
# pl.seed_everything(1)
i = random.randint(0, len(data2))
# i = 4403
ims2, l = data2[i]
f, ax = plt.subplots(1, 3, figsize=(10, 4))
f.suptitle(f"path: {l} ({i}th image)")
ax[0].set_title("anchor img")
ax[0].imshow(ims2[0])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_title("positive img")
ax[1].imshow(ims2[1])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_title("negative img")
ax[2].imshow(ims2[2])
ax[2].set_xticks([])
ax[2].set_yticks([])

### Only originals
Used for the ablation study

In [None]:
data_origins = MIDVHoloDataset(data_dir, visu_transform, split_dir, "trainval/train_train.txt", True)

In [None]:
pl.seed_everything(0)
i = random.randint(0, len(data2))
# i = 4403
ims2, l = data_origins[i]
f, ax = plt.subplots(1, 3, figsize=(10, 4))
f.suptitle(f"path: {l} ({i}th image)")
ax[0].set_title("anchor img")
ax[0].imshow(ims2[0])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_title("positive img")
ax[1].imshow(ims2[1])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_title("negative img")
ax[2].imshow(ims2[2])
ax[2].set_xticks([])
ax[2].set_yticks([])

### No flip/rotation
Used for the ablation study

In [None]:
data_noflip = MIDVHoloDataset(data_dir, visu_transform, split_dir, "trainval/train_train.txt", flip_rot=False)

In [None]:
pl.seed_everything(6)
n = 6
f, ax = plt.subplots(n, 3, figsize=(7, 14))
for j in range(n):
    i = random.randint(0, len(data_noflip))
    ims, l = data_noflip[i]
    ax[j, 0].set_title("FAKE" if "fraud" in l else "LEGIT")
    ax[j, 0].imshow(ims[0])
    ax[j, 0].set_xticks([])
    ax[j, 0].set_yticks([])
    # ax[j, 1].set_title("positive img")
    ax[j, 1].imshow(ims[1])
    ax[j, 1].set_xticks([])
    ax[j, 1].set_yticks([])
    # ax[j, 2].set_title("negative img")
    ax[j, 2].imshow(ims[2])
    ax[j, 2].set_xticks([])
    ax[j, 2].set_yticks([])