In [34]:
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import os.path as osp
from glob import glob
from PIL import Image
import random
from tqdm.notebook import tqdm
import torch
from torchvision.models import ResNet101_Weights
import torch.optim as optim
import numpy as np
import augly.image as imaugs

In [35]:
class DISC21Definition(object):
    def __init__(self, root):
        self.dataset_dir = root
        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.gallery_dir = osp.join(self.dataset_dir, 'validation')
        self.query_dir = osp.join(self.dataset_dir, 'test')
        self.train = []
        self.gallery = []
        self.query = []
        self.num_train_pids = 0
        self.num_gallery_pids = 0
        self.num_query_pids = 0
        self.has_time_info = False
        self.load()

    def preprocess(self, splitter='T', fpaths=None):
        if fpaths is None:
            fpaths = glob(osp.join(self.train_dir, '*.jpg'))
        else:
            fpaths = glob(osp.join(fpaths, '*.jpg'))
        data = []
        all_pids = {}
        for fpath in fpaths:
            fname = osp.basename(fpath)
            pid = int(fname[:-4].split(splitter)[1])
            if pid not in all_pids:
                all_pids[pid] = len(all_pids)
            data.append((self.train_dir + '/' + fname, fname))
        return data, int(len(all_pids))

    def load(self):
        self.train, self.num_train_pids = self.preprocess('T', self.train_dir)
        self.gallery, self.num_gallery_pids = self.preprocess('R', self.gallery_dir)
        self.query, self.num_query_pids = self.preprocess('Q', self.query_dir)
        print(self.__class__.__name__, "dataset loaded")
        print("  subset   | # ids | # images")
        print("  ---------------------------")
        print("  train    | {:6d} | {:8d}".format(self.num_train_pids, len(self.train)))
        print("  gallery  | {:6d} | {:8d}".format(self.num_gallery_pids, len(self.gallery)))
        print("  query    | {:6d} | {:8d}".format(self.num_query_pids, len(self.query)))

In [36]:
class DISC21(Dataset):
    def __init__(self, df, train=True, gallery=True, transform=None, augmentations=None):
        self.is_train = train
        self.is_gallery = gallery
        self.transform = transform
        self.augmentations = transform if augmentations is None else augmentations

        if self.is_train:
            self.images = df.train
        elif self.is_gallery:
            self.images = df.gallery
        else:
            self.images = df.query

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        full_name, name = self.images[index]
        anchor_img = Image.open(full_name)

        if self.is_train:
            positive_img = anchor_img

            negative_index = index
            while negative_index == index:
                negative_index = random.randrange(len(self.images))
            negative_full_name, negative_name = self.images[negative_index]
            negative_img = Image.open(negative_full_name)

            if self.transform:
                anchor_img = self.transform(anchor_img)
                positive_img = self.augmentations(positive_img)
                negative_img = self.augmentations(negative_img)

            return anchor_img, positive_img, negative_img, name
        else:
            if self.transform:
                anchor_img = self.transform(anchor_img)
            return anchor_img

In [42]:
transformation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

augmentation_chain = transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        transforms.Resize(256),
        transforms.CenterCrop(224),
        imaugs.Brightness(factor=2.0),
        imaugs.RandomRotation(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

In [43]:
train_df = DISC21Definition('/media/augustinas/T7/DISC2021/SmallData/images/')
train_ds = DISC21(train_df, train=True, transform=transformation_chain, augmentations=augmentation_chain)

DISC21Definition dataset loaded
  subset   | # ids | # images
  ---------------------------
  train    | 100000 |   100000
  gallery  | 100000 |   100000
  query    |  10000 |    10000


In [11]:
embedding_dims = 2
batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [7]:
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

True
cuda:0


In [8]:
# Load model
model = models.resnet101(weights=ResNet101_Weights.DEFAULT).to(device)
model.fc = torch.nn.Identity()
model.avgpool = torch.nn.Identity()
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [9]:
epoch_count = 10
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_func = torch.nn.TripletMarginLoss()

In [13]:
model.train()
for epoch in tqdm(range(epoch_count), desc="Epochs"):
    running_loss = []
    for step, (anchor_img, positive_img, negative_img, anchor_label) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)

        anchor_out = model(anchor_img)
        positive_out = model(positive_img)
        negative_out = model(negative_img)

        loss = loss_func(anchor_out, positive_out, negative_out)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss.append(loss.cpu().detach().numpy())
    print("Epoch: {}/{} - Loss: {:.4f}".format(epoch+1, epoch_count, np.mean(running_loss)))

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/3125 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 7.92 GiB total capacity; 7.00 GiB already allocated; 12.88 MiB free; 7.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF