In [1]:
#!git clone https://github.com/jcpeterson/cifar-10h
%cd cifar-10h

/data/user-data/sa25729/cifar-10h


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import models
import torchvision.transforms as transforms
import os
import argparse
import copy
import random
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def seed_everything(seed=12):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
parser = argparse.ArgumentParser(description='CIFAR-10H Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lr_schedule', default=0, type=int, help='lr scheduler')
parser.add_argument('--batch_size', default=1024, type=int, help='batch size')
parser.add_argument('--test_batch_size', default=2048, type=int, help='batch size')
parser.add_argument('--num_epoch', default=100, type=int, help='epoch number')
parser.add_argument('--num_classes', type=int, default=10, help='number classes')
args = parser.parse_args(args=[])

In [3]:
def train(model, trainloader, criterion, optimizer):
    model.train()
    for batch_idx, (inputs, targets, ad) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets,conf_score)
        loss.backward()
        optimizer.step()

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

In [4]:
from PIL import Image
import numpy as np
import torchvision

class CIFAR10H(torchvision.datasets.CIFAR10):

    def __init__(self, root,  rand_number=0, train=False, transform=None, target_transform=None,
                 download=False):
        super(CIFAR10H, self).__init__(root, train, transform, target_transform, download) 
        self.transform = transform
        self.target_transform = target_transform
        self.ad = np.load(os.path.join(root,'cifar10h-probs.npy'))

    def __getitem__(self, index: int):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        ad = self.ad[index]
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target, ad

class CELossWithLS_CCA(torch.nn.Module):
    
    def __init__(self, classes= args.num_classes, smoothing=0.16, ignore_index=-1):
        super(CELossWithLS_CCA, self).__init__()
        self.smoothing = smoothing
        self.complement = 1.0 - smoothing
        self.cls = classes
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.ignore_index = ignore_index

    def forward(self, logits, target, conf_score):
        with torch.no_grad():
            new_smoothing  = self.smoothing - conf_score/10
            new_complement = 1 - new_smoothing
            oh_labels = F.one_hot(target.to(torch.int64), num_classes = self.cls).contiguous()
            smoothen_ohlabel = oh_labels * new_complement + new_smoothing / self.cls
        
        logs = self.log_softmax(logits[target!=self.ignore_index])
        return -torch.sum(logs * smoothen_ohlabel[target!=self.ignore_index], dim=1).mean()


In [5]:
seed_everything()
conf_score = torch.tensor([0.8265, 0.8410, 0.7920, 0.7833, 0.7851, 0.8231, 0.8496, 0.8212, 0.8126,
        0.8997])
conf_score = conf_score.to(device)
mean_cifar10, std_cifar10 = (0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(), transforms.ToTensor(),
            transforms.Normalize(mean_cifar10, std_cifar10), ])
transform_test = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean_cifar10, std_cifar10),])

train_dataset = CIFAR10H(root='./data', train=False, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
print('train samples:',len(train_dataset), 'test samples:',len(test_dataset))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=2)

model = models.resnet34(pretrained=True).to(device)
model.fc = nn.Linear(model.fc.in_features, args.num_classes)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=False, weight_decay=0.0001)
#criterion = nn.CrossEntropyLoss()
criterion = CELossWithLS_CCA().to(device)

best_epoch, best_acc = 0.0, 0
for epoch in range(args.num_epoch):
    train(model, train_loader, criterion, optimizer)
    accuracy = test(model, test_loader)
    if accuracy > best_acc:
        patience = 0
        best_acc = accuracy
        best_epoch = epoch
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), 'best_model_cifar10h.pth_LS_cca_LS.tar')
    print('epoch: {}  acc: {:.4f}  best epoch: {}  best acc: {:.4f}'.format(
            epoch, accuracy, best_epoch, best_acc, optimizer.param_groups[0]['lr']))


Files already downloaded and verified
Files already downloaded and verified
train samples: 10000 test samples: 50000
epoch: 0  acc: 0.4281  best epoch: 0  best acc: 0.4281
epoch: 1  acc: 0.5009  best epoch: 1  best acc: 0.5009
epoch: 2  acc: 0.6504  best epoch: 2  best acc: 0.6504
epoch: 3  acc: 0.6152  best epoch: 2  best acc: 0.6504
epoch: 4  acc: 0.6039  best epoch: 2  best acc: 0.6504
epoch: 5  acc: 0.6633  best epoch: 5  best acc: 0.6633
epoch: 6  acc: 0.7363  best epoch: 6  best acc: 0.7363
epoch: 7  acc: 0.7502  best epoch: 7  best acc: 0.7502
epoch: 8  acc: 0.7327  best epoch: 7  best acc: 0.7502
epoch: 9  acc: 0.7601  best epoch: 9  best acc: 0.7601
epoch: 10  acc: 0.7657  best epoch: 10  best acc: 0.7657
epoch: 11  acc: 0.7514  best epoch: 10  best acc: 0.7657
epoch: 12  acc: 0.7539  best epoch: 10  best acc: 0.7657
epoch: 13  acc: 0.7574  best epoch: 10  best acc: 0.7657
epoch: 14  acc: 0.7617  best epoch: 10  best acc: 0.7657
epoch: 15  acc: 0.7634  best epoch: 10  best acc

