In [1]:
from __future__ import print_function, absolute_import
import argparse
import os.path as osp
import random
import numpy as np
import sys

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import DataLoader

from UDAsbs import datasets
from UDAsbs import models
from UDAsbs.trainers import PreTrainer, PreTrainer_multi
from UDAsbs.evaluators import Evaluator
from UDAsbs.utils.data import IterLoader
from UDAsbs.utils.data import transforms as T
from UDAsbs.utils.data.sampler import RandomMultipleGallerySampler
from UDAsbs.utils.data.preprocessor import Preprocessor
from UDAsbs.utils.logging import Logger
from UDAsbs.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict
from UDAsbs.utils.lr_scheduler import WarmupMultiStepLR


start_epoch = best_mAP = 0

2.3.0


In [2]:
def get_data(name, data_dir, height, width, batch_size, workers, num_instances, iters=200):
    root = osp.join(data_dir)

    dataset = datasets.create(name, root)

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])

    train_set = dataset.train
    num_classes = dataset.num_train_pids

    train_transformer = T.Compose([
             T.Resize((height, width), interpolation=3),
             T.RandomHorizontalFlip(p=0.5),
             T.Pad(10),
             T.RandomCrop((height, width)),
             # T.AugMix(),
             T.ToTensor(),
             normalizer
         ])


    test_transformer = T.Compose([
             T.Resize((height, width), interpolation=3),
             T.ToTensor(),
             normalizer
         ])

    rmgs_flag = num_instances > 0
    if rmgs_flag:
        sampler = RandomMultipleGallerySampler(train_set, num_instances)
    else:
        sampler = None

    train_loader = IterLoader(
                DataLoader(Preprocessor(train_set, root=dataset.images_dir,
                                        transform=train_transformer),
                            batch_size=batch_size, num_workers=workers, sampler=sampler,
                            shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters)

    test_loader = DataLoader(
        Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
                     root=dataset.images_dir, transform=test_transformer),
        batch_size=batch_size, num_workers=workers,
        shuffle=False, pin_memory=True)

    return dataset, num_classes, train_loader, test_loader

In [3]:
parser = argparse.ArgumentParser(description="Pre-training on the source domain")
# data
parser.add_argument('-ds', '--dataset-source', type=str, default='market1501',
                    choices=datasets.names())
parser.add_argument('-dt', '--dataset-target', type=str, default='dukemtmc',
                    choices=datasets.names())
parser.add_argument('-b', '--batch-size', type=int, default=64)
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
                    help="each minibatch consist of "
                         "(batch_size // num_instances) identities, and "
                         "each identity has num_instances instances, "
                         "default: 0 (NOT USE)")
# model
parser.add_argument('-a', '--arch', type=str, default='resnet50',
                    choices=models.names())
parser.add_argument('--features', type=int, default=0)
parser.add_argument('--dropout', type=float, default=0)
# optimizer
parser.add_argument('--lr', type=float, default=0.00035,
                    help="learning rate of new parameters, for pretrained ")
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--warmup-step', type=int, default=10)
parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70], help='milestones for the learning rate decay')
# training configs
parser.add_argument('--resume', type=str, default="", metavar='PATH')
#logs/market1501TOdukemtmc/resnet50-pretrain-1_gempooling/model_best.pth.tar
parser.add_argument('--evaluate', action='store_true',
                    help="evaluation only")
parser.add_argument('--eval-step', type=int, default=40)
parser.add_argument('--rerank', action='store_true',
                    help="evaluation only")
parser.add_argument('--epochs', type=int, default=80)
parser.add_argument('--iters', type=int, default=200)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--print-freq', type=int, default=100)
parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')
# path
working_dir = osp.dirname(osp.abspath(''))
parser.add_argument('--data-dir', type=str, metavar='PATH',
                    default=osp.join(working_dir, '/home/jun/ReID_Dataset/'))
parser.add_argument('--logs-dir', type=str, metavar='PATH',
                    default=osp.join(working_dir, 'logs/demo'))

args = parser.parse_args('')

In [4]:
global start_epoch, best_mAP

cudnn.benchmark = True

if not args.evaluate:
    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
