In [9]:
import os
import argparse
from solver import Solver
from data_loader import get_loader
from torch.backends import cudnn
from data_loader import CelebA
from torchvision import transforms as T
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch.utils.data import DataLoader
def str2bool(v):
    return v.lower() in ('true')

def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)
    if not os.path.exists(config.sample_dir):
        os.makedirs(config.sample_dir)
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    # Data loader.
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)
    

    # Solver for training and testing StarGAN.
    solver = Solver(celeba_loader, rafd_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.train()
        elif config.dataset in ['Both']:
            solver.train_multi()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.test()
        elif config.dataset in ['Both']:
            solver.test_multi()

In [10]:
parser = argparse.ArgumentParser()

# Model configuration.
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
    
# Training configuration.
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
                    default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])

# Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')

# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=True)

# Directories.
parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images')
parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
parser.add_argument('--result_dir', type=str, default='stargan/results')

# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)
parser.add_argument('--lr_update_step', type=int, default=1000)

config = parser.parse_args('')

In [3]:
print(config)

Namespace(c_dim=5, c2_dim=8, celeba_crop_size=178, rafd_crop_size=256, image_size=128, g_conv_dim=64, d_conv_dim=64, g_repeat_num=6, d_repeat_num=6, lambda_cls=1, lambda_rec=10, lambda_gp=10, dataset='CelebA', batch_size=4, num_iters=200000, num_iters_decay=100000, g_lr=0.0001, d_lr=0.0001, n_critic=5, beta1=0.5, beta2=0.999, resume_iters=None, selected_attrs=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'], test_iters=200000, num_workers=8, mode='train', use_tensorboard=True, celeba_image_dir='data/celeba/images', attr_path='data/celeba/list_attr_celeba.txt', rafd_image_dir='data/RaFD/train', log_dir='stargan/logs', model_save_dir='stargan/models', sample_dir='stargan/samples', result_dir='stargan/results', log_step=10, sample_step=1000, model_save_step=10000, lr_update_step=1000)


In [11]:
# For fast training.
cudnn.benchmark = True

# Create directories if not exist.
if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)
if not os.path.exists(config.model_save_dir):
    os.makedirs(config.model_save_dir)
if not os.path.exists(config.sample_dir):
    os.makedirs(config.sample_dir)
if not os.path.exists(config.result_dir):
    os.makedirs(config.result_dir)

In [12]:
transform = []

transform.append(T.RandomHorizontalFlip())
transform.append(T.CenterCrop(256))
transform.append(T.Resize(224))
transform.append(T.ToTensor())
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform = T.Compose(transform)

dataset = CelebA(image_dir = config.celeba_image_dir, 
                 attr_path = config.attr_path, 
                 selected_attrs = config.selected_attrs, 
                 transform = transform, 
                 mode = 'train')

Finished preprocessing the CelebA dataset...


In [22]:
dataset.__dict__.keys()

dict_keys(['image_dir', 'attr_path', 'selected_attrs', 'transform', 'mode', 'train_dataset', 'test_dataset', 'attr2idx', 'idx2attr', 'num_images'])

In [18]:
dataset.selected_attrs

['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']

In [13]:
if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
if config.dataset in ['RaFD', 'Both']:
    rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                config.rafd_crop_size, config.image_size, config.batch_size,
                                'RaFD', config.mode, config.num_workers)

Finished preprocessing the CelebA dataset...


In [98]:
data = next(iter(dataset))

In [100]:
data[0].shape, data[1].shape

(torch.Size([3, 224, 224]), torch.Size([5]))

In [14]:
image_data = data[0] * 0.5 + 0.5  # Normalized tensors with mean 0.5

# Corresponding labels tensor
labels = data[1]  # Random integer labels for demonstration

# Visualize the first tensor in the batch
def visualize_tensor(tensor, title="Tensor"):
    tensor = tensor.permute(1, 2, 0)  # Permute to shape (H, W, C) for matplotlib
    plt.imshow(tensor)
    plt.title(title)
    plt.axis('off')
    plt.show()

# Plot all tensors in the batch
for i in range(image_data.shape[0]):
    visualize_tensor(image_data[i], title=f"Tensor {i+1} with label {labels[i]}")

print(f"Data shape: {image_data.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Labels: {labels}")

NameError: name 'data' is not defined

In [26]:
data[1].shape

torch.Size([4, 5])

