<a href="https://colab.research.google.com/github/AstrakhantsevaAA/confidence_estimation_resnet/blob/master/confidence_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%%writefile requirements.txt
torch
torchvision
tqdm
pillow
matplotlib

Writing requirements.txt


In [0]:
!pip install -r requirements.txt

In [0]:
!pip uninstall tensorboard
!pip uninstall tensorflow

In [0]:
!pip install --ignore-installed tf-nightly

In [0]:
%load_ext tensorboard

In [0]:
import numpy as np
from tqdm import tqdm
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms, models
from google.colab import files

# from model import resnet 
import resnet #upload resnet.py to Google Colab


In [0]:
def encode_onehot(labels, n_classes):
    onehot = torch.FloatTensor(labels.size()[0], n_classes)
    labels = labels.data
    if labels.is_cuda:
        onehot = onehot.cuda()
    onehot.zero_()
    onehot.scatter_(1, labels.view(-1, 1), 1)
    return onehot

In [0]:
np.random.seed(0)
torch.cuda.manual_seed(0)
cudnn.deterministic = True
cudnn.benchmark = False

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

baseline = False
if baseline:
    budget = 0.
else: budget = 0.9

filename = f'CIFAR_resnet_{budget}'

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [0]:
# Image Preprocessing
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])


test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])


num_classes = 10
train_dataset = datasets.CIFAR10(root='data/',
                                     train=True,
                                     transform=train_transform,
                                     download=True)

test_dataset = datasets.CIFAR10(root='data/',
                                    train=False,
                                    transform=test_transform,
                                    download=True)




# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=128,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=2)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100,
                                          shuffle=False,
                                          pin_memory=True,
                                          num_workers=2)


In [0]:
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
cnn.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

In [0]:
cnn = resnet.resnet18(num_classes=num_classes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn = cnn.to(device)
if device == 'cuda':
    cnn = torch.nn.DataParallel(cnn)

prediction_criterion = nn.CrossEntropyLoss()
cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=1.0e-3,
                      momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)

In [0]:
writer = SummaryWriter()
%tensorboard --logdir=runs

