diff --git a/README.md b/README.md index 5015ac0..681960a 100644 --- a/README.md +++ b/README.md @@ -1 +1,89 @@ -# elastic \ No newline at end of file +# Elastic +This repo contains a PyTorch implementation of Elastic. It is compatible with PyTorch 1.0-stable, PyTorch 1.0-preview and PyTorch 0.4.1. All released models are exactly the models evaluated in the paper. +## ImageNet +We prepare our data following https://github.com/pytorch/examples/tree/master/imagenet + +Pretrained models available at +``` +for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic dla102x se_resnext50_elastic densenet201 densenet201_elastic; do + wget http://ai2-vision.s3.amazonaws.com/elastic/imagenet_models/"$a".pth.tar +done +``` +### Testing +``` +python classify.py /path/to/imagenet/ --evaluate --resume /path/to/model.pth.tar +``` +### Training +``` +python classify.py /path/to/imagenet/ +``` +### Multi-processing distributed training in Docker (recommended): +We train all the models in docker containers: https://docs.nvidia.com/deeplearning/dgx/pytorch-release-notes/rel_18.07.html + +You may need to follow instructions in the link above to install [docker](https://www.docker.com/) and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) if you haven't done so. + +After pulling the docker image, we run a docker container: +``` +nvidia-docker run -it -e NVIDIA_VISIBLE_DEVICES=0,1 --ipc=host --rm -v /path/to/code:/path/to/code -v /path/to/imagenet:/path/to/imagenet nvcr.io/nvidia/pytorch:18.07-py3 +``` +Then run this training script inside the docker container. +``` +python -m apex.parallel.multiproc docker_classify.py /path/to/imagenet +``` +## MSCOCO +We extract data into this structure and use python cocoapi to load data: https://github.com/cocodataset/cocoapi +``` +/path/to/mscoco/annotations/instances_train2014.json +/path/to/mscoco/annotations/instances_val2014.json +/path/to/mscoco/train2014 +/path/to/mscoco/val2014 +``` +Pretrained models available at +``` +for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic densenet201 densenet201_elastic; do + wget http://ai2-vision.s3.amazonaws.com/elastic/coco_models/coco_"$a".pth.tar +done +``` +### Testing +``` +python multilabel_classify.py /path/to/mscoco --resume /path/to/model.pth.tar --evaluate +``` +### Finetuning or resume training +``` +python multilabel_classify.py /path/to/mscoco --resume /path/to/model.pth.tar +``` +## PASCAL VOC semantic segmentation +We prepare PASCAL VOC data following https://github.com/chenxi116/DeepLabv3.pytorch + +Pretrained models available at +``` +for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic; do + wget http://ai2-vision.s3.amazonaws.com/elastic/pascal_models/deeplab_"$a"_pascal_v3_original_epoch50.pth +done +``` +### Testing +Models should be put at data/deeplab_*.pth +``` +CUDA_VISIBLE_DEVICES=0 python segment.py --exp original +``` +### Finetuning or resume training +All PASCAL VOC semantic segmentation models are trained on one GPU. +``` +CUDA_VISIBLE_DEVICES=0 python segment.py --exp my_exp --train --resume /path/to/model.pth.tar +``` +## Note +Distributed training maintains batchnorm statistics on each GPU/worker/process without synchronization, which leads to different performances on different GPUs. At the end of each epoch, our distributed script reports averaged performance (top-1, top-5) by evaluating the whole validation set on all GPUs, and saves the model on the first GPU (throws away models on other GPUs). As a result, evaluating the saved model after training leads to slightly (<0.1%) different (could be either better or worse) numbers. In the paper, we reported the average performances for all models. Averaging batchnorm statistics before evaluation may lead to marginally better numbers. + +## Credits +ImageNet training script is modified from https://github.com/pytorch/pytorch + +ImageNet distributed training script is modified from https://github.com/NVIDIA/apex + +Pascal segmentation code is modified from https://github.com/chenxi116/DeepLabv3.pytorch + +ResNext model is modified form https://github.com/last-one/tools + +DLA models are modified from https://github.com/ucbdrive/dla + +DenseNet model is modified from https://github.com/csrhddlam/pytorch-checkpoint + diff --git a/classify.py b/classify.py new file mode 100644 index 0000000..e80b93a --- /dev/null +++ b/classify.py @@ -0,0 +1,321 @@ +import argparse +import time +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data as data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import models +import os +import datetime +from utils import add_flops_counting_methods, accuracy, save_checkpoint, AverageMeter + + +model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', + 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', + 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') +parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') +parser.add_argument('--epochs', default=120, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('-g', '--num-gpus', default=8, type=int, + metavar='N', help='number of GPUs to match (default: 8)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=117, type=int, + metavar='N', help='print frequency (default: 117)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, + help='distributed backend') + +best_err1 = 100 + + +def main(): + global args, best_err1 + args = parser.parse_args() + print('config: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'num_gpus', args.num_gpus) + iteration_size = args.num_gpus // torch.cuda.device_count() # do multiple iterations + assert iteration_size >= 1 + args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr + args.lr = args.lr / iteration_size + args.batch_size = args.batch_size // iteration_size + print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) + + args.distributed = args.world_size > 1 + + if args.distributed: + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size) + + # create model + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # count number of parameters + count = 0 + params = list() + for n, p in model.named_parameters(): + if '.ups.' not in n: + params.append(p) + count += np.prod(p.size()) + print('Parameters:', count) + + # count flops + model = add_flops_counting_methods(model) + model.eval() + image = torch.randn(1, 3, 224, 224) + + model.start_flops_count() + model(image).sum() + model.stop_flops_count() + print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) + + # normal code + if not args.distributed: + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + else: + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) + + # cuda warm up + model = model.cuda() + image = torch.randn(args.batch_size, 3, 224, 224) + image_cuda = image.cuda() + + for i in range(3): + start = time.time() + model(image_cuda).sum().backward() # Warmup CUDA memory allocator + print(time.time() - start) + + # with torch.autograd.profiler.profile(use_cuda=True) as prof: + # start = time.time() + # model(image_cuda).sum().backward() + # print(time.time() - start) + # prof.export_chrome_trace('trace_gpu') + + # import cProfile, pstats, io + # pr = cProfile.Profile(time.perf_counter) + # pr.enable() + # model(image_cuda).sum().backward() + # pr.disable() + # s = io.StringIO() + # sortby = 'cumulative' + # ps = pstats.Stats(pr, stream=s).sort_stats(sortby) + # ps.print_stats() + # print(s.getvalue()) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, + ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + + model.load_state_dict(checkpoint['state_dict'], strict=False) if 'state_dict' in checkpoint else print('no state_dict found') + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') + args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch + best_err1 = checkpoint['best_err1'] if 'best_err' in checkpoint else best_err1 + + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, iteration_size) + + # evaluate on validation set + err1 = validate(val_loader, model, criterion) + + # remember best err@1 and save checkpoint + is_best = err1 < best_err1 + best_err1 = min(err1, best_err1) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_err1': best_err1, + 'optimizer': optimizer.state_dict(), + }, is_best, filename=args.arch + '_checkpoint.pth.tar') + print(str(float(best_err1))) + + +def train(train_loader, model, criterion, optimizer, epoch, iteration_size): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to train mode + model.train() + optimizer.zero_grad() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + target = target.cuda(non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(float(loss), input.size(0)) + top1.update(100 - float(prec1), input.size(0)) + top5.update(100 - float(prec5), input.size(0)) + # compute gradient and do SGD step + loss.backward() + + if i % iteration_size == iteration_size - 1: + optimizer.step() + optimizer.zero_grad() + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + + # compute output + with torch.no_grad(): + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(float(loss), input.size(0)) + top1.update(100 - float(prec1), input.size(0)) + top5.update(100 - float(prec5), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + print(str(datetime.datetime.now()) + ' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + return top1.avg + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +if __name__ == '__main__': + main() diff --git a/data/pascal_seg_colormap.mat b/data/pascal_seg_colormap.mat new file mode 100755 index 0000000..24fa7b6 Binary files /dev/null and b/data/pascal_seg_colormap.mat differ diff --git a/docker_classify.py b/docker_classify.py new file mode 100644 index 0000000..fee75be --- /dev/null +++ b/docker_classify.py @@ -0,0 +1,453 @@ +import argparse +import os +import shutil +import time + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +# import torchvision.models as models +import models +import numpy as np +import gc + +from utils import add_flops_counting_methods, save_checkpoint, AverageMeter, accuracy + +try: + from apex.parallel import DistributedDataParallel as DDP + from apex.fp16_utils import * +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + +model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', + 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', + 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers (default: 8)') +parser.add_argument('--epochs', default=120, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=32, type=int, + metavar='N', help='mini-batch size (default: 32)') +parser.add_argument('-g', '--num-gpus', default=8, type=int, + metavar='N', help='number of GPUs we pretend to have (default: 8)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=117, type=int, + metavar='N', help='print frequency (default: 117)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--fp16', action='store_true', + help='Run model fp16 mode.') +parser.add_argument('--static-loss-scale', type=float, default=1, + help='Static loss scale, positive power of 2 values can improve fp16 convergence.') +parser.add_argument('--dist-url', default='file://sync.file', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--world-size', default=1, type=int, + help='Number of GPUs to use. Can either be manually set ' + + 'or automatically set by using \'python -m multiproc\'.') +parser.add_argument('--rank', default=0, type=int, + help='Used for multi-process training. Can either be manually set ' + + 'or automatically set by using \'python -m multiproc\'.') + +cudnn.benchmark = True + + +def fast_collate(batch): + imgs = [img[0] for img in batch] + targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) + w = imgs[0].size[0] + h = imgs[0].size[1] + tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) + for i, img in enumerate(imgs): + nump_array = np.asarray(img, dtype=np.uint8) + tens = torch.from_numpy(nump_array) + if(nump_array.ndim < 3): + nump_array = np.expand_dims(nump_array, axis=-1) + nump_array = np.rollaxis(nump_array, 2) + + tensor[i] += torch.from_numpy(nump_array) + + return tensor, targets + + +best_err1 = 100 +args = parser.parse_args() + + +def main(): + global best_err1, args + + iteration_size = args.num_gpus // args.world_size + args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr + args.lr = args.lr / iteration_size + print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) + args.distributed = args.world_size > 1 + args.gpu = 0 + if args.distributed: + args.gpu = args.rank % torch.cuda.device_count() + + if args.distributed: + torch.cuda.set_device(args.gpu) + dist.init_process_group(backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + + if args.fp16: + assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." + + # create model + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # count number of parameters + count = 0 + params = list() + for n, p in model.named_parameters(): + if '.ups.' not in n: + params.append(p) + count += np.prod(p.size()) + print('Parameters:', count) + + # count flops + model = add_flops_counting_methods(model) + model.eval() + image = torch.randn(1, 3, 224, 224) + + model.start_flops_count() + model(image).sum() + model.stop_flops_count() + print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) + + model = model.cuda() + if args.fp16: + model = network_to_half(model) + if args.distributed: + #shared param turns off bucketing in DDP, for lower latency runs this can improve perf + model = DDP(model, shared_param=True) + + global model_params, master_params + if args.fp16: + model_params, master_params = prep_param_lists(model) + else: + master_params = list(model.parameters()) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, + ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + + model.load_state_dict(checkpoint['state_dict'], strict=False) if 'state_dict' in checkpoint else print('no state_dict found') + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') + args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch + best_err1 = checkpoint['best_err1'] if 'best_err' in checkpoint else best_err1 + + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + + crop_size = 224 + val_size = 256 + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(crop_size), + transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), Too slow + # normalize, + ])) + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(val_size), + transforms.CenterCrop(crop_size), + ])) + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate, drop_last=True) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(val_size), + transforms.CenterCrop(crop_size), + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, + collate_fn=fast_collate) + # print(len(train_loader), len(val_loader)) + if args.evaluate: + validate(val_loader, model, criterion) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch) + print('allocated before', torch.cuda.memory_allocated()) + print('cached before', torch.cuda.memory_cached()) + gc.collect() + torch.cuda.empty_cache() + print('allocated after', torch.cuda.memory_allocated()) + print('cached after', torch.cuda.memory_cached()) + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, iteration_size) + +# # sync models on multiple GPUs +# if args.rank == 0: +# save_checkpoint({ +# 'epoch': epoch + 1, +# 'arch': args.arch, +# 'state_dict': model.state_dict(), +# 'optimizer' : optimizer.state_dict(), +# }, False, 'temp.pth.tar') +# # barrier +# loss = torch.FloatTensor([args.rank]).cuda() +# reduced_loss = reduce_tensor(loss.data) +# print(loss.data, reduced_loss) +# if os.path.isfile('temp.pth.tar'): +# print("=> loading checkpoint '{}'".format('temp.pth.tar')) +# checkpoint = torch.load('temp.pth.tar', map_location = lambda storage, loc: storage.cuda(args.gpu)) +# model.load_state_dict(checkpoint['state_dict'], strict=False) +# optimizer.load_state_dict(checkpoint['optimizer']) +# print("=> loaded checkpoint '{}' (epoch {})" +# .format('temp.pth.tar', checkpoint['epoch'])) +# assert checkpoint['epoch'] == epoch + 1 + + # evaluate on validation set + err1 = validate(val_loader, model, criterion) + # remember best err@1 and save checkpoint + if args.rank == 0: + is_best = err1 < best_err1 + best_err1 = min(err1, best_err1) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_err1': best_err1, + 'optimizer': optimizer.state_dict(), + }, is_best) + print(str(float(best_err1))) + + +class data_prefetcher(): + def __init__(self, loader): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) + self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) + if args.fp16: + self.mean = self.mean.half() + self.std = self.std.half() + self.preload() + + def preload(self): + try: + self.next_input, self.next_target = next(self.loader) + except StopIteration: + self.next_input = None + self.next_target = None + return + with torch.cuda.stream(self.stream): + self.next_input = self.next_input.cuda(async=True) + self.next_target = self.next_target.cuda(async=True) + if args.fp16: + self.next_input = self.next_input.half() + else: + self.next_input = self.next_input.float() + self.next_input = self.next_input.sub_(self.mean).div_(self.std) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + target = self.next_target + self.preload() + return input, target + + +def train(train_loader, model, criterion, optimizer, epoch, iteration_size): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to train mode + model.train() + optimizer.zero_grad() + + end = time.time() + + prefetcher = data_prefetcher(train_loader) + input, target = prefetcher.next() + i = -1 + while input is not None: + i += 1 + + # measure data loading time + data_time.update(time.time() - end) + input_var = Variable(input) + target_var = Variable(target) + + # compute output + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + + if args.distributed: + reduced_loss = reduce_tensor(loss.data) + prec1 = reduce_tensor(prec1) + prec5 = reduce_tensor(prec5) + else: + reduced_loss = loss.data + + losses.update(to_python_float(reduced_loss), input.size(0)) + top1.update(100 - to_python_float(prec1), input.size(0)) + top5.update(100 - to_python_float(prec5), input.size(0)) + + loss = loss*args.static_loss_scale + # compute gradient and do SGD step + loss.backward() + if i % iteration_size == iteration_size - 1: + optimizer.step() + optimizer.zero_grad() + + torch.cuda.synchronize() + # measure elapsed time + batch_time.update(time.time() - end) + + end = time.time() + input, target = prefetcher.next() + if args.rank == 0 and i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Speed {3:.3f} ({4:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), + args.world_size * args.batch_size / batch_time.val, + args.world_size * args.batch_size / batch_time.avg, + batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + + prefetcher = data_prefetcher(val_loader) + input, target = prefetcher.next() + i = -1 + while input is not None: + i += 1 + + target = target.cuda(async=True) + input_var = Variable(input) + target_var = Variable(target) + + # compute output + with torch.no_grad(): + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + + if args.distributed: + reduced_loss = reduce_tensor(loss.data) + prec1 = reduce_tensor(prec1) + prec5 = reduce_tensor(prec5) + else: + reduced_loss = loss.data + + losses.update(to_python_float(reduced_loss), input.size(0)) + top1.update(100 - to_python_float(prec1), input.size(0)) + top5.update(100 - to_python_float(prec5), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if args.rank == 0 and i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Speed {2:.3f} ({3:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), + args.world_size * args.batch_size / batch_time.val, + args.world_size * args.batch_size / batch_time.avg, + batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + input, target = prefetcher.next() + print(' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + return top1.avg + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= args.world_size + return rt + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..dc46138 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +from .densenet import * +from .dla import * +from .dla_up import * +from .resnext import * diff --git a/models/densenet.py b/models/densenet.py new file mode 100644 index 0000000..7f20e76 --- /dev/null +++ b/models/densenet.py @@ -0,0 +1,204 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from .utils import CheckpointFunction, CpBatchNorm2d + + +class _DenseLayerElastic(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): + super(_DenseLayerElastic, self).__init__() + self.pool = nn.AvgPool2d(2, stride=2) + self.dummy = nn.Sequential() + self.add_module('conv1_d', nn.Conv2d(num_input_features, bn_size * + growth_rate // 2, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2_d', CpBatchNorm2d(bn_size * growth_rate // 2)), + self.add_module('relu2_d', nn.ReLU(inplace=True)), + self.add_module('conv2_d', nn.Conv2d(bn_size * growth_rate // 2, growth_rate, + kernel_size=3, stride=1, padding=1, bias=False)), + self.add_module('norm1', CpBatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * + growth_rate // 2, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', CpBatchNorm2d(bn_size * growth_rate // 2)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate // 2, growth_rate, + kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + self.efficient = efficient + + def forward(self, *prev_features): + concated_features = torch.cat(prev_features, 1) + bottleneck_output = self.relu1(self.norm1(concated_features)) + bottleneck_output_d = bottleneck_output + if prev_features[0].size(2) != 7: + bottleneck_output_d = self.pool(bottleneck_output_d) + bottleneck_output_d = self.conv1_d(bottleneck_output_d) + bottleneck_output = self.conv1(bottleneck_output) + new_features_d = self.conv2_d(self.relu2_d(self.norm2_d(bottleneck_output_d))) + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + if prev_features[0].size(2) != 7: + new_features_d = F.upsample(new_features_d, None, 2, 'bilinear', False) + return new_features + new_features_d + + +def _bn_function_factory(norm, relu, conv): + def bn_function(*inputs): + concated_features = torch.cat(inputs, 1) + bottleneck_output = conv(relu(norm(concated_features))) + return bottleneck_output + return bn_function + + +class _DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): + super(_DenseLayer, self).__init__() + self.add_module('norm1', CpBatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * + growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', CpBatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, + kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + self.efficient = efficient + + def forward(self, *prev_features): + bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) + if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): + args = prev_features + tuple(self.norm1.parameters()) + tuple(self.conv1.parameters()) + bottleneck_output = CheckpointFunction.apply(bn_function, len(prev_features), *args) + else: + bottleneck_output = bn_function(*prev_features) + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', CpBatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class _DenseBlock(nn.Module): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, + efficient=False, dense_layer=_DenseLayer): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = dense_layer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + efficient=efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.named_children(): + new_features = layer(*features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" ` + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 3 or 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. + efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. + """ + def __init__(self, growth_rate=32, block_config=(16, 16, 16), compression=0.5, + num_init_features=64, bn_size=4, drop_rate=0, + num_classes=1000, small_inputs=False, efficient=True, elastic=False): + + super(DenseNet, self).__init__() + assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' + self.avgpool_size = 8 if small_inputs else 7 + + # First convolution + if small_inputs: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ])) + self.features.add_module('norm0', CpBatchNorm2d(num_init_features)) + self.features.add_module('relu0', nn.ReLU(inplace=True)) + self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, + ceil_mode=False)) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + efficient=efficient and i == 0, + dense_layer=_DenseLayer if not elastic else _DenseLayerElastic + ) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition(num_input_features=num_features, + num_output_features=int(num_features * compression)) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = int(num_features * compression) + + # Final batch norm + self.features.add_module('norm_final', CpBatchNorm2d(num_features)) + + # Linear layer + self.classifier = nn.Linear(num_features, num_classes) + + # Initialization + for name, param in self.named_parameters(): + if 'conv' in name and 'weight' in name: + n = param.size(0) * param.size(2) * param.size(3) + param.data.normal_().mul_(math.sqrt(2. / n)) + elif 'norm' in name and 'weight' in name: + param.data.fill_(1) + elif 'norm' in name and 'bias' in name: + param.data.fill_(0) + elif 'classifier' in name and 'bias' in name: + param.data.fill_(0) + + def forward(self, x): + features = self.features(x) + out = F.relu(features, inplace=True) + out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(features.size(0), -1) + out = self.classifier(out) + return out + + +def densenet201(**kwargs): + model = DenseNet(block_config=(6, 12, 48, 32), elastic=False, **kwargs) + return model + + +def densenet201_elastic(**kwargs): + model = DenseNet(block_config=(10, 20, 40, 30), elastic=True, **kwargs) + return model diff --git a/models/dla.py b/models/dla.py new file mode 100644 index 0000000..eae4350 --- /dev/null +++ b/models/dla.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math +from torch.utils.checkpoint import * +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np +# from .utils import fill_up_weights, CpBatchNorm2d +BatchNorm = nn.BatchNorm2d + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BottleneckX(nn.Module): + expansion = 2 + cardinality = 32 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BottleneckX, self).__init__() + cardinality = BottleneckX.cardinality + bottle_planes = planes * cardinality // 32 + self.conv1 = nn.Conv2d(inplanes, bottle_planes, + kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes) + self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, + stride=stride, padding=dilation, bias=False, + dilation=dilation, groups=cardinality) + self.bn2 = BatchNorm(bottle_planes) + self.conv3 = nn.Conv2d(bottle_planes, planes, + kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckXElastic(nn.Module): + expansion = 2 + cardinality = 32 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BottleneckXElastic, self).__init__() + cardinality = BottleneckX.cardinality + self.elastic = (stride == 1 and planes < 1024) + if self.elastic: + # self.ups = nn.ConvTranspose2d( + # inplanes, inplanes, 4, stride=2, padding=1, + # output_padding=0, groups=inplanes, bias=False) + # fill_up_weights(self.ups) + self.down = nn.AvgPool2d(2, stride=2) + self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + bottle_planes = planes * cardinality // 32 + + self.conv1_d = nn.Conv2d(inplanes, bottle_planes // 2, + kernel_size=1, bias=False) + self.bn1_d = BatchNorm(bottle_planes // 2) + self.conv2_d = nn.Conv2d(bottle_planes // 2, bottle_planes // 2, kernel_size=3, + stride=stride, padding=dilation, bias=False, + dilation=dilation, groups=cardinality // 2) + self.bn2_d = BatchNorm(bottle_planes // 2) + self.conv3_d = nn.Conv2d(bottle_planes // 2, planes, + kernel_size=1, bias=False) + + self.conv1 = nn.Conv2d(inplanes, bottle_planes // 2, + kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes // 2) + self.conv2 = nn.Conv2d(bottle_planes // 2, bottle_planes // 2, kernel_size=3, + stride=stride, padding=dilation, bias=False, + dilation=dilation, groups=cardinality // 2) + self.bn2 = BatchNorm(bottle_planes // 2) + self.conv3 = nn.Conv2d(bottle_planes // 2, planes, + kernel_size=1, bias=False) + + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + self.__flops__ = 0 + + def forward(self, x, residual=None): + if residual is None: + residual = x + out_d = x + if self.elastic: + if x.size(2) % 2 > 0 or x.size(3) % 2 > 0: + out_d = F.pad(out_d, (0, x.size(3) % 2, 0, x.size(2) % 2), mode='replicate') + out_d = self.down(out_d) + + out_d = self.conv1_d(out_d) + out_d = self.bn1_d(out_d) + out_d = self.relu(out_d) + + out_d = self.conv2_d(out_d) + out_d = self.bn2_d(out_d) + out_d = self.relu(out_d) + + out_d = self.conv3_d(out_d) + if self.elastic: + out_d = self.ups(out_d) + self.__flops__ += np.prod(out_d[0].shape) * 8 + if out_d.size(2) > x.size(2) or out_d.size(3) > x.size(3): + out_d = out_d[:, :, :x.size(2), :x.size(3)] + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + + out = out + out_d + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, residual): + super(Root, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, + stride=1, bias=False, padding=(kernel_size - 1) // 2) + self.bn = BatchNorm(out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + def __init__(self, levels, block, in_channels, out_channels, stride=1, + level_root=False, root_dim=0, root_kernel_size=1, + dilation=1, root_residual=False, seg=False): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, + dilation=dilation) + self.tree2 = block(out_channels, out_channels, 1, + dilation=dilation) + else: + self.tree1 = Tree(levels - 1, block, in_channels, out_channels, + stride, root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, root_residual=root_residual, seg=seg) + self.tree2 = Tree(levels - 1, block, out_channels, out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, root_residual=root_residual, seg=seg) + if levels == 1: + self.root = Root(root_dim, out_channels, root_kernel_size, + root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride, ceil_mode=seg) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=1, bias=False), + BatchNorm(out_channels) + ) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__(self, levels, channels, num_classes=1000, + block=BottleneckX, residual_root=False, return_levels=False, + pool_size=7, linear_root=False, seg=False): + super(DLA, self).__init__() + self.channels = channels + self.seg = seg + self.return_levels = return_levels + self.num_classes = num_classes + self.base_layer = nn.Sequential( + nn.Conv2d(3, channels[0], kernel_size=7, stride=1, + padding=3, bias=False), + BatchNorm(channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level( + channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, + level_root=False, root_residual=residual_root, seg=seg) + self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, + level_root=True, root_residual=residual_root, seg=seg) + self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, + level_root=True, root_residual=residual_root, seg=seg) + self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, + level_root=True, root_residual=residual_root, seg=seg) + + self.avgpool = nn.AvgPool2d(pool_size) + self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, + stride=1, padding=0, bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_level(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes: + downsample = nn.Sequential( + nn.MaxPool2d(stride, stride=stride, ceil_mode=self.seg), + nn.Conv2d(inplanes, planes, + kernel_size=1, stride=1, bias=False), + BatchNorm(planes), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample=downsample)) + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d(inplanes, planes, kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), + BatchNorm(planes), + nn.ReLU(inplace=True)]) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = [] + x = self.base_layer(x) + for i in range(6): + if self.seg: + x = checkpoint(getattr(self, 'level{}'.format(i)), x) + else: + x = getattr(self, 'level{}'.format(i))(x) + y.append(x) + if self.return_levels: + return y + else: + x = self.avgpool(x) + x = self.fc(x) + x = x.view(x.size(0), -1) + return x + + +def dla60x(**kwargs): + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, **kwargs) + return model + + +def dla102x(**kwargs): + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, residual_root=True, **kwargs) + return model + + +def dla60x_elastic(**kwargs): + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckXElastic, **kwargs) + return model + + +def dla102x_elastic(**kwargs): + BottleneckX.cardinality = 50 + model = DLA([1, 1, 3, 3, 3, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckXElastic, residual_root=True, **kwargs) + return model diff --git a/models/dla_up.py b/models/dla_up.py new file mode 100644 index 0000000..d8b6a68 --- /dev/null +++ b/models/dla_up.py @@ -0,0 +1,162 @@ +import math + +import numpy as np +import torch +from torch import nn +from . import dla +from .utils import fill_up_weights +BatchNorm = nn.BatchNorm2d + + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class IDAUp(nn.Module): + def __init__(self, node_kernel, out_dim, channels, up_factors): + super(IDAUp, self).__init__() + self.channels = channels + self.out_dim = out_dim + for i, c in enumerate(channels): + if c == out_dim: + proj = Identity() + else: + proj = nn.Sequential( + nn.Conv2d(c, out_dim, + kernel_size=1, stride=1, bias=False), + BatchNorm(out_dim), + nn.ReLU(inplace=True)) + f = int(up_factors[i]) + if f == 1: + up = Identity() + else: + up = nn.ConvTranspose2d( + out_dim, out_dim, f * 2, stride=f, padding=f // 2, + output_padding=0, groups=out_dim, bias=False) + fill_up_weights(up) + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + + for i in range(1, len(channels)): + node = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, + kernel_size=node_kernel, stride=1, + padding=node_kernel // 2, bias=False), + BatchNorm(out_dim), + nn.ReLU(inplace=True)) + setattr(self, 'node_' + str(i), node) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, layers): + assert len(self.channels) == len(layers), \ + '{} vs {} layers'.format(len(self.channels), len(layers)) + layers = list(layers) + for i, l in enumerate(layers): + upsample = getattr(self, 'up_' + str(i)) + project = getattr(self, 'proj_' + str(i)) + layers[i] = upsample(project(l)) + x = layers[0] + y = [] + for i in range(1, len(layers)): + node = getattr(self, 'node_' + str(i)) + x = node(torch.cat([x, layers[i][:, :, :x.size(2), :x.size(3)]], 1)) + y.append(x) + return x, y + + +class DLAUp(nn.Module): + def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): + super(DLAUp, self).__init__() + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr(self, 'ida_{}'.format(i), + IDAUp(3, channels[j], in_channels[j:], + scales[j:] // scales[j])) + scales[j + 1:] = scales[j] + in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + def forward(self, layers): + layers = list(layers) + assert len(layers) > 1 + for i in range(len(layers) - 1): + ida = getattr(self, 'ida_{}'.format(i)) + x, y = ida(layers[-i - 2:]) + layers[-i - 1:] = y + return x + + +class DLASeg(nn.Module): + def __init__(self, base_name, classes, down_ratio=2): + super(DLASeg, self).__init__() + assert down_ratio in [2, 4, 8, 16] + self.first_level = int(np.log2(down_ratio)) + self.base = dla.__dict__[base_name](return_levels=True, seg=True) + channels = self.base.channels + # print(channels, self.first_level) + scales = [2 ** i for i in range(len(channels[self.first_level:]))] + self.dla_up = DLAUp(channels[self.first_level:], scales=scales) + self.fc = nn.Sequential( + nn.Conv2d(channels[self.first_level], classes, kernel_size=1, + stride=1, padding=0, bias=True) + ) + up_factor = 2 ** self.first_level + if up_factor > 1: + up = nn.ConvTranspose2d(classes, classes, up_factor * 2, + stride=up_factor, padding=up_factor // 2, + output_padding=0, groups=classes, + bias=False) + fill_up_weights(up) + up.weight.requires_grad = False + else: + up = Identity() + self.up = up + self.softmax = nn.LogSoftmax(dim=1) + + for m in self.fc.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + x = self.base(x) + x = self.dla_up(x[self.first_level:]) + x = self.fc(x) + y = self.softmax(self.up(x)) + return y[:, :, :-1, :-1] + + def optim_parameters(self, memo=None): + for param in self.base.parameters(): + yield param + for param in self.dla_up.parameters(): + yield param + for param in self.fc.parameters(): + yield param + + +def dla60x_seg(classes, **kwargs): + model = DLASeg('dla60x', classes, **kwargs) + return model + + +def dla60x_elastic_seg(classes, **kwargs): + model = DLASeg('dla60x_elastic', classes, **kwargs) + return model diff --git a/models/resnext.py b/models/resnext.py new file mode 100644 index 0000000..0e22380 --- /dev/null +++ b/models/resnext.py @@ -0,0 +1,300 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +from torch.utils.checkpoint import * +import torch.nn.functional as F +from torch.nn import init +import numpy as np + + +class ASPP(nn.Module): + def __init__(self, C, depth, num_classes, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): + super(ASPP, self).__init__() + self._C = C + self._depth = depth + self._num_classes = num_classes + self._norm = norm + + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.relu = nn.ReLU(inplace=True) + self.aspp1 = nn.Conv2d(C, depth, kernel_size=1, stride=1, bias=False) + self.aspp2 = nn.Conv2d(C, depth, kernel_size=3, stride=1, + dilation=int(6*mult), padding=int(6*mult), + bias=False) + self.aspp3 = nn.Conv2d(C, depth, kernel_size=3, stride=1, + dilation=int(12*mult), padding=int(12*mult), + bias=False) + self.aspp4 = nn.Conv2d(C, depth, kernel_size=3, stride=1, + dilation=int(18*mult), padding=int(18*mult), + bias=False) + self.aspp5 = nn.Conv2d(C, depth, kernel_size=1, stride=1, bias=False) + self.aspp1_bn = self._norm(depth, momentum) + self.aspp2_bn = self._norm(depth, momentum) + self.aspp3_bn = self._norm(depth, momentum) + self.aspp4_bn = self._norm(depth, momentum) + self.aspp5_bn = self._norm(depth, momentum) + self.conv2 = nn.Conv2d(depth * 5, depth, kernel_size=1, stride=1, + bias=False) + self.bn2 = self._norm(depth, momentum) + self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) + + def forward(self, x): + x1 = self.aspp1(x) + x1 = self.aspp1_bn(x1) + x1 = self.relu(x1) + x2 = self.aspp2(x) + x2 = self.aspp2_bn(x2) + x2 = self.relu(x2) + x3 = self.aspp3(x) + x3 = self.aspp3_bn(x3) + x3 = self.relu(x3) + x4 = self.aspp4(x) + x4 = self.aspp4_bn(x4) + x4 = self.relu(x4) + x5 = self.global_pooling(x) + x5 = self.aspp5(x5) + x5 = self.aspp5_bn(x5) + x5 = self.relu(x5) + x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', + align_corners=True)(x5) + x = torch.cat((x1, x2, x3, x4, x5), 1) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.conv3(x) + + return x + + +class Selayer(nn.Module): + + def __init__(self, inplanes): + super(Selayer, self).__init__() + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1) + self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + out = self.global_avgpool(x) + + out = self.conv1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.sigmoid(out) + + return x * out + + +class BottleneckX(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None, dilation=1, norm=None, elastic=False, se=False): + super(BottleneckX, self).__init__() + self.se = se + self.elastic = elastic and stride == 1 and planes < 512 + if self.elastic: + self.down = nn.AvgPool2d(2, stride=2) + self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + # half resolution + self.conv1_d = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1_d = norm(planes) + self.conv2_d = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, groups=cardinality // 2, + dilation=dilation, padding=dilation, bias=False) + self.bn2_d = norm(planes) + self.conv3_d = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + # full resolution + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, groups=cardinality // 2, + dilation=dilation, padding=dilation, bias=False) + self.bn2 = norm(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + # after merging + self.bn3 = norm(planes * self.expansion) + if self.se: + self.selayer = Selayer(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.__flops__ = 0 + + def forward(self, x): + residual = x + out_d = x + if self.elastic: + if x.size(2) % 2 > 0 or x.size(3) % 2 > 0: + out_d = F.pad(out_d, (0, x.size(3) % 2, 0, x.size(2) % 2), mode='replicate') + out_d = self.down(out_d) + + out_d = self.conv1_d(out_d) + out_d = self.bn1_d(out_d) + out_d = self.relu(out_d) + + out_d = self.conv2_d(out_d) + out_d = self.bn2_d(out_d) + out_d = self.relu(out_d) + + out_d = self.conv3_d(out_d) + + if self.elastic: + out_d = self.ups(out_d) + self.__flops__ += np.prod(out_d[0].shape) * 8 + if out_d.size(2) > x.size(2) or out_d.size(3) > x.size(3): + out_d = out_d[:, :, :x.size(2), :x.size(3)] + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = out + out_d + out = self.bn3(out) + + if self.se: + out = self.selayer(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNext(nn.Module): + + def __init__(self, block, layers, num_classes=1000, seg=False, elastic=False, se=False): + self.inplanes = 64 + self.cardinality = 32 + self.seg = seg + self._norm = lambda planes, momentum=0.05 if seg else 0.1: torch.nn.BatchNorm2d(planes, momentum=momentum) + + super(ResNext, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = self._norm(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], elastic=elastic, se=se) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, elastic=elastic, se=se) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, elastic=elastic, se=se) + if seg: + self.layer4 = self._make_mg(block, 512, se=se) + self.aspp = ASPP(512 * block.expansion, 256, num_classes, self._norm) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + else: + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, elastic=False, se=se) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + init.normal_(self.fc.weight, std=0.01) + for n, p in self.named_parameters(): + if n.split('.')[-1] == 'weight': + if 'conv' in n: + init.kaiming_normal_(p, mode='fan_in', nonlinearity='relu') + if 'bn' in n: + p.data.fill_(1) + if 'bn3' in n: + p.data.fill_(0) + elif n.split('.')[-1] == 'bias': + p.data.fill_(0) + + def _make_layer(self, block, planes, blocks, stride=1, elastic=False, se=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self._norm(planes * block.expansion), + ) + + layers = list() + layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample=downsample, norm=self._norm, elastic=elastic, se=se)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, self.cardinality, norm=self._norm, elastic=elastic, se=se)) + return nn.Sequential(*layers) + + def _make_mg(self, block, planes, dilation=2, multi_grid=(1, 2, 4), se=False): + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=1, dilation=1, bias=False), + self._norm(planes * block.expansion), + ) + + layers = list() + layers.append(block(self.inplanes, planes, self.cardinality, downsample=downsample, dilation=dilation*multi_grid[0], norm=self._norm, se=se)) + self.inplanes = planes * block.expansion + layers.append(block(self.inplanes, planes, self.cardinality, dilation=dilation*multi_grid[1], norm=self._norm, se=se)) + layers.append(block(self.inplanes, planes, self.cardinality, dilation=dilation*multi_grid[2], norm=self._norm, se=se)) + return nn.Sequential(*layers) + + def forward(self, x): + size = (x.shape[2], x.shape[3]) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + if self.seg: + for module in self.layer1._modules.values(): + x = checkpoint(module, x) + for module in self.layer2._modules.values(): + x = checkpoint(module, x) + for module in self.layer3._modules.values(): + x = checkpoint(module, x) + for module in self.layer4._modules.values(): + x = checkpoint(module, x) + x = self.aspp(x) + x = nn.Upsample(size, mode='bilinear', align_corners=True)(x) + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +def resnext50(seg=False, **kwargs): + model = ResNext(BottleneckX, [3, 4, 6, 3], seg=seg, elastic=False, **kwargs) + return model + + +def se_resnext50(seg=False, **kwargs): + model = ResNext(BottleneckX, [3, 4, 6, 3], seg=seg, elastic=False, se=True, **kwargs) + return model + + +def resnext50_elastic(seg=False, **kwargs): + model = ResNext(BottleneckX, [6, 8, 5, 3], seg=seg, elastic=True, **kwargs) + return model + + +def se_resnext50_elastic(seg=False, **kwargs): + model = ResNext(BottleneckX, [6, 8, 5, 3], seg=seg, elastic=True, se=True, **kwargs) + return model + + +def resnext101(seg=False, **kwargs): + model = ResNext(BottleneckX, [3, 4, 23, 3], seg=seg, elastic=False, **kwargs) + return model + + +def resnext101_elastic(seg=False, **kwargs): + model = ResNext(BottleneckX, [12, 14, 20, 3], seg=seg, elastic=True, **kwargs) + return model diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..4968eb8 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,80 @@ +import torch +import warnings +import torch.nn.functional as F +import math + + +class CpBatchNorm2d(torch.nn.BatchNorm2d): + def __init__(self, *args, **kwargs): + super(CpBatchNorm2d, self).__init__(*args, **kwargs) + + def forward(self, input): + self._check_input_dim(input) + if input.requires_grad: + exponential_average_factor = 0.0 + if self.training and self.track_running_stats: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + else: + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, 0.0, self.eps) + + +def detach_variable(inputs): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) + + +def check_backward_validity(inputs): + if not any(inp.requires_grad for inp in inputs): + warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + for i in range(len(ctx.input_tensors)): + temp = ctx.input_tensors[i] + ctx.input_tensors[i] = temp.detach() + ctx.input_tensors[i].requires_grad = temp.requires_grad + with torch.enable_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) + return (None, None) + input_grads + + +def fill_up_weights(up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] \ No newline at end of file diff --git a/multilabel_classify.py b/multilabel_classify.py new file mode 100644 index 0000000..f44e8dd --- /dev/null +++ b/multilabel_classify.py @@ -0,0 +1,396 @@ +import argparse +import time +import numpy as np +import pdb +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.utils.data as data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +# import torchvision.models as models +import models +import os +from PIL import Image +from utils import add_flops_counting_methods, save_checkpoint, AverageMeter + +model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', + 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', + 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] + + +class CocoDetection(datasets.coco.CocoDetection): + def __init__(self, root, annFile, transform=None, target_transform=None): + from pycocotools.coco import COCO + self.root = root + self.coco = COCO(annFile) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + self.cat2cat = dict() + for cat in self.coco.cats.keys(): + self.cat2cat[cat] = len(self.cat2cat) + # print(self.cat2cat) + + def __getitem__(self, index): + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + output = torch.zeros((3, 80), dtype=torch.long) + for obj in target: + if obj['area'] < 32 * 32: + output[0][self.cat2cat[obj['category_id']]] = 1 + elif obj['area'] < 96 * 96: + output[1][self.cat2cat[obj['category_id']]] = 1 + else: + output[2][self.cat2cat[obj['category_id']]] = 1 + target = output + + path = coco.loadImgs(img_id)[0]['file_name'] + img = Image.open(os.path.join(self.root, path)).convert('RGB') + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') +parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') +parser.add_argument('--epochs', default=36, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=96, type=int, + metavar='N', help='mini-batch size (default: 96)') +parser.add_argument('-g', '--num-gpus', default=4, type=int, + metavar='N', help='number of GPUs to match (default: 4)') +parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, + metavar='W', help='weight decay (default: 5e-4)') +parser.add_argument('--print-freq', '-p', default=117, type=int, + metavar='N', help='print frequency (default: 117)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, + help='distributed backend') + + +def main(): + global args + args = parser.parse_args() + print('config: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'num_gpus', args.num_gpus) + iteration_size = args.num_gpus // torch.cuda.device_count() # do multiple iterations + assert iteration_size >= 1 + args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr + args.lr = args.lr / iteration_size + args.batch_size = args.batch_size // iteration_size + print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) + + args.distributed = args.world_size > 1 + + if args.distributed: + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size) + + # create model + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch](num_classes=80) + + # count number of parameters + count = 0 + params = list() + for n, p in model.named_parameters(): + if '.ups.' not in n: + params.append(p) + count += np.prod(p.size()) + print('Parameters:', count) + + # count flops + model = add_flops_counting_methods(model) + model.eval() + image = torch.randn(1, 3, 224, 224) + + model.start_flops_count() + model(image).sum() + model.stop_flops_count() + print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) + + # normal code + model = torch.nn.DataParallel(model).cuda() + + criterion = nn.BCEWithLogitsLoss().cuda() + optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, + ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + + resume = ('module.fc.bias' in checkpoint['state_dict'] and + checkpoint['state_dict']['module.fc.bias'].size() == model.module.fc.bias.size()) or \ + ('module.classifier.bias' in checkpoint['state_dict'] and + checkpoint['state_dict']['module.classifier.bias'].size() == model.module.classifier.bias.size()) + if resume: + # True resume: resume training on COCO + model.load_state_dict(checkpoint['state_dict'], strict=False) + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') + args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch + else: + # Fake resume: transfer from ImageNet + for n, p in list(checkpoint['state_dict'].items()): + if 'classifier' in n or 'fc' in n: + print(n, 'deleted from state_dict') + del checkpoint['state_dict'][n] + model.load_state_dict(checkpoint['state_dict'], strict=False) + + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + train_dataset = CocoDetection(os.path.join(args.data, 'train2014'), + os.path.join(args.data, 'annotations/instances_train2014.json'), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + val_dataset = CocoDetection(os.path.join(args.data, 'val2014'), + os.path.join(args.data, 'annotations/instances_val2014.json'), + transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + normalize, + ])) + + train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate_multi(val_loader, model, criterion) + return + + for epoch in range(args.start_epoch, args.epochs): + coco_adjust_learning_rate(optimizer, epoch) + + # train for one epoch + train_multi(train_loader, model, criterion, optimizer, epoch, iteration_size) + + # evaluate on validation set + validate_multi(val_loader, model, criterion) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, False, filename='coco_' + args.arch + '_checkpoint.pth.tar') + + +def train_multi(train_loader, model, criterion, optimizer, epoch, iteration_size): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + prec = AverageMeter() + rec = AverageMeter() + + # switch to train mode + model.train() + optimizer.zero_grad() + end = time.time() + tp, fp, fn, tn, count = 0, 0, 0, 0, 0 + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + target = target.cuda(non_blocking=True) + target = target.max(dim=1)[0] + # compute output + output = model(input) + loss = criterion(output, target.float()) * 80.0 + + # measure accuracy and record loss + pred = output.data.gt(0.0).long() + + tp += (pred + target).eq(2).sum(dim=0) + fp += (pred - target).eq(1).sum(dim=0) + fn += (pred - target).eq(-1).sum(dim=0) + tn += (pred + target).eq(0).sum(dim=0) + count += input.size(0) + + this_tp = (pred + target).eq(2).sum() + this_fp = (pred - target).eq(1).sum() + this_fn = (pred - target).eq(-1).sum() + this_tn = (pred + target).eq(0).sum() + this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp + this_fn).float() + + this_prec = this_tp.float() / (this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0 + this_rec = this_tp.float() / (this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0 + + losses.update(float(loss), input.size(0)) + prec.update(float(this_prec), input.size(0)) + rec.update(float(this_rec), input.size(0)) + # compute gradient and do SGD step + loss.backward() + + if i % iteration_size == iteration_size - 1: + optimizer.step() + optimizer.zero_grad() + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] + r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] + f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))] + + mean_p_c = sum(p_c) / len(p_c) + mean_r_c = sum(r_c) / len(r_c) + mean_f_c = sum(f_c) / len(f_c) + + p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0 + r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0 + f_o = 2 * p_o * r_o / (p_o + r_o) + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Precision {prec.val:.2f} ({prec.avg:.2f})\t' + 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, prec=prec, rec=rec)) + print('P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' + .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) + + +def validate_multi(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + prec = AverageMeter() + rec = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + tp, fp, fn, tn, count = 0, 0, 0, 0, 0 + tp_size, fn_size = 0, 0 + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + original_target = target + target = target.max(dim=1)[0] + # compute output + with torch.no_grad(): + output = model(input) + loss = criterion(output, target.float()) + + # measure accuracy and record loss + pred = output.data.gt(0.0).long() + + tp += (pred + target).eq(2).sum(dim=0) + fp += (pred - target).eq(1).sum(dim=0) + fn += (pred - target).eq(-1).sum(dim=0) + tn += (pred + target).eq(0).sum(dim=0) + three_pred = pred.unsqueeze(1).expand(-1, 3, -1) # n, 3, 80 + tp_size += (three_pred + original_target).eq(2).sum(dim=0) + fn_size += (three_pred - original_target).eq(-1).sum(dim=0) + count += input.size(0) + + this_tp = (pred + target).eq(2).sum() + this_fp = (pred - target).eq(1).sum() + this_fn = (pred - target).eq(-1).sum() + this_tn = (pred + target).eq(0).sum() + this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp + this_fn).float() + + this_prec = this_tp.float() / (this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0 + this_rec = this_tp.float() / (this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0 + + losses.update(float(loss), input.size(0)) + prec.update(float(this_prec), input.size(0)) + rec.update(float(this_rec), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] + r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] + f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))] + + mean_p_c = sum(p_c) / len(p_c) + mean_r_c = sum(r_c) / len(r_c) + mean_f_c = sum(f_c) / len(f_c) + + p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0 + r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0 + f_o = 2 * p_o * r_o / (p_o + r_o) + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Precision {prec.val:.2f} ({prec.avg:.2f})\t' + 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + prec=prec, rec=rec)) + print('P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' + .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) + + print('--------------------------------------------------------------------') + print(' * P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' + .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) + return + + +def coco_adjust_learning_rate(optimizer, epoch): + if isinstance(optimizer, torch.optim.Adam): + return + lr = args.lr + # if epoch >= 12: + # lr *= 0.1 + if epoch >= 24: + lr *= 0.1 + if epoch >= 30: + lr *= 0.1 + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +if __name__ == '__main__': + main() diff --git a/segment.py b/segment.py new file mode 100644 index 0000000..4ba28ac --- /dev/null +++ b/segment.py @@ -0,0 +1,197 @@ +import argparse +import os +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import pdb +from PIL import Image +from scipy.io import loadmat +from torch.autograd import Variable +from utils import AverageMeter, inter_and_union, VOCSegmentation +import models + +model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', 'dla60x', 'dla60x_elastic'] + +parser = argparse.ArgumentParser() +parser.add_argument('--train', action='store_true', default=False, + help='training mode') +parser.add_argument('--exp', type=str, required=True, + help='name of experiment') +parser.add_argument('--gpu', type=int, default=0, + help='test time gpu device id') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') +parser.add_argument('--dataset', type=str, default='pascal', + help='pascal') +parser.add_argument('--epochs', type=int, default=50, + help='num of training epochs') +parser.add_argument('--batch_size', type=int, default=16, + help='batch size') +parser.add_argument('--base_lr', type=float, default=0.007, + help='base learning rate') +parser.add_argument('--last_mult', type=float, default=1.0, + help='learning rate multiplier for last layers') +parser.add_argument('--freeze_bn', action='store_true', default=False, + help='freeze batch normalization parameters') +parser.add_argument('--crop_size', type=int, default=513, + help='image crop size') +parser.add_argument('--resume', type=str, default=None, + help='path to checkpoint to resume from') +parser.add_argument('--workers', type=int, default=4, + help='number of data loading workers') +args = parser.parse_args() + + +def main(): + assert torch.cuda.is_available() + model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format( + args.arch, args.dataset, args.exp) + if args.dataset == 'pascal': + dataset = VOCSegmentation('data/VOCdevkit', + train=args.train, crop_size=args.crop_size) + else: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + + if 'resnext' in args.arch: + model = models.__dict__[args.arch](seg=True, num_classes=len(dataset.CLASSES)) + elif 'dla' in args.arch: + model = models.__dict__[args.arch + '_seg'](classes=len(dataset.CLASSES)) + else: + raise ValueError('Unknown arch: {}'.format(args.arch)) + + if args.train: + criterion = nn.CrossEntropyLoss(ignore_index=255) + model = nn.DataParallel(model).cuda() + model.train() + if args.freeze_bn: + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + if 'resnext' in args.arch: + arch_params = ( + list(model.module.conv1.parameters()) + + list(model.module.bn1.parameters()) + + list(model.module.layer1.parameters()) + + list(model.module.layer2.parameters()) + + list(model.module.layer3.parameters()) + + list(model.module.layer4.parameters())) + last_params = list(model.module.aspp.parameters()) + else: + arch_params = list(model.module.base.parameters()) + last_params = list() + for n, p in model.named_parameters(): + if 'base' not in n and 'up.weight' not in n: + last_params.append(p) + + optimizer = optim.SGD([ + {'params': filter(lambda p: p.requires_grad, arch_params)}, + {'params': filter(lambda p: p.requires_grad, last_params)}], + lr=args.base_lr, momentum=0.9, weight_decay=0.0005 if 'resnext' in args.arch else 0.0001) + dataset_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=args.train, + pin_memory=True, num_workers=args.workers) + max_iter = args.epochs * len(dataset_loader) + losses = AverageMeter() + start_epoch = 0 + + if args.resume: + if os.path.isfile(args.resume): + print('=> loading checkpoint {0}'.format(args.resume)) + checkpoint = torch.load(args.resume) + + resume = False + for n, p in list(checkpoint['state_dict'].items()): + if 'aspp' in n or 'dla_up' in n: + resume = True + break + if resume: + # True resume: resume training on pascal + model.load_state_dict(checkpoint['state_dict'], strict=True) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + else: + # Fake resume: transfer from ImageNet + if 'resnext' in args.arch: + model.load_state_dict(checkpoint['state_dict'], strict=False) + else: + pretrained_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} + model.module.base.load_state_dict(pretrained_dict, strict=False) + print('=> loaded checkpoint {0} (epoch {1})'.format( + args.resume, start_epoch)) + else: + print('=> no checkpoint found at {0}'.format(args.resume)) + + for epoch in range(start_epoch, args.epochs): + for i, (inputs, target, _, _, _, _) in enumerate(dataset_loader): + cur_iter = epoch * len(dataset_loader) + i + lr = args.base_lr * (1 - float(cur_iter) / max_iter) ** 0.9 + optimizer.param_groups[0]['lr'] = lr + optimizer.param_groups[1]['lr'] = lr * args.last_mult + + inputs = Variable(inputs.cuda()) + target = Variable(target.cuda()) + + outputs = model(inputs) + loss = criterion(outputs, target) + if np.isnan(loss.item()) or np.isinf(loss.item()): + pdb.set_trace() + losses.update(loss.item(), args.batch_size) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + print('epoch: {0}\t' + 'iter: {1}/{2}\t' + 'lr: {3:.6f}\t' + 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( + epoch + 1, i + 1, len(dataset_loader), lr, loss=losses)) + + if epoch % 10 == 9: + torch.save({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, model_fname % (epoch + 1)) + + else: + torch.cuda.set_device(args.gpu) + model = model.cuda() + model.eval() + checkpoint = torch.load(model_fname % args.epochs) + state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k} + model.load_state_dict(state_dict) + cmap = loadmat('data/pascal_seg_colormap.mat')['colormap'] + cmap = (cmap * 255).astype(np.uint8).flatten().tolist() + + inter_meter = AverageMeter() + union_meter = AverageMeter() + for i in range(len(dataset)): + inputs, target, a, b, h, w = dataset[i] + inputs = inputs.unsqueeze(0) + inputs = Variable(inputs.cuda()) + outputs = model(inputs) + _, pred = torch.max(outputs, 1) + pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) + mask = target.numpy().astype(np.uint8) + imname = dataset.masks[i].split('/')[-1] + + inter, union = inter_and_union(pred, mask, len(dataset.CLASSES)) + inter_meter.update(inter) + union_meter.update(union) + + mask_pred = Image.fromarray(pred[a:a + h, b:b + w]) + mask_pred.putpalette(cmap) + mask_pred.save(os.path.join('data/val', imname)) + print('eval: {0}/{1}'.format(i + 1, len(dataset))) + + iou = inter_meter.sum / (union_meter.sum + 1e-10) + for i, val in enumerate(iou): + print('IoU {0}: {1:.2f}'.format(dataset.CLASSES[i], val * 100)) + print('Mean IoU: {0:.2f}'.format(iou.mean() * 100)) + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..1c74b8e --- /dev/null +++ b/utils.py @@ -0,0 +1,463 @@ +from __future__ import print_function +import math +import random +import torchvision.transforms as transforms +import warnings +from torch.nn import functional as F +import shutil +import torch.utils.data as data +import os +from PIL import Image +import torch +import numpy as np + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + def __init__(self): + self.val = None + self.sum = None + self.cnt = None + self.avg = None + self.ema = None + self.initialized = False + + def update(self, val, n=1): + if not self.initialized: + self.initialize(val, n) + else: + self.add(val, n) + + def initialize(self, val, n): + self.val = val + self.sum = val * n + self.cnt = n + self.avg = val + self.ema = val + self.initialized = True + + def add(self, val, n): + self.val = val + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + self.ema = self.ema * 0.99 + self.val * 0.01 + + +def inter_and_union(pred, mask, num_class): + pred = np.asarray(pred, dtype=np.uint8).copy() + mask = np.asarray(mask, dtype=np.uint8).copy() + + # 255 -> 0 + pred += 1 + mask += 1 + pred = pred * (mask > 0) + + inter = pred * (pred == mask) + (area_inter, _) = np.histogram(inter, bins=num_class, range=(1, num_class)) + (area_pred, _) = np.histogram(pred, bins=num_class, range=(1, num_class)) + (area_mask, _) = np.histogram(mask, bins=num_class, range=(1, num_class)) + area_union = area_pred + area_mask - area_inter + + return (area_inter, area_union) + + +def preprocess(image, mask, flip=False, scale=None, crop=None): + if flip: + if random.random() < 0.5: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + if scale: + w, h = image.size + rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2)) + random_scale = math.pow(2, rand_log_scale) + new_size = (int(round(w * random_scale)), int(round(h * random_scale))) + image = image.resize(new_size, Image.ANTIALIAS) + mask = mask.resize(new_size, Image.NEAREST) + + data_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + image = data_transforms(image) + mask = torch.LongTensor(np.array(mask).astype(np.int64)) + + if crop: + h, w = image.shape[1], image.shape[2] + ori_h, ori_w = image.shape[1], image.shape[2] + + pad_tb = max(0, int((1 + crop[0] - h) / 2)) + pad_lr = max(0, int((1 + crop[1] - w) / 2)) + image = torch.nn.ZeroPad2d((pad_lr, pad_lr, pad_tb, pad_tb))(image) + mask = torch.nn.ConstantPad2d((pad_lr, pad_lr, pad_tb, pad_tb), 255)(mask) + + h, w = image.shape[1], image.shape[2] + i = random.randint(0, h - crop[0]) + j = random.randint(0, w - crop[1]) + image = image[:, i:i + crop[0], j:j + crop[1]] + mask = mask[i:i + crop[0], j:j + crop[1]] + + return image, mask, pad_tb - j, pad_lr - i, ori_h, ori_w + + +# pascal dataloader +class VOCSegmentation(data.Dataset): + CLASSES = [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', + 'tv/monitor' + ] + + def __init__(self, root, train=True, transform=None, target_transform=None, download=False, crop_size=None): + self.root = root + _voc_root = os.path.join(self.root, 'VOC2012') + _list_dir = os.path.join(_voc_root, 'list') + self.transform = transform + self.target_transform = target_transform + self.train = train + self.crop_size = crop_size + + if download: + self.download() + + if self.train: + _list_f = os.path.join(_list_dir, 'train_aug.txt') + else: + _list_f = os.path.join(_list_dir, 'val.txt') + self.images = [] + self.masks = [] + with open(_list_f, 'r') as lines: + for line in lines: + _image = _voc_root + line.split()[0] + _mask = _voc_root + line.split()[1] + assert os.path.isfile(_image) + assert os.path.isfile(_mask) + self.images.append(_image) + self.masks.append(_mask) + + def __getitem__(self, index): + _img = Image.open(self.images[index]).convert('RGB') + _target = Image.open(self.masks[index]) + + _img, _target, a, b, h, w = preprocess(_img, _target, + flip=True if self.train else False, + scale=(0.5, 2.0) if self.train else None, + crop=(self.crop_size, self.crop_size)) + + if self.transform is not None: + _img = self.transform(_img) + + if self.target_transform is not None: + _target = self.target_transform(_target) + + return _img, _target, a, b, h, w # used for visualizing + + def __len__(self): + return len(self.images) + + def download(self): + raise NotImplementedError('Automatic download not yet implemented.') + + +# flops counter +def add_flops_counting_methods(net_main_module): + """Adds flops counting functions to an existing model. After that + the flops count should be activated and the model should be run on an input + image. + + Example: + + fcn = add_flops_counting_methods(fcn) + fcn = fcn.cuda().train() + fcn.start_flops_count() + + + _ = fcn(batch) + + fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch + + Important: dividing by 2 only works for resnet models -- see below for the details + of flops computation. + + Attention: we are counting multiply-add as two flops in this work, because in + most resnet models convolutions are bias-free (BN layers act as bias there) + and it makes sense to count muliply and add as separate flops therefore. + This is why in the above example we divide by 2 in order to be consistent with + most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual + Networks" by Figurnov et al multiply-add was counted as two flops. + + This module computes the average flops which is necessary for dynamic networks which + have different number of executed layers. For static networks it is enough to run the network + once and get statistics (above example). + + Implementation: + The module works by adding batch_count to the main module which tracks the sum + of all batch sizes that were run through the network. + + Also each convolutional layer of the network tracks the overall number of flops + performed. + + The parameters are updated with the help of registered hook-functions which + are being called each time the respective layer is executed. + + Parameters + ---------- + net_main_module : torch.nn.Module + Main module containing network + + Returns + ------- + net_main_module : torch.nn.Module + Updated main module with new methods/attributes that are used + to compute flops. + """ + + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) + + net_main_module.reset_flops_count() + + # Adding variables necessary for masked flops computation + net_main_module.apply(add_flops_mask_variable_or_reset) + + return net_main_module + + +def compute_average_flops_cost(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Returns current mean flops consumption per image. + + """ + + batches_count = self.__batch_counter__ + flops_sum = 0 + for module in self.modules(): + if hasattr(module, '__flops__'): # is_supported_instance(module) + flops_sum += module.__flops__ + + return flops_sum / batches_count + + +def start_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Activates the computation of mean flops consumption per image. + Call it before you run the network. + + """ + add_batch_counter_hook_function(self) + self.apply(add_flops_counter_hook_function) + + +def stop_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Stops computing the mean flops consumption per image. + Call whenever you want to pause the computation. + + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +def add_flops_mask(module, mask): + def add_flops_mask_func(module): + if isinstance(module, torch.nn.Conv2d): + module.__mask__ = mask + module.apply(add_flops_mask_func) + + +def remove_flops_mask(module): + module.apply(add_flops_mask_variable_or_reset) + + +# ---- Internal functions +def is_supported_instance(module): + if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.ReLU) \ + or isinstance(module, torch.nn.PReLU) or isinstance(module, torch.nn.ELU) \ + or isinstance(module, torch.nn.LeakyReLU) or isinstance(module, torch.nn.ReLU6) \ + or isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.MaxPool2d) \ + or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.BatchNorm2d): + return True + + return False + + +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def relu_flops_counter_hook(module, input, output): + input = input[0] + batch_size = input.shape[0] + active_elements_count = batch_size + for val in input.shape[1:]: + active_elements_count *= val + + module.__flops__ += active_elements_count + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + batch_size = input.shape[0] + module.__flops__ += batch_size * input.shape[1] * output.shape[1] + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += np.prod(input.shape) + +def bn_flops_counter_hook(module, input, output): + module.affine + input = input[0] + + batch_flops = np.prod(input.shape) + if module.affine: + batch_flops *= 2 + module.__flops__ += batch_flops + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_height, output_width = output.shape[2:] + + kernel_height, kernel_width = conv_module.kernel_size + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel + + active_elements_count = batch_size * output_height * output_width + + if conv_module.__mask__ is not None: + # (b, 1, h, w) + flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) + active_elements_count = flops_mask.sum() + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += overall_flops + + +def batch_counter_hook(module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = input.shape[0] + module.__batch_counter__ += batch_size + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + module.__flops__ = 0 + + +def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + if isinstance(module, torch.nn.Conv2d): + handle = module.register_forward_hook(conv_flops_counter_hook) + elif isinstance(module, torch.nn.ReLU) or isinstance(module, torch.nn.PReLU) \ + or isinstance(module, torch.nn.ELU) or isinstance(module, torch.nn.LeakyReLU) \ + or isinstance(module, torch.nn.ReLU6): + handle = module.register_forward_hook(relu_flops_counter_hook) + elif isinstance(module, torch.nn.Linear): + handle = module.register_forward_hook(linear_flops_counter_hook) + elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): + handle = module.register_forward_hook(pool_flops_counter_hook) + elif isinstance(module, torch.nn.BatchNorm2d): + handle = module.register_forward_hook(bn_flops_counter_hook) + else: + handle = module.register_forward_hook(empty_flops_counter_hook) + module.__flops_handle__ = handle + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ +# --- Masked flops counting + + +# Also being run in the initialization +def add_flops_mask_variable_or_reset(module): + if is_supported_instance(module): + module.__mask__ = None