In [1]:
%load_ext autoreload
%autoreload 2

import os, argparse, sys, time
sys.path.append('..')
from pathlib import Path
import torch
from torch import nn, optim
from tensorboardX import SummaryWriter

# from torchvision import datasets, transforms
#from VIB.model import Net
from SimpleClass.DataLoader import loadCIFAR10
from VIB.vgg_model import VGG_IB
from VIB.training import train
from VIB.evaluation import validate
from VIB.default_params import load_parser

In [2]:
parser = argparse.ArgumentParser()
parser = load_parser(parser)
args = parser.parse_args([])
print(args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Namespace(ban_crop=False, ban_flip=False, batch_norm=False, batchsize=128, cfg='D0', data_set='cifar10', epochs=300, gpu=0, ib_lr=-1, ib_wd=-1, init_var=0.01, kl_fac=1e-06, lr=0.1, lr_epoch=30, lr_fac=0.5, mag=9, momentum=0.9, no_ib=False, opt='sgd', print_freq=50, reg_weight=0, resume='', resume_vgg_pt='', resume_vgg_vib='', sample_test=0, sample_train=1, save_dir='../models/ib_vgg', tb_path='../tb/ib_vgg', threshold=0, val=False, weight_decay=0.0001, workers=1)


In [3]:
import torchvision
from torchvision import datasets, transforms

In [4]:
    writer = SummaryWriter(args.tb_path)
    if args.ib_lr == -1:
        args.ib_lr = args.lr
    if args.ib_wd == -1:
        args.ib_wd = args.weight_decay
    if not os.path.exists(args.tb_path):
        os.makedirs(args.tb_path)

    n_cls = 10 if args.data_set == 'cifar10' else 100
    dset_string = 'datasets.CIFAR10' if args.data_set == 'cifar10' else 'datasets.CIFAR100'
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_tfms = [transforms.ToTensor(), normalize]
    if not args.ban_flip:
        train_tfms = [transforms.RandomHorizontalFlip()] + train_tfms
    if not args.ban_crop:
        train_tfms = [transforms.RandomCrop(32, 4)] + train_tfms
    train_transform = transforms.Compose(train_tfms)
    val_transorm = transforms.Compose([transforms.ToTensor(),normalize])
    train_loader = torch.utils.data.DataLoader(
        eval(dset_string)(root='../data', train=True, transform=train_transform, download=True),
        batch_size=args.batchsize, shuffle=True, num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        eval(dset_string)(root='../data', train=False, transform=val_transorm),
        batch_size=args.batchsize, shuffle=False, num_workers=args.workers, pin_memory=True)

Files already downloaded and verified


In [5]:
    model = VGG_IB(config=args.cfg, mag=args.mag, batch_norm=args.batch_norm, 
                    threshold=args.threshold, init_var=args.init_var, 
                    sample_in_training=args.sample_train, sample_in_testing=args.sample_test, 
                    n_cls=n_cls, no_ib=args.no_ib, device=device).to(device)
    ib_param_list, ib_name_list, cnn_param_list, cnn_name_list = [], [], [], []
    for name, param in model.named_parameters():
        if 'z_mu' in name or 'z_logD' in name:
            ib_param_list.append(param)
            ib_name_list.append(name)
        else:
            cnn_param_list.append(param)
            cnn_name_list.append(name)
    print('detected VIB params ({}): {}'.format(len(ib_name_list), ib_name_list))
    print('detected VGG params ({}): {}'.format(len(cnn_name_list), cnn_name_list))
    print('Learning rate of IB: {}, learning rate of others: {}'.format(args.ib_lr, args.lr))
    if args.opt.lower() == 'sgd':
        optimizer = torch.optim.SGD([{'params': ib_param_list, 'lr': args.ib_lr, 'weight_decay': args.ib_wd}, 
                                     {'params': cnn_param_list, 'lr': args.lr, 'weight_decay':args.weight_decay}], 
                                    momentum=args.momentum)
    elif args.opt.lower() == 'adam':
        optimizer = torch.optim.Adam([{'params': ib_param_list, 'lr': args.ib_lr, 'weight_decay': args.ib_wd}, 
                                      {'params': cnn_param_list, 'lr': args.lr, 'weight_decay': args.weight_decay}])

Using structure [(64, 1), (64, 1), 'M', (128, 1), (128, 1), 'M', (256, 1), (256, 1), (256, 1), 'M', (512, 1), (512, 1), (512, 1), 'M', (512, 1), (512, 1), (512, 1), 'M']
detected VIB params (45): ['conv_layers.2.prior_z_logD', 'conv_layers.2.post_z_mu', 'conv_layers.2.post_z_logD', 'conv_layers.5.prior_z_logD', 'conv_layers.5.post_z_mu', 'conv_layers.5.post_z_logD', 'conv_layers.9.prior_z_logD', 'conv_layers.9.post_z_mu', 'conv_layers.9.post_z_logD', 'conv_layers.12.prior_z_logD', 'conv_layers.12.post_z_mu', 'conv_layers.12.post_z_logD', 'conv_layers.16.prior_z_logD', 'conv_layers.16.post_z_mu', 'conv_layers.16.post_z_logD', 'conv_layers.19.prior_z_logD', 'conv_layers.19.post_z_mu', 'conv_layers.19.post_z_logD', 'conv_layers.22.prior_z_logD', 'conv_layers.22.post_z_mu', 'conv_layers.22.post_z_logD', 'conv_layers.26.prior_z_logD', 'conv_layers.26.post_z_mu', 'conv_layers.26.post_z_logD', 'conv_layers.29.prior_z_logD', 'conv_layers.29.post_z_mu', 'conv_layers.29.post_z_logD', 'conv_layer

In [6]:
    torch.backends.cudnn.benchmark = True
    criterion = torch.nn.CrossEntropyLoss().to(device)

In [7]:
def main():
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    start_epoch = 0
    if args.resume != '':
        # resume from interrupted training
        state_dict = torch.load(args.resume, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['state_dict'])
        if 'opt_state_dict' in state_dict:
            optimizer.load_state_dict(state_dict['opt_state_dict'])
        model.print_compression_ratio(args.threshold)
        start_epoch = state_dict['epoch']
        print('loaded checkpoint {} at epoch {} with acc {}'.format(args.resume, state_dict['epoch'], state_dict['prec1'])) 
    if args.resume_vgg_pt:
        # VGG model trained without IB params
        state_dict = torch.load(args.resume_vgg_pt, map_location='cpu')
        try:
            print('loaded pretraind model with acc {}'.format(state_dict['best_prec1']))
        except:
            pass
        # match the state dicts
        ib_keys, vgg_keys = model.state_dict().keys(), state_dict['state_dict'].keys()
        for i in range(13):
            for j in range(6):
                model.state_dict()[ib_keys[i*9+j]].copy_(state_dict['state_dict'][vgg_keys[i*6+j]])
        ib_offset, vgg_offset = 9*13, 6*13
        for i in range(3):
            for j in range(2):
                model.state_dict()[ib_keys[ib_offset + i*5 + j]].copy_(state_dict['state_dict'][vgg_keys[vgg_offset + i*2+j]])
    if args.resume_vgg_vib:
        # VGG model trained without IB params
        state_dict = torch.load(args.resume_vgg_vib)
        print('loaded pretraind model with acc {}'.format(state_dict['prec1']))
        # match the state dicts
        ib_keys, vgg_keys = list(model.state_dict().keys()), list(state_dict['state_dict'].keys())
        for i in range(13):
            for j in range(6):
                model.state_dict()[ib_keys[i*9+j]].copy_(state_dict['state_dict'][ib_keys[i*9+j]])
        ib_offset, vgg_offset = 9*13, 6*13
        for i in range(2):
            for j in range(2):
                model.state_dict()[ib_keys[ib_offset + i*5 + j]].copy_(state_dict['state_dict'][vgg_keys[ib_offset + i*5 + j]])
    if args.val:
        model.eval()
        validate(val_loader, model, criterion, 0, None, device, args)
        return
    best_acc = -1
    for epoch in range(start_epoch, args.epochs):
        optimizer.param_groups[0]['lr'] = args.ib_lr * (args.lr_fac ** (epoch//args.lr_epoch))
        optimizer.param_groups[1]['lr'] = args.lr * (args.lr_fac ** (epoch//args.lr_epoch))
        train(train_loader, model, criterion, optimizer, epoch, writer, device, args)
        model.print_compression_ratio(args.threshold, writer, epoch)
        prune_acc = validate(val_loader, model, criterion, epoch, writer, device, args)
        writer.add_scalar('test_acc', prune_acc, epoch)
        if prune_acc > best_acc:
            best_acc = prune_acc
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
                'prec1': best_acc,
            }, os.path.join(args.save_dir, 'best_prune_acc.pth'))
        torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
                'prec1': prune_acc,
            }, os.path.join(args.save_dir, 'last_epoch.pth'))
    print('Best accuracy: {}'.format(best_acc))

In [8]:
main()

kl fac:1e-06


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same