In [0]:
def train_model(model, loss, optimizer, scheduler, num_epochs):
    global best_acc
    lmbda = 0.1

    accuracy = np.array([])
    log_loss = np.array([])
    conf_loss = np.array([])

    for epoch in range(start_epoch, start_epoch+num_epochs):
        print('Epoch {}/{}:'.format(epoch, num_epochs - 1), flush=True)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_loader
                model.train()  # Set model to training mode
            else:
                dataloader = test_loader
                model.eval()   # Set model to evaluate mode

            running_loss = 0.
            running_acc = 0.
            running_confidence = []
            running_conf_loss = 0.

            # Iterate over data.
            for images, labels in tqdm(dataloader):
                images = images.to(device)
                labels = labels.to(device)
                labels_onehot = encode_onehot(labels, num_classes)

                optimizer.zero_grad()

                # forward and backward
                with torch.set_grad_enabled(phase == 'train'):
                    pred_original, confidence = model(images)
                    pred_original = F.softmax(pred_original, dim=-1)
                    confidence = torch.sigmoid(confidence)

                    # Make sure we don't have any numerical instability
                    if phase == 'train':
                        eps = 1e-12
                        pred_original = torch.clamp(pred_original, 0. + eps, 1. - eps)
                        confidence = torch.clamp(confidence, 0. + eps, 1. - eps)

                        if baseline:
                            # Randomly set half of the confidences to 1 (i.e. no hints)
                            b = torch.bernoulli(torch.Tensor(confidence.size()).uniform_(0, 1)).to(device)
                            conf = confidence * b + (1 - b)
            
                            pred_new = pred_original * conf.expand_as(pred_original) + labels_onehot * (1 - conf.expand_as(labels_onehot))
                            pred_original = torch.log(pred_new)
                        else:
                            pred_original = torch.log(pred_original)
            
                    xentropy_loss = loss(pred_original, labels)
                    confidence_loss = torch.mean(-torch.log(confidence))

                    if phase == 'train':
                        if baseline:
                            total_loss = xentropy_loss
                        else:
                            total_loss = xentropy_loss + (lmbda * confidence_loss)

                            if budget > confidence_loss.item():
                                lmbda = lmbda / 1.01
                            elif budget <= confidence_loss.item():
                                lmbda = lmbda / 0.99

                    pred_idx = pred_original.argmax(dim=1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        total_loss.backward()
                        optimizer.step()

                # statistics
                running_loss += xentropy_loss.item()
                running_acc += (pred_idx == labels.data).float().mean()
                running_conf_loss += confidence_loss.item()
                if phase == 'val':
                    running_confidence.extend(confidence.cpu().numpy())

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)
            epoch_conf_loss = running_conf_loss / len(dataloader)
            

            print(f'\n {phase} Loss: {epoch_loss:.4f} Confidence Loss: {epoch_conf_loss:.4f} Acc: {epoch_acc:.4f}', flush=True)

            if phase == 'val':
                conf_min = np.min(np.array(running_confidence))
                conf_max = np.max(running_confidence)
                conf_avg = np.mean(running_confidence)
                print(f'conf_min: {conf_min:.3f}, conf_max: {conf_max:.3f}, conf_avg: {conf_avg:.3f}')

            writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch)
            writer.add_scalar(f'ConfLoss/{phase}', epoch_conf_loss, epoch)
            writer.add_scalar(f'Accuracy/{phase}', epoch_acc, epoch)

            accuracy = np.append(accuracy, epoch_acc.cpu())
            log_loss = np.append(log_loss, epoch_loss)
            conf_loss = np.append(conf_loss, epoch_conf_loss)
    
            data = {
                'accuracy': accuracy,
                'loss': log_loss,
                'confidence_loss': conf_loss
            }

            if not os.path.isdir('accs_losses'):
                os.mkdir('accs_losses')
            torch.save(data, f'./accs_losses/{phase}_accs_losses_{budget}.pth')

            if phase == 'train':
                scheduler.step(epoch)


    return model

In [0]:
train_model(cnn, prediction_criterion, cnn_optimizer, scheduler, num_epochs=100);

In [0]:
assert os.path.isdir('accs_losses'), 'Error: no accs_losses directory found!'
data_train = torch.load(f'./accs_losses/train_accs_losses_{budget}.pth')
data_val = torch.load(f'./accs_losses/val_accs_losses_{budget}.pth')
files.download(f"./accs_losses/train_accs_losses_{budget}.pth")
files.download(f"./accs_losses/val_accs_losses_{budget}.pth")

acc_train = data_new['accuracy']
loss_train = data_new['loss']
confloss_train = data_new['confidence_loss']

acc_test = data_new['accuracy']
loss_test = data_new['loss']
confloss_test = data_new['confidence_loss']

In [0]:
data_new = torch.load(f'./accs_losses/train_accs_losses_0.0.pth')
acc_train_budget = data_new['accuracy']
loss_train_budget = data_new['loss']
loss_train_budget = data_new['confidence_loss']

In [0]:
plt.figure(figsize=(18, 9))
plt.plot(np.arange(len(acc_train)), acc_train, label=f'new Train, budget=0.9, Acc: {acc_train[-1]}')
plt.plot(np.arange(len(loss_train)), loss_train, label=f'new Test,budget=0.9,  Acc: {loss_train[-1]}')
plt.plot(np.arange(len(acc_train_budget09)), acc_train_budget09, label=f'old Train, budget=0.9, Acc: {acc_train_budget09[-1]}', ls='--')
plt.plot(np.arange(len(acc_test_budget09)), acc_test_budget09, label=f'old Test, budget=0.9, Acc: {acc_test_budget09[-1]}', ls='--')
plt.title(f'Accuracy (CIFAR10 ResNet with confidence branch, budget = 0.9)')
plt.legend()

Confidence branch заметно ускоряет обучение, и дает небольшой прирост в точности.