else:
    log_dir = osp.dirname(args.resume)
    sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
print("==========\nArgs:{}\n==========".format(args))

Args:Namespace(dataset_source='market1501', dataset_target='dukemtmc', batch_size=64, workers=4, height=256, width=128, num_instances=4, arch='resnet50', features=0, dropout=0, lr=0.00035, momentum=0.9, weight_decay=0.0005, warmup_step=10, milestones=[40, 70], resume='', evaluate=False, eval_step=40, rerank=False, epochs=80, iters=200, seed=1, print_freq=100, margin=0.0, data_dir='/home/jun/ReID_Dataset/', logs_dir='/home/jun/logs/demo')


In [12]:
# Create data loaders
iters = args.iters if (args.iters>0) else None
iters = None
dataset_source, num_classes, train_loader_source, test_loader_source = \
    get_data(args.dataset_source, args.data_dir, args.height,
                args.width, args.batch_size, args.workers, args.num_instances, iters)

dataset_target, _, train_loader_target, test_loader_target = \
    get_data(args.dataset_target, args.data_dir, args.height,
                args.width, args.batch_size, args.workers, 0, iters)

In [13]:
len(train_loader_source)

46

In [6]:
# Create model
print(f'Creating {args.arch} model with num_features = {args.features}, dropout = {args.dropout}, num_classes = {[num_classes]}')
model = models.create(args.arch, num_features=args.features, dropout=args.dropout,
                      num_classes=[num_classes])
model.cuda()
model = nn.DataParallel(model)

In [38]:
from torch.autograd import Variable
input = Variable(torch.FloatTensor(32, 3, 256, 128)).cuda()
model.train()
output = model(input, training=True)


In [40]:
len(output)

4

In [39]:
output[0].shape, output[1][0].shape

(torch.Size([32, 2048]), torch.Size([32, 751]))

In [7]:
# Load from checkpoint
if args.resume:
    checkpoint = load_checkpoint(args.resume)
    copy_state_dict(checkpoint['state_dict'], model)
    start_epoch = checkpoint['epoch']
    best_mAP = checkpoint['best_mAP']
    print("=> Start epoch {}  best mAP {:.1%}"
          .format(start_epoch, best_mAP))

In [8]:
# Evaluator
evaluator = Evaluator(model)
# args.evaluate=True
if args.evaluate:
    print("Test on source domain:")
    evaluator.evaluate(test_loader_source, dataset_source.query, dataset_source.gallery, cmc_flag=True, rerank=args.rerank)
    print("Test on target domain:")
    evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True, rerank=args.rerank)
    

In [9]:
params = []
for key, value in model.named_parameters():
    if not value.requires_grad:
        continue
    params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}]
optimizer = torch.optim.Adam(params)
lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.01,
                                 warmup_iters=args.warmup_step)

# Trainer
trainer = PreTrainer(model, num_classes, margin=args.margin) if 'multi' not in args.arch else PreTrainer_multi(model, num_classes, margin=args.margin)

In [25]:
for epoch in range(start_epoch, args.epochs):

    train_loader_source.new_epoch()
    train_loader_target.new_epoch()

    trainer.train(epoch, train_loader_source, train_loader_target, optimizer,
                train_iters=len(train_loader_source), print_freq=args.print_freq)
    lr_scheduler.step()
    if ((epoch+1)%args.eval_step==0 or (epoch==args.epochs-1)):

        _, mAP = evaluator.evaluate(test_loader_source, dataset_source.query,
                                    dataset_source.gallery, cmc_flag=True)

        is_best = mAP > best_mAP
        best_mAP = max(mAP, best_mAP)
        save_checkpoint({
            'state_dict': model.state_dict(),
            'epoch': epoch + 1,
            'best_mAP': best_mAP,
        }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  source mAP: {:5.1%}  best: {:5.1%}{}\n'.
                format(epoch, mAP, best_mAP, ' *' if is_best else ''))
print("Test on source domain:")
evaluator.evaluate(test_loader_source, dataset_source.query, dataset_source.gallery, cmc_flag=True, rerank=args.rerank)
print("Test on target domain:")
evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True, rerank=args.rerank)

KeyboardInterrupt: 