In [1]:
!pip install gdown
!gdown 1m-MczPwDlbz3Vy7Z0AiCTL0Ugo12d-82
!unzip minor.zip

import os
import torch
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import cv2 as cv
import albumentations as A
import pandas as pd

with open("Minor/train.txt","r+") as f0:
  train_list = f0.read().split("\n")

with open("Minor/val.txt","r+") as f0:
  val_list = f0.read().split("\n")

with open("Minor/test.txt","r+") as f0:
  test_list = f0.read().split("\n")

labels = pd.read_csv("Minor/labels.csv")


def construct_dataloder(dataset, batch_size, shuffle=True, num_workers=4):
    return DataLoader(dataset, batch_size=batch_size,
                      shuffle=shuffle, num_workers=num_workers)

class VIPDataset(Dataset):
    def __init__(self, file_ids, labels, data_dir,):
        self.file_ids = file_ids
        self.labels = labels
        self.data_dir = data_dir
        self.aug_pipeline = transform = A.Compose([
            A.augmentations.geometric.resize.RandomScale(scale_limit=[0,2],interpolation=cv.INTER_LINEAR,p=0.3),
            A.augmentations.transforms.ImageCompression(quality_lower=99, quality_upper=100, always_apply=False, p=0.3),
            A.augmentations.crops.transforms.RandomCrop(width=200, height=200,p=1)
        ])

    def __len__(self):
        return len(self.file_ids)

    def __getitem__(self, index):
        image_id = self.file_ids[index]
        x = np.array(Image.open(os.path.join(self.data_dir, image_id)))
        x = self.augment(x)
        x = torch.from_numpy(x)
        y = int(self.labels.loc[self.labels["image_ids"] == image_id,"label"])
        return x, y

    def augment(self, x):
        """Augmentations for images"""
        return self.aug_pipeline(image=x)["image"]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Minor/data/stylegan2_ffhq_ffhq-0509.png  
  inflating: Minor/data/000000224929.jpg  
  inflating: Minor/data/Taming_coco_samples_67_samples_nopix_002198.png  
  inflating: Minor/data/Taming_coco_samples_1970_samples_nopix_003720.png  
  inflating: Minor/data/000000514362.jpg  
  inflating: Minor/data/000000049491.jpg  
  inflating: Minor/data/000000260932.jpg  
  inflating: Minor/data/Taming_ffhq_014363.png  
  inflating: Minor/data/000000133515.jpg  
  inflating: Minor/data/000000541147.jpg  
  inflating: Minor/data/realffhq_17096.png  
  inflating: Minor/data/000000360617.jpg  
  inflating: Minor/data/Taming_coco_samples_577_samples_nopix_003299.png  
  inflating: Minor/data/000000568156.jpg  
  inflating: Minor/data/000000415790.jpg  
  inflating: Minor/data/realffhq_38876.png  
  inflating: Minor/data/gated_inpainting_id000000083147.png  
  inflating: Minor/data/glide_text2img_annot000000054672.png  
  in

In [5]:
train_ds = VIPDataset(train_list,labels,"Minor/data")
val_ds = VIPDataset(val_list,labels,"Minor/data")
test_ds = VIPDataset(test_list,labels,"Minor/data")

In [6]:
train_dl = construct_dataloder(train_ds, 8, shuffle=True, num_workers=2)
val_dl = construct_dataloder(val_ds, 8, shuffle=True, num_workers=2)
test_dl = construct_dataloder(test_ds, 8, shuffle=True, num_workers=2)

In [7]:
next(iter(train_dl))

[tensor([[[[108,  68,  63],
           [ 86,  51,  47],
           [104,  73,  70],
           ...,
           [160, 134, 122],
           [164, 136, 125],
           [173, 134, 126]],
 
          [[104,  64,  59],
           [100,  65,  61],
           [ 77,  47,  42],
           ...,
           [156, 130, 116],
           [142, 114, 103],
           [169, 130, 122]],
 
          [[ 87,  58,  53],
           [ 84,  55,  50],
           [ 83,  57,  51],
           ...,
           [134, 108,  94],
           [156, 129, 119],
           [157, 123, 117]],
 
          ...,
 
          [[ 86,  70,  63],
           [ 89,  74,  71],
           [ 42,  31,  27],
           ...,
           [147, 112, 102],
           [149, 114, 101],
           [135,  97,  85]],
 
          [[ 61,  42,  35],
           [148, 129, 124],
           [ 45,  29,  23],
           ...,
           [114,  87,  83],
           [123,  95,  88],
           [103,  71,  65]],
 
          [[ 60,  41,  34],
           [ 56,  37