In [6]:
import argparse
import os
import torch
import random 
import numpy as np 
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
from itertools import cycle
import pickle
import sys

from resnetv2 import PreActResNet18 as ResNet18  
from utils import Labeled_dataset


parser = argparse.ArgumentParser(description='PyTorch Cifar10_100 Training')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--data_dir', help='The directory for data', default='trans_data', type=str)
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run')
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--decreasing_lr', default='60,80', help='decreasing strategy')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='cifar10_cil', type=str)
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--load_model', default=False, type=eval, choices=[True, False], help='load last checkpoint to continue training')

best_prec1 = 0

In [2]:
def validate(val_loader, model, criterion, if_main=False):

    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.long().cuda()

        # compute output
        with torch.no_grad():
            output = model(input, main_fc=if_main)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(val_loader), loss=losses, top1=top1))

    print('valid_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg



In [3]:
def train(rand_loader, new_balance_loader, old_balance_loader, model, criterion, optimizer, epoch):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    coef_old = int(args.batch_size*4/5)/64
    coef_new = int(args.batch_size/5)/64

    # switch to train mode
    model.train()

    new_balance = iter(new_balance_loader)
    old_balance = iter(old_balance_loader)

    for i, (input, target) in enumerate(rand_loader):

        try:
            bal_new_img, bal_new_target = next(new_balance)
        except StopIteration:
            new_balance = iter(new_balance_loader)
            bal_new_img, bal_new_target = next(new_balance)

        try:
            bal_old_img, bal_old_target = next(old_balance)
        except StopIteration:
            old_balance = iter(old_balance_loader)
            bal_old_img, bal_old_target = next(old_balance)

        bal_new_img = bal_new_img.cuda()
        bal_old_img = bal_old_img.cuda()
        input = input.cuda()

        bal_new_target = bal_new_target.long().cuda()
        bal_old_target = bal_old_target.long().cuda()
        target = target.long().cuda()

        # random input
        output_gt = model(input, main_fc=False)
        loss_rand = criterion(output_gt, target)
        
        
        # balance inputs
        output_bal_new = model(bal_new_img, main_fc=True)
        output_bal_old = model(bal_old_img, main_fc=True)
        loss_balance = criterion(output_bal_new, bal_new_target)*coef_new + criterion(output_bal_old, bal_old_target)*coef_old
        
        if tensor_allNaN(output_gt):
            print('output_gt')
            sys.exit()
        if tensor_allNaN(loss_rand):
            print('loss_rand')
            sys.exit()
        if tensor_allNaN(output_bal_new):
            print('output_bal_new')
            sys.exit()
        if tensor_allNaN(output_bal_old):
            print('output_bal_old')
            sys.exit()
        if tensor_allNaN(loss_balance):
            print('loss_balance')  
            sys.exit()
        
        
        loss = (loss_balance + loss_rand)*0.5

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output_gt.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]   
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(rand_loader), loss=losses, top1=top1))
            

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg



In [4]:
def save_checkpoint(state, filename='weight.pt'):
    """
    Save the training model
    """
    torch.save(state, filename)



In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



In [4]:
# output.shape = [batch size, num_class]
# a = torch.tensor([[0.1, 0.9, 0.1], [0.8, 0.2, 0.1], [0.7, 0.3, 0.9]])
# b = torch.tensor([1, 1, 2])
# accuracy(a, b, topk=(, 2))
# [tensor(66.6667), tensor(100.)]


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        
        res.append(correct_k.mul_(100.0 / batch_size))
    return res



In [None]:
##### TO-DO #######
# output.shape = [batch size, num_class]

def confusion_classWise(output, target):
    '''class wise confusion matrix'''
    
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        
        res.append(correct_k.mul_(100.0 / batch_size))
    return res



In [7]:
def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 



In [8]:
def tensor_allNaN(tensor):
    all_Nan = torch.isnan(tensor).all()
    if all_Nan:
        return True
    else:
        return False

# ==== below ================

### prepare data

In [9]:
global args, best_prec1

In [10]:
# jupyter notebook input workaround
# args = parser.parse_args()
args = parser.parse_args(args=['--save_dir', 'output', 
                               '--data_dir', 'trans_data', 
                               '--gpu', '2', 
                               '--epochs', '200', 
                               '--load_model', 'False',
                               '--seed', '1'
                              ])

print(args)

Namespace(batch_size=128, data_dir='trans_data', decreasing_lr='60,80', epochs=200, gpu=2, load_model=False, lr=0.1, momentum=0.9, print_freq=50, save_dir='output', seed=1, weight_decay=0.0005)


In [11]:
path_head = args.data_dir
train_path = os.path.join(path_head,'4500_labeled_images_cifar10_train.pkl')
old_img_path = os.path.join(path_head,'100_labeled_images_cifar10_train.pkl')
val_path = os.path.join(path_head,'500_labeled_images_cifar10_val.pkl')
test_path = os.path.join(path_head,'labeled_images_cifar10_test.pkl')
sequence = np.random.permutation(10)
print('class sequence: ', sequence)

