In [1]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import sys
import time
import logging
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as dst

from utils import AverageMeter, accuracy, transform_time
from utils import load_pretrained_model, save_checkpoint
from utils import create_exp_dir, count_parameters_in_MB
from network import define_tsnet

In [2]:
parser = argparse.ArgumentParser(description="test r20 net")

In [3]:
# various path
parser.add_argument(
    "--save_root", type=str, default="./results", help="models and logs are saved here"
)
parser.add_argument(
    "--img_root", type=str, default="./datasets", help="path name of image dataset"
)

_StoreAction(option_strings=['--img_root'], dest='img_root', nargs=None, const=None, default='./datasets', type=<class 'str'>, choices=None, required=False, help='path name of image dataset', metavar=None)

In [4]:
# training hyper parameters
parser.add_argument(
    "--print_freq",
    type=int,
    default=50,
    help="frequency of showing training results on console",
)
parser.add_argument(
    "--epochs", type=int, default=200, help="number of total epochs to run"
)
parser.add_argument("--batch_size", type=int, default=128, help="The size of batch")
parser.add_argument("--lr", type=float, default=0.1, help="initial learning rate")
parser.add_argument("--momentum", type=float, default=0.9, help="momentum")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay")
parser.add_argument("--num_class", type=int, default=100, help="number of classes")
parser.add_argument("--cuda", type=int, default=1)

_StoreAction(option_strings=['--cuda'], dest='cuda', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, required=False, help=None, metavar=None)

In [5]:
# others
parser.add_argument("--seed", type=int, default=2, help="random seed")
parser.add_argument("--note", type=str, default="test-c100-r20", help="note for this run")

_StoreAction(option_strings=['--note'], dest='note', nargs=None, const=None, default='test-c100-r20', type=<class 'str'>, choices=None, required=False, help='note for this run', metavar=None)

In [6]:
# net and dataset choosen
parser.add_argument(
    "--data_name", type=str, default="cifar100", required=True, help="name of dataset"
)  # cifar10/cifar100
parser.add_argument(
    "--net_name", type=str, default="resnet20",required=True, help="name of basenet"
)  # resnet20/resnet110

_StoreAction(option_strings=['--net_name'], dest='net_name', nargs=None, const=None, default='resnet20', type=<class 'str'>, choices=None, required=True, help='name of basenet', metavar=None)

In [7]:
args, unparsed = parser.parse_known_args()

args.save_root = os.path.join(args.save_root, args.note)
create_exp_dir(args.save_root)

log_format = "%(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format)
fh = logging.FileHandler(os.path.join(args.save_root, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

usage: ipykernel_launcher.py [-h] [--save_root SAVE_ROOT]
                             [--img_root IMG_ROOT] [--print_freq PRINT_FREQ]
                             [--epochs EPOCHS] [--batch_size BATCH_SIZE]
                             [--lr LR] [--momentum MOMENTUM]
                             [--weight_decay WEIGHT_DECAY]
                             [--num_class NUM_CLASS] [--cuda CUDA]
                             [--seed SEED] [--note NOTE] --data_name DATA_NAME
                             --net_name NET_NAME
ipykernel_launcher.py: error: the following arguments are required: --data_name, --net_name


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [8]:
def test(test_loader, net, criterion):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    net.eval()

    for i, (img, target) in enumerate(test_loader, start=1):
        if args.cuda:
            img = img.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        with torch.no_grad():
            _, _, _, _, _, out = net(img)
            loss = criterion(out, target)

        prec1, prec5 = accuracy(out, target, topk=(1, 5))
        losses.update(loss.item(), img.size(0))
        top1.update(prec1.item(), img.size(0))
        top5.update(prec5.item(), img.size(0))

    f_l = [losses.avg, top1.avg, top5.avg]
    logging.info("Loss: {:.4f}, Prec@1: {:.2f}, Prec@5: {:.2f}".format(*f_l))

    return top1.avg, top5.avg

In [9]:
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info("----------- Network Initialization --------------")

    # 加载模型
    net = define_tsnet(name=args.net_name, num_class=args.num_class, cuda=args.cuda)
    checkpoint = torch.load("/home/corner/Knowledge-Distillation-Zoo/results/base/test-c100-r20/initial_r20.pth.tar")
    net.load_state_dict(checkpoint["net"])
    logging.info("%s", net)
    logging.info("param size = %fMB", count_parameters_in_MB(net))
    logging.info("-----------------------------------------------")

    # save initial parameters
    logging.info("Saving initial parameters......")
    save_path = os.path.join(
        args.save_root, "initial_r{}.pth.tar".format(args.net_name[6:])
    )
    torch.save(
        {
            "epoch": 0,
            "net": net.state_dict(),
            "prec@1": 0.0,
            "prec@5": 0.0,
        },
        save_path,
    )

    # define loss functions
    if args.cuda:
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # define transforms
    if args.data_name == "cifar100":
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception("Invalid dataset name...")

    train_transform = transforms.Compose(
        [
            transforms.Pad(4, padding_mode="reflect"),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )

    # define data loader
    train_loader = torch.utils.data.DataLoader(
        dataset(
            root=args.img_root, transform=train_transform, train=True, download=True
        ),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset(
            root=args.img_root, transform=test_transform, train=False, download=True
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    for epoch in range(1, args.epochs + 1):
        # evaluate on testing set
        logging.info("Testing the models......")
        test_top1, test_top5 = test(test_loader, net, criterion)

        # save model
        logging.info(f"Saving models for epoch{epoch}......")
        save_path = os.path.join(
            args.save_root, "epoch_{}.pth.tar".format(epoch)
        )
        torch.save({
            "epoch": epoch,
            "net": net.state_dict(),
            "prec@1": test_top1,
            "prec@5": test_top5,
        }, save_path)