In [19]:
solver = Solver(celeba_loader, None, None, config)

Generator(
  (main): Sequential(
    (0): Conv2d(8, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ResidualBlock(
      (main): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), paddi

2024-06-17 13:23:09.477645: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-17 13:23:09.499737: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-17 13:23:09.995611: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-06-17 13:23:09.996320: I external/local_xla/x

In [None]:
data_iter = iter(celeba_loader)
x_fixed, c_org = next(data_iter)
x_fixed = x_fixed.to(solver.device)
c_fixed_list = solver.create_labels(c_org, solver.c_dim, solver.dataset, solver.selected_attrs)

In [47]:
data_loader = celeba_loader
data_iter = iter(data_loader)
x_real, label_org = next(data_iter)

In [48]:
# Generate target domain labels randomly.
rand_idx = torch.randperm(label_org.size(0))
label_trg = label_org[rand_idx]

In [51]:
rand_idx, label_org, label_trg

(tensor([0, 2, 3, 1]),
 tensor([[0., 0., 0., 0., 1.],
         [1., 0., 0., 1., 1.],
         [1., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.]]),
 tensor([[0., 0., 0., 0., 1.],
         [1., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [1., 0., 0., 1., 1.]]))

In [53]:
c_org = label_org.clone()
c_trg = label_trg.clone()
x_real = x_real.to(solver.device)           # Input images.
c_org = c_org.to(solver.device)             # Original domain labels.
c_trg = c_trg.to(solver.device)             # Target domain labels.
label_org = label_org.to(solver.device)     # Labels for computing classification loss.
label_trg = label_trg.to(solver.device)     # Labels for computing classification loss.

# =================================================================================== #
#                             2. Train the discriminator                              #
# =================================================================================== #

# Compute loss with real images.
out_src, out_cls = solver.D(x_real)
d_loss_real = - torch.mean(out_src)

In [57]:
out_src.shape , out_cls.shape

(torch.Size([4, 1, 2, 2]), torch.Size([4, 5]))

In [59]:
d_loss_real

tensor(0.0004, device='cuda:0', grad_fn=<NegBackward0>)

In [42]:
x_fixed.shape, c_org.shape

(torch.Size([4, 3, 128, 128]), torch.Size([4, 5]))

In [152]:
import os.path as osp
import numpy as np
from __future__ import print_function, absolute_import
import os.path as osp
import glob
import re
import urllib
import zipfile

class BaseDataset(object):
    """
    Base class of reid dataset
    """

    def get_imagedata_info(self, data):
        pids, cams = [], []
        for item in data:
            pids += [item[1]]
            cams += [item[-1]]
        pids = set(pids)
        cams = set(cams)
        num_pids = len(pids)
        num_cams = len(cams)
        num_imgs = len(data)
        return num_pids, num_imgs, num_cams

    def get_videodata_info(self, data, return_tracklet_stats=False):
        pids, cams, tracklet_stats = [], [], []
        for img_paths, pid, camid in data:
            pids += [pid]
            cams += [camid]
            tracklet_stats += [len(img_paths)]
        pids = set(pids)
        cams = set(cams)
        num_pids = len(pids)
        num_cams = len(cams)
        num_tracklets = len(data)
        if return_tracklet_stats:
            return num_pids, num_tracklets, num_cams, tracklet_stats
        return num_pids, num_tracklets, num_cams

    def print_dataset_statistics(self):
        raise NotImplementedError

    @property
    def images_dir(self):
        return None

class BaseImageDataset(BaseDataset):
    """
    Base class of image reid dataset
    """

    def print_dataset_statistics(self, train, query, gallery):
        num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
        num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
        num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)

        print("Dataset statistics:")
        print("  ----------------------------------------")
        print("  subset   | # ids | # images | # cameras")
        print("  ----------------------------------------")
        print("  train    | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
        print("  query    | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
        print("  gallery  | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
        print("  ----------------------------------------")

class Market1501(BaseImageDataset):
    """
    Market1501
    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
    URL: http://www.liangzheng.org/Project/project_reid.html

    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    dataset_dir = 'Market-1501-v15.09.15'

    def __init__(self, root, ncl=1, verbose=True, **kwargs):
        super(Market1501, self).__init__()
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self.ncl = ncl
        self._check_before_run()

        train = self._process_dir(self.train_dir, relabel=True)
        query = self._process_dir(self.query_dir, relabel=False)
        gallery = self._process_dir(self.gallery_dir, relabel=False)

        if verbose:
            print("=> Market1501 loaded")
            self.print_dataset_statistics(train, query, gallery)

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            assert 0 <= pid <= 1501  # pid == 0 means background
            assert 1 <= camid <= 6
            camid -= 1  # index starts from 0

            if relabel: pid = pid2label[pid]
            pids=()
            for _ in range(self.ncl):
                pids=(pid,)+pids
            item=(img_path,) + pids + (camid,)
            dataset.append(item)
            # if relabel: pid = pid2label[pid]
            # dataset.append((img_path, pid, camid))

        return dataset
    
class DukeMTMC(BaseImageDataset):
    """
    DukeMTMC-reID
    Reference:
    1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
    2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
    URL: https://github.com/layumi/DukeMTMC-reID_evaluation

    Dataset statistics:
    # identities: 1404 (train + query)
    # images:16522 (train) + 2228 (query) + 17661 (gallery)
    # cameras: 8
    """
    dataset_dir = '.'

    def __init__(self, root, ncl=1, verbose=True, **kwargs):
        super(DukeMTMC, self).__init__()
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
        self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
        self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
        self.camstyle_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train_camstyle')

        self._download_data()
        self._check_before_run()

        self.ncl=ncl
        self.num_cam= 8
        # camstytrain = self._process_dir(self.camstylegallery_dir, relabel=True)
        train = self._process_dir(self.train_dir, relabel=True)
        query = self._process_dir(self.query_dir, relabel=False)
        gallery = self._process_dir(self.gallery_dir, relabel=False)

        if verbose:
            print("=> DukeMTMC-reID loaded")
            self.print_dataset_statistics(train, query, gallery)

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)

    def _download_data(self):
        if osp.exists(self.dataset_dir):
            print("This dataset has been downloaded.")
            return

        print("Creating directory {}".format(self.dataset_dir))
        mkdir_if_missing(self.dataset_dir)
        fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))

        print("Downloading DukeMTMC-reID dataset")
        urllib.request.urlretrieve(self.dataset_url, fpath)

        print("Extracting files")
        zip_ref = zipfile.ZipFile(fpath, 'r')
        zip_ref.extractall(self.dataset_dir)
        zip_ref.close()

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            assert 1 <= camid <= 8
            camid -= 1  # index starts from 0
            if relabel: pid = pid2label[pid]
            pids=()
            for _ in range(self.ncl):
                pids = (pid,)+pids
            item = (img_path,) + pids + (camid,)
            dataset.append(item)

        return dataset

In [225]:
market = Market1501('/home/jun/ReID_Dataset')
duke = DukeMTMC('/home/jun/ReID_Dataset')

=> Market1501 loaded
Dataset statistics:
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |   751 |    12936 |         6
  query    |   750 |     3368 |         6
  gallery  |   751 |    15913 |         6
  ----------------------------------------
This dataset has been downloaded.
=> DukeMTMC-reID loaded
Dataset statistics:
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |   702 |    16522 |         8
  query    |   702 |     2228 |         8
  gallery  |  1110 |    17661 |         8
  ----------------------------------------


In [117]:
market.num_train_cams

6

In [224]:
transform = []
transform.append(T.RandomHorizontalFlip())
transform.append(T.Resize((256,256), Image.BICUBIC))
transform.append(T.RandomCrop((224, 224)))
transform.append(T.ToTensor())
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform = T.Compose(transform)

In [162]:
next(iter(duke.train))

('/home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/0161_c2_f0082950.jpg',
 117,
 1)

In [215]:
from torch.utils.data import  Dataset

class Preprocessor(Dataset):
    def __init__(self, dataset1, dataset2, root=None, transform=None, mutual=False):
        super().__init__()
        self.num_cams = dataset1.num_train_cams + dataset2.num_train_cams
        self.dataset = []#dataset
        for inds, item in enumerate(dataset1.train):
            self.dataset.append((item[0],item[1],item[2],inds))
        len_dataset1 = len(self.dataset)
        for inds, item in enumerate(dataset2.train):
            self.dataset.append((item[0],item[1],item[2]+dataset1.num_train_cams,inds+len_dataset1))
        self.root = root
        self.transform = transform
        self.mutual = mutual

    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, indices):
        if self.mutual:
            return self._get_mutual_item(indices)
        else:
            return self._get_single_item(indices)

    def _get_single_item(self, index):
        items = self.dataset[index] # fname, pid,pid1,pid2, camid, inds
        fname, camid, inds =items[0],items[-2],items[-1]
        pids = []
        for i, pid in enumerate(items[1:-2]):
            pids.append(pid)

        fpath = fname
        if self.root is not None:
            fpath = osp.join(self.root, fname)

        img = Image.open(fpath).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        camid_one_hot_vector = torch.zeros(self.num_cams)
        camid_one_hot_vector[camid] = 1 
        return [img, fname]+ pids+[camid_one_hot_vector, inds]

class Preprocessor_starGAN(Dataset):
    def __init__(self, dataset1, dataset2, root=None, transform=None, mutual=False):
        super().__init__()
        self.num_cams = dataset1.num_train_cams + dataset2.num_train_cams
        self.dataset = []#dataset
        for inds, item in enumerate(dataset1.train):
            self.dataset.append((item[0],item[1],item[2],inds))
        len_dataset1 = len(self.dataset)
        for inds, item in enumerate(dataset2.train):
            self.dataset.append((item[0],item[1],item[2]+dataset1.num_train_cams,inds+len_dataset1))
        self.root = root
        self.transform = transform
        self.mutual = mutual

    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, indices):
        if self.mutual:
            return self._get_mutual_item(indices)
        else:
            return self._get_single_item(indices)

    def _get_single_item(self, index):
        items = self.dataset[index] # fname, pid,pid1,pid2, camid, inds
        fname, camid, inds =items[0],items[-2],items[-1]
        pids = []
        for i, pid in enumerate(items[1:-2]):
            pids.append(pid)

        fpath = fname
        if self.root is not None:
            fpath = osp.join(self.root, fname)

        img = Image.open(fpath).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        camid_one_hot_vector = torch.zeros(self.num_cams)
        camid_one_hot_vector[camid] = 1 
        return img, camid_one_hot_vector



In [226]:
dataloader = Preprocessor_starGAN(market, duke, root=market.images_dir,
                            transform=transform)

In [211]:
for data in dataloader: 
    if data[1].find('Duke') != -1:
        print(data[0].shape, data[1], data[2], data[3], data[4])

torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/0161_c2_f0082950.jpg 117 tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]) 200600
torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/1631_c3_f0098658.jpg 647 tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]) 200601
torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/0582_c1_f0154002.jpg 421 tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]) 200602
torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/1989_c8_f0161670.jpg 688 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]) 200603
torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/4472_c6_f0103169.jpg 261 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]) 200604
torch.Size([3, 224, 224]) /home/jun/ReID_Dataset/./DukeMTMC-reID/bounding_box_train/0281_c5_f01

In [227]:
data = next(iter(dataloader))

In [229]:
data[0]

tensor([[[-0.5529, -0.5529, -0.5529,  ..., -0.6863, -0.6784, -0.6706],
         [-0.5843, -0.5843, -0.5922,  ..., -0.7020, -0.6941, -0.6863],
         [-0.6078, -0.6078, -0.6157,  ..., -0.7020, -0.6941, -0.6863],
         ...,
         [-0.7098, -0.7098, -0.7020,  ..., -0.3098, -0.3176, -0.3176],
         [-0.7176, -0.7176, -0.7098,  ..., -0.3333, -0.3412, -0.3412],
         [-0.7333, -0.7333, -0.7255,  ..., -0.3804, -0.3804, -0.3882]],

        [[-0.5451, -0.5451, -0.5451,  ..., -0.6471, -0.6471, -0.6392],
         [-0.5765, -0.5765, -0.5843,  ..., -0.6627, -0.6549, -0.6392],
         [-0.6000, -0.6000, -0.6078,  ..., -0.6627, -0.6549, -0.6392],
         ...,
         [-0.7333, -0.7333, -0.7255,  ..., -0.3098, -0.3176, -0.3176],
         [-0.7412, -0.7333, -0.7333,  ..., -0.3333, -0.3412, -0.3412],
         [-0.7569, -0.7490, -0.7412,  ..., -0.3804, -0.3804, -0.3882]],

        [[-0.5059, -0.5059, -0.5059,  ..., -0.6078, -0.5922, -0.5843],
         [-0.5451, -0.5451, -0.5529,  ..., -0

In [230]:
data[1]

tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])