In [1]:
import torch
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import os.path as osp
from glob import glob
from PIL import Image
import random
from tqdm.notebook import tqdm

In [2]:
class DISC21Definition(object):
    def __init__(self, root):
        self.dataset_dir = root
        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.gallery_dir = osp.join(self.dataset_dir, 'validation')
        self.query_dir = osp.join(self.dataset_dir, 'test')
        self.train = []
        self.gallery = []
        self.query = []
        self.num_train_pids = 0
        self.num_gallery_pids = 0
        self.num_query_pids = 0
        self.has_time_info = False
        self.load()

    def preprocess(self, splitter='T', fpaths=None):
        if fpaths is None:
            fpaths = glob(osp.join(self.train_dir, '*.jpg'))
        else:
            fpaths = glob(osp.join(fpaths, '*.jpg'))
        data = []
        all_pids = {}
        for fpath in fpaths:
            fname = osp.basename(fpath)
            pid = int(fname[:-4].split(splitter)[1])
            if pid not in all_pids:
                all_pids[pid] = len(all_pids)
            pid = all_pids[pid]  # relabel
            camid = 0
            data.append((self.train_dir + '/' + fname, fname))
        return data, int(len(all_pids))

    def load(self):
        self.train, self.num_train_pids = self.preprocess('T', self.train_dir)
        self.gallery, self.num_gallery_pids = self.preprocess('R', self.gallery_dir)
        self.query, self.num_query_pids = self.preprocess('Q', self.query_dir)
        print(self.__class__.__name__, "dataset loaded")
        print("  subset   | # ids | # images")
        print("  ---------------------------")
        print("  train    | {:6d} | {:8d}".format(self.num_train_pids, len(self.train)))
        print("  gallery  | {:6d} | {:8d}".format(self.num_gallery_pids, len(self.gallery)))
        print("  query    | {:6d} | {:8d}".format(self.num_query_pids, len(self.query)))

In [3]:
class DISC21(Dataset):
    def __init__(self, df, train=True, gallery=True, transform=None, augmentations=None):
        self.is_train = train
        self.is_gallery = gallery
        self.transform = transform
        self.augmentations = transform if augmentations is None else augmentations

        if self.is_train:
            self.images = df.train
        elif self.is_gallery:
            self.images = df.gallery
        else:
            self.images = df.query

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        full_name, name = self.images[index]
        anchor_img = Image.open(full_name)

        if self.is_train:
            positive_img = anchor_img

            negative_index = index
            while negative_index == index:
                negative_index = random.randrange(len(self.images))
            negative_full_name, negative_name = self.images[negative_index]
            negative_img = Image.open(negative_full_name)

            if self.transform:
                anchor_img = self.transform(anchor_img)
                positive_img = self.augmentations(positive_img)
                negative_img = self.augmentations(negative_img)

            return anchor_img, positive_img, negative_img, name
        else:
            if self.transform:
                anchor_img = self.transform(anchor_img)
            return anchor_img

In [4]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

In [5]:
train_df = DISC21Definition('/media/augustinas/T7/DISC2021/SmallData/images/')
train_ds = DISC21(train_df, train=True, transform=transformation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [6]:
embedding_dims = 2
batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [8]:
for epoch in tqdm(range(1), desc="Epochs"):
    running_loss = []
    for step, (anchor_img, positive_img, negative_img, anchor_name) in enumerate(
            tqdm(train_loader, desc="Training", leave=False)):
        continue

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Training:   0%|          | 0/3125 [00:00<?, ?it/s]

KeyboardInterrupt: 