CCA: get confidence score

In [6]:
seed_everything()
mean_cifar10, std_cifar10 = (0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(), transforms.ToTensor(),
            transforms.Normalize(mean_cifar10, std_cifar10), ])
transform_test = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean_cifar10, std_cifar10),])

train_dataset = CIFAR10H(root='./data', train=False, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
print('train samples:',len(train_dataset), 'test samples:',len(test_dataset))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=2)

model = models.resnet34(pretrained=True).to(device)
model.fc = nn.Linear(model.fc.in_features, args.num_classes)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=False, weight_decay=0.0001)
#criterion = nn.CrossEntropyLoss()
criterion = CELossWithLS().to(device)
model.load_state_dict(torch.load('best_model_cifar10h.pth_LS.tar'))

Files already downloaded and verified
Files already downloaded and verified
train samples: 10000 test samples: 50000


<All keys matched successfully>

In [9]:
model.load_state_dict(torch.load('best_model_cifar10h.pth_LS.tar'))
def get_conf_freq (model, dataloader):
    model.eval()
    conf_score = torch.zeros([10]).to(device)
    count = torch.zeros([10]).to(device)
    with torch.no_grad():
        for batch_idx, (inputs, targets, ad) in enumerate(dataloader):
            inputs, targets, ad = inputs.to(device), targets.to(device), ad.to(device)
            outputs = model(inputs)
            softmaxes = F.softmax(outputs, dim=1)

            for i in range (len(targets)):
                confidence = softmaxes[i][targets[i]]
                conf_score[targets[i]] += confidence
                count[targets[i]] += 1
            conf_avg = conf_score/count
    return conf_avg

conf_score = get_conf_freq(model, train_loader)
conf_score


tensor([0.8265, 0.8410, 0.7920, 0.7833, 0.7851, 0.8231, 0.8496, 0.8212, 0.8126,
        0.8997], device='cuda:0')

Train the model with CCA

In [22]:
class CELossWithLS_conf(torch.nn.Module):
    def __init__(self, classes= 10, smoothing=0.16, ignore_index=-1):
        super(CELossWithLS_conf, self).__init__()
        self.smoothing = smoothing
        self.complement = 1.0 - smoothing
        self.cls = classes
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.ignore_index = ignore_index

    def forward(self, logits, target, conf_score):
        with torch.no_grad():
            new_smoothing  = self.smoothing - conf_score/10
            new_complement = 1 - new_smoothing
            oh_labels = F.one_hot(target.to(torch.int64), num_classes = self.cls).contiguous()
            smoothen_ohlabel = oh_labels * new_complement + new_smoothing / self.cls
        
        logs = self.log_softmax(logits[target!=self.ignore_index])
        return -torch.sum(logs * smoothen_ohlabel[target!=self.ignore_index], dim=1).mean()
    
    
def train_cca(model, trainloader, criterion, optimizer, conf_score):
    model.train()
    for batch_idx, (inputs, targets, ad) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets, conf_score)
        loss.backward()
        optimizer.step()


criterion_cca = CELossWithLS_conf().to(device)
best_epoch, best_acc = 0.0, 0
for epoch in range(args.num_epoch):
    train_cca(model, train_loader, criterion_cca, optimizer, conf_score)
    accuracy = test(model, test_loader)
    if accuracy > best_acc:
        patience = 0
        best_acc = accuracy
        best_epoch = epoch
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), 'best_model_cifar10h.pth_LS_cca.tar')
    print('epoch: {}  acc: {:.4f}  best epoch: {}  best acc: {:.4f}'.format(
            epoch, accuracy, best_epoch, best_acc, optimizer.param_groups[0]['lr']))


epoch: 0  acc: 0.7211  best epoch: 0  best acc: 0.7211
epoch: 1  acc: 0.7010  best epoch: 0  best acc: 0.7211
epoch: 2  acc: 0.7431  best epoch: 2  best acc: 0.7431
epoch: 3  acc: 0.7392  best epoch: 2  best acc: 0.7431
epoch: 4  acc: 0.7680  best epoch: 4  best acc: 0.7680
epoch: 5  acc: 0.7465  best epoch: 4  best acc: 0.7680
epoch: 6  acc: 0.7725  best epoch: 6  best acc: 0.7725
epoch: 7  acc: 0.7721  best epoch: 6  best acc: 0.7725
epoch: 8  acc: 0.7481  best epoch: 6  best acc: 0.7725
epoch: 9  acc: 0.7557  best epoch: 6  best acc: 0.7725