torch.cuda.set_device(int(args.gpu))

if args.seed:
    setup_seed(args.seed)

os.makedirs(args.save_dir, exist_ok=True)

train_trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor()
    ])

val_trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])

criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
decreasing_lr = list(map(int, args.decreasing_lr.split(',')))

model = ResNet18(num_classes=10)
model.cuda()

starting_epoch = 0

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)


#prepare dataset
train_dataset = Labeled_dataset(train_path, train_trans, sequence[8:10], offset=8)
train_old_dataset = Labeled_dataset(old_img_path, train_trans, sequence[:8], offset=0)
val_dataset = Labeled_dataset(val_path, val_trans, sequence[:10], offset=0)

train_random_dataset = torch.utils.data.dataset.ConcatDataset((train_dataset,train_old_dataset))

class sequence:  [9 1 7 3 5 0 4 6 8 2]
target list =  [8 2]
target list =  [9 1 7 3 5 0 4 6]
target list =  [9 1 7 3 5 0 4 6 8 2]


In [12]:
train_loader_random = torch.utils.data.DataLoader(
    train_random_dataset,
    batch_size=args.batch_size, shuffle=True,
    num_workers=2, pin_memory=True)

train_loader_balance_new = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=int(args.batch_size/5), shuffle=True,
    num_workers=2, pin_memory=True)

train_loader_balance_old = torch.utils.data.DataLoader(
    train_old_dataset,
    batch_size=int(args.batch_size*4/5), shuffle=True,
    num_workers=2, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=2, pin_memory=True)

### start to train

In [13]:
# if load a checkpoint or not
if args.load_model:
    checkpoint =  torch.load(os.path.join(args.save_dir, 'checkpoint.pt'))
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    starting_epoch = checkpoint['epoch']
    best_prec1 = checkpoint['best_prec1']
    print('starting_epoch: ', starting_epoch)
    print('best_prec1: ', best_prec1)



# index, acc
train_acc = [[],[]]
ta_bal = [[],[]]
ta_imba = [[],[]]


for epoch in range(starting_epoch, starting_epoch + args.epochs):
    print('='*50)
    print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

    train_accuracy = train(train_loader_random, train_loader_balance_new, train_loader_balance_old, model, criterion, optimizer, epoch)
    
    prec1_bal = validate(val_loader, model, criterion, if_main=True)
    prec1_imba = validate(val_loader, model, criterion, if_main=False)
    

    train_acc[0].append(epoch)
    ta_bal[0].append(epoch)
    ta_imba[0].append(epoch)
    
    train_acc[1].append(train_accuracy)
    ta_bal[1].append(prec1_bal)
    ta_imba[1].append(prec1_imba)

    scheduler.step()

    # remember best prec@1 and save checkpoint
    is_best = prec1_bal > best_prec1
    best_prec1 = max(prec1_bal, best_prec1)

    if is_best:
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }, filename=os.path.join(args.save_dir, 'best_model.pt'))

    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
    }, filename=os.path.join(args.save_dir, 'checkpoint.pt'))

    plt.plot(train_acc[0], train_acc[1], label='train_acc')
    plt.plot(ta_bal[0], ta_bal[1], label='TA_bal')
    plt.plot(ta_imba[0], ta_imba[1], label='TA_imba')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
    plt.close()




The learning rate is 0.1
Epoch: [0][0/77]	Loss 3.3693 (3.3693)	Accuracy 20.312 (20.312)
Epoch: [0][50/77]	Loss 4.8590 (19.7423)	Accuracy 39.062 (44.531)
train_accuracy 45.449
Test: [0/40]	Loss 9.1246 (9.1246)	Accuracy 0.000 (0.000)
valid_accuracy 8.920
Test: [0/40]	Loss 6.2905 (6.2905)	Accuracy 0.000 (0.000)
valid_accuracy 10.140
The learning rate is 0.1
Epoch: [1][0/77]	Loss 8.5237 (8.5237)	Accuracy 35.156 (35.156)
Epoch: [1][50/77]	Loss 2.9866 (7.6616)	Accuracy 46.094 (48.284)
train_accuracy 47.786
Test: [0/40]	Loss 2.2665 (2.2665)	Accuracy 0.000 (0.000)
valid_accuracy 10.220
Test: [0/40]	Loss 4.4859 (4.4859)	Accuracy 0.000 (0.000)
valid_accuracy 10.000
The learning rate is 0.1
Epoch: [2][0/77]	Loss 2.9102 (2.9102)	Accuracy 44.531 (44.531)
Epoch: [2][50/77]	Loss 2.8086 (2.8577)	Accuracy 40.625 (45.389)
train_accuracy 45.235
Test: [0/40]	Loss 2.2977 (2.2977)	Accuracy 0.000 (0.000)
valid_accuracy 10.340
Test: [0/40]	Loss 4.4767 (4.4767)	Accuracy 0.000 (0.000)
valid_accuracy 10.400
The 