In [1]:
from resnet import resnet32
from cifar import get_cifar10
import utils 
import pencil as pencil_lib

import torchvision.transforms as transforms
import torch
import torch.optim as optim
import torch.nn as nn
from scipy.special import softmax

import copy as cp
import os
import pickle
import numpy as np

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

if True: # Define hyper parameters
    epochs = 320
    batch_size = 128
    lr = 0.06
    momentum = 0.9
    weight_decay = 10e-4
    n_classes = 10
    pencil_epochs = [70, 130, 120]
    pencil_lrs = [0.06, 0.06, 0.06]
    pencil_alpha, pencil_beta, pencil_gamma = 0.1, 0.4, 600
    pencil_KL = True
    pencil_K = 10
    checkpoint = 5
    start_epoch = 0
    seed = 1
    train_ratio = 0.9
    asym_noise = True
    noise_ratio = 0.2
    temp_path = 'PATH_HERE'
    cifar_path = './../data'

if True: # Define dataset and transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    np.random.seed(seed)
    trainset, valset = get_cifar10(cifar_path, train_ratio=train_ratio, asym=asym_noise, percent=noise_ratio, train=True, download=False, transform_train=transform_train,
                                               transform_val=transform_val)
    if start_epoch==0: # 
        print('trainset',len(trainset))
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)
#         with open(temp_path+"trainloader"+".pkl", "wb") as file:
#             pickle.dump(trainloader,file)
#         with open(temp_path+"valloader"+".pkl", "wb") as file:
#             pickle.dump(valloader,file)
#     else: 
#         with open(temp_path+"trainloader"+".pkl", "rb") as file:
#             trainloader = pickle.load(file)
#         with open(temp_path+"valloader"+".pkl", "rb") as file:
#             valloader = pickle.load(file)

if True: # Define Model and training hyper parameters

    # Model
    print("==> creating preact_resnet")
    model = resnet32()
    model = model.cuda()
    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

#     PENCIL
    pencil = pencil_lib.PENCIL(len(trainset), n_classes, pencil_epochs, pencil_lrs, 
                                               pencil_alpha, pencil_beta, pencil_gamma, save_losses=True, use_KL=pencil_KL, K=pencil_K)

if True: # Define dict and list for storing results
    results = {'train':None,'val':None}
    metrics = ['acc','auprc','auroc','loss']
    for subset in results.keys(): 
        results[subset] = {metric:[] for metric in metrics}

    results['pencil_labels'], results['labels'], results['true_labels'] = [], [], []
    results['pencil_label_drift'], results['pencil_label_correct'] = [], []
    for i in range(len(trainset)): 
        item = trainset.__getitem__(i)
        results['labels'].append(item[1])
        results['true_labels'].append(item[3])

if start_epoch>0: # Load checkpoint data
    model.load_state_dict(torch.load(temp_path+'epoch'+str(start_epoch)))
    with open(temp_path+"results"+str(start_epoch)+".pkl", "rb") as file:
        results = pickle.load(file)
    with open(temp_path+"pencil"+str(start_epoch)+".pkl", "rb") as file:
        pencil = pickle.load(file)
    pencil.alpha=pencil_alpha
    pencil.beta=pencil_beta
    
for epoch in range(start_epoch,epochs):
    
    if True: # Train 
        pencil.set_lr(optimizer, epoch)
        model.train()
        preds, label_list = [], []
        if epoch in pencil_epochs: results['pencil_labels'].append(cp.deepcopy(pencil.y_tilde))

        for batch_idx, (inputs, labels, inds, gtrue_labels) in enumerate(trainloader):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = pencil.get_loss(epoch, outputs, labels, inds)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pencil.update_y_tilde(epoch, inds)
            preds.append(outputs.detach())
            label_list.append(labels.cpu())

        label_list = np.vstack(torch.cat(label_list).numpy())
        preds = softmax(np.vstack(torch.cat(preds).cpu().numpy()), axis=1)
        train_results = utils.get_performance_metrics(num_classes=n_classes,preds=preds,label_list=label_list)
        for result in train_results.keys():
            results['train'][result].append(train_results[result])
        label_drift = round(sum(results['labels']==np.argmax(pencil.y_tilde,axis=1))/len(results['labels']),4)
        label_true_corr = round(sum(results['true_labels']==np.argmax(pencil.y_tilde,axis=1))/len(results['true_labels']),4)
        results['pencil_label_drift'].append(label_drift)
        results['pencil_label_correct'].append(label_true_corr)
    
    if True: # Val
        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(valloader):
                inputs, labels = inputs.cuda(), labels
                outputs = model(inputs)
                val_preds.append(outputs)
                val_labels.append(labels)
        val_labels = np.vstack(torch.cat(val_labels).numpy())
        val_preds = softmax(np.vstack(torch.cat(val_preds).cpu().numpy()), axis=1)
        val_results = utils.get_performance_metrics(num_classes=n_classes,preds=val_preds,label_list=val_labels)
        for result in val_results.keys():
            results['val'][result].append(val_results[result])

    if True: # print results 
        print('Epoch', epoch, 'Current LR', round(optimizer.param_groups[0]['lr'],5))
        if epoch>=pencil_epochs[0]:
            print('Percent of estimated labels that have remained the same since end of phase 1',round(sum(np.argmax(results['pencil_labels'][0],axis=1)==np.argmax(pencil.y_tilde,axis=1))/len(results['pencil_labels'][0]),4))
        print('% y_tilde equal to noisy label',results['pencil_label_drift'][-1])
        print('% y_tilde equal to true label',results['pencil_label_correct'][-1])
        print('Train acc %.3f, Train auroc %.3f' % (results['train']['acc'][-1], results['train']['auroc'][-1]))
        print('Val acc %.3f, Val auroc %.3f' % (results['val']['acc'][-1], results['val']['auroc'][-1]))                                                                 
         
#     if checkpoint!=None and (epoch%checkpoint==0 or epoch+1 in pencil_epochs): #Save model for this epoch
#         if epoch%checkpoint==0 and epoch>start_epoch+checkpoint: 
#             os.remove(temp_path+'epoch'+str(epoch-checkpoint))
#             os.remove(temp_path+"results"+str(epoch-checkpoint)+".pkl")
#             os.remove(temp_path+"pencil"+str(epoch-checkpoint)+".pkl")
#         torch.save(model.state_dict(), temp_path+'epoch'+str(epoch))
#         with open(temp_path+"results"+str(epoch)+".pkl", "wb") as file:
#             pickle.dump(results,file)
#         with open(temp_path+"pencil"+str(epoch)+".pkl", "wb") as file:
#             pickle.dump(pencil,file)

Train: 45000 Val: 5000
trainset 45000
==> creating preact_resnet
Total params: 0.47M
Epoch 0 Current LR 0.06
% y_tilde equal to noisy label 1.0
% y_tilde equal to true label 0.9
Train acc 0.359, Train auroc 0.819
Val acc 0.386, Val auroc 0.862
Epoch 1 Current LR 0.06
% y_tilde equal to noisy label 1.0
% y_tilde equal to true label 0.9
Train acc 0.530, Train auroc 0.898
Val acc 0.443, Val auroc 0.865


KeyboardInterrupt: 