#Import packages

In [None]:
# Import packages
import torch
import numpy as np
import torchvision.transforms as transforms
import random
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image
import torchvision.datasets as torchdivision_datasets
import pickle
import os
import torch.nn as nn
import torch.nn.functional as F
import copy
import torch.backends.cudnn as cudnn
from scipy.optimize import linear_sum_assignment as linear_assignment
import math

#Change the working directory

In [None]:
import os

# Change the current working directory to a directory in Google Drive
new_directory_path = "/content/drive/My Drive/Mercy college/Thesis"
os.chdir(new_directory_path)

In [None]:
# Verify the current working directory
current_directory = os.getcwd()
print("Current Working Directory:", current_directory)

Current Working Directory: /content/drive/My Drive/Mercy college/Thesis


#Common codes

##BasicBlock

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

##Bottleneck

In [None]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

##ResNet

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out


##Resnet_CIFAR

In [None]:
def Resnet_CIFAR():
    return ResNet(BasicBlock, [2, 2, 2, 2])

##ContrastiveModel

In [None]:
class ContrastiveModel(nn.Module):
    def __init__(self, backbone, head='mlp', features_dim=128):
        super(ContrastiveModel, self).__init__()
        self.backbone = backbone
        self.backbone_dim = 512
        self.head = head

        if head == 'linear':
            self.contrastive_head = nn.Linear(self.backbone_dim, features_dim)

        elif head == 'mlp':
            self.contrastive_head = nn.Sequential(
                    nn.Linear(self.backbone_dim, self.backbone_dim),
                    nn.ReLU(), nn.Linear(self.backbone_dim, features_dim))

        else:
            raise ValueError('Invalid head {}'.format(head))

    def forward(self, x):
        features = self.contrastive_head(self.backbone(x))
        features = F.normalize(features, dim = 1)
        return features


##ClusteringModel

In [None]:
class ClusteringModel(nn.Module):
    def __init__(self, backbone, class_num):
        super(ClusteringModel, self).__init__()
        self.backbone = backbone
        self.cluster_head = nn.ModuleList([nn.Linear(512, class_num) for _ in range(1)])

    def forward(self, x):
        features = self.backbone(x)
        out = [cluster_head(features) for cluster_head in self.cluster_head]

        return out[0]

##RandAugmentMC

In [None]:
PARAMETER_MAX = 10

def AutoContrast(img, **kwarg):
    return PIL.ImageOps.autocontrast(img)


def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)


def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Cutout(img, v, max_v, bias=0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)


def CutoutAbs(img, v, **kwarg):
    w, h = img.size
    x0 = np.random.uniform(0, w)
    y0 = np.random.uniform(0, h)
    x0 = int(max(0, x0 - v / 2.))
    y0 = int(max(0, y0 - v / 2.))
    x1 = int(min(w, x0 + v))
    y1 = int(min(h, y0 + v))
    xy = (x0, y0, x1, y1)
    # gray
    color = (127, 127, 127)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def Equalize(img, **kwarg):
    return PIL.ImageOps.equalize(img)


def Identity(img, **kwarg):
    return img


def Invert(img, **kwarg):
    return PIL.ImageOps.invert(img)


def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)


def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)


def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)


def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX


def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)


def augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.99, 0.01),
            (Color, 0.99, 0.01),
            (Contrast, 0.99, 0.01),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 1, 8),
            (Rotate, 45, -45),
            (Sharpness, 0.99, 0.01),
            (ShearX, 0.3, -0.3),
            (ShearY, 0.3, -0.3),
            (Solarize, 256, 0),
            (TranslateX, 0.3, -0.3),
            (TranslateY, 0.3, -0.3)]
    return augs


class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = augment_pool()

    def __call__(self, img):
        ops = random.sample(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, 16)
        return img

##AverageMeter

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

##_hungarian_match

In [None]:
def _hungarian_match(flat_preds, flat_targets, num_samples, class_num):
    num_k = class_num
    num_correct = np.zeros((num_k, num_k))

    for c1 in range(0, num_k):
        for c2 in range(0, num_k):
        # elementwise, so each sample contributes once
            votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
            num_correct[c1, c2] = votes

    # num_correct is small
    match = linear_assignment(num_samples - num_correct)

    # return as list of tuples, out_c to gt_c
    res = []
    for i in range(len(match[0])):
        out_c = match[0][i]
        gt_c = match[1][i]
        res.append((out_c, gt_c))

    return res

##test

In [None]:
def test(net, testloader,device, class_num):
    net.eval()
    predicted_all = []
    targets_all = []
    for batch_idx, (inputs, _,_, targets, indexes) in enumerate(testloader):
        batchSize = inputs.size(0)
        targets, inputs = targets.to(device), inputs.to(device)
        output = net(inputs)
        predicted = torch.argmax(output, 1)
        predicted_all.append(predicted)
        targets_all.append(targets)


    flat_predict = torch.cat(predicted_all).to(device)
    flat_target = torch.cat(targets_all).to(device)
    num_samples = flat_predict.shape[0]
    match = _hungarian_match(flat_predict, flat_target, num_samples, class_num)
    reordered_preds = torch.zeros(num_samples).to(device)

    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)

    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples) * 100

    return acc, reordered_preds

##test_ruc

In [None]:
def test_ruc(net, net2, testloader, device, class_num):
    net.eval()
    net2.eval()

    predicted_all = [[] for i in range(0,3)]
    targets_all = []
    acc_list = []
    p_label_list = []

    for batch_idx, (inputs, _, _, targets, indexes) in enumerate(testloader):
        batchSize = inputs.size(0)
        targets, inputs = targets.to(device), inputs.to(device)
        logit = net(inputs)
        logit2 = net2(inputs)
        _, predicted = torch.max(logit, 1)
        _, predicted2 = torch.max(logit2, 1)
        _, predicted3 = torch.max(logit + logit2, 1)

        predicted_all[0].append(predicted)
        predicted_all[1].append(predicted2)
        predicted_all[2].append(predicted3)
        targets_all.append(targets)

    for i in range(0, 3):
        flat_predict = torch.cat(predicted_all[i]).to(device)
        flat_target = torch.cat(targets_all).to(device)
        num_samples = flat_predict.shape[0]
        acc = int((flat_predict.float() == flat_target.float()).sum()) / float(num_samples) * 100
        acc_list.append(acc)
        p_label_list.append(flat_predict)

    return acc_list, p_label_list

#CIFAR20

##Default variables

In [None]:
# Default variables
lr = 0.01
momentum = 0.9
weight_decay = 5e-4
epochs = 200
batch_size = 250
s_thr = 0.99
n_num = 100
o_model = 'checkpoint/selflabel_cifar-20.pth.tar'
e_model = 'checkpoint/simclr_cifar-20.pth.tar'
seed = 1567010775

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

##Check if the GPU is available

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

##cifar100_to_cifar20

In [None]:
def cifar100_to_cifar20(target):
    """
    CIFAR100 to CIFAR 20 dictionary.
    This function is from IIC github.
    """

    class_dict = {0: 4,
     1: 1,
     2: 14,
     3: 8,
     4: 0,
     5: 6,
     6: 7,
     7: 7,
     8: 18,
     9: 3,
     10: 3,
     11: 14,
     12: 9,
     13: 18,
     14: 7,
     15: 11,
     16: 3,
     17: 9,
     18: 7,
     19: 11,
     20: 6,
     21: 11,
     22: 5,
     23: 10,
     24: 7,
     25: 6,
     26: 13,
     27: 15,
     28: 3,
     29: 15,
     30: 0,
     31: 11,
     32: 1,
     33: 10,
     34: 12,
     35: 14,
     36: 16,
     37: 9,
     38: 11,
     39: 5,
     40: 5,
     41: 19,
     42: 8,
     43: 8,
     44: 15,
     45: 13,
     46: 14,
     47: 17,
     48: 18,
     49: 10,
     50: 16,
     51: 4,
     52: 17,
     53: 4,
     54: 2,
     55: 0,
     56: 17,
     57: 4,
     58: 18,
     59: 17,
     60: 10,
     61: 3,
     62: 2,
     63: 12,
     64: 12,
     65: 16,
     66: 12,
     67: 1,
     68: 9,
     69: 19,
     70: 2,
     71: 10,
     72: 0,
     73: 1,
     74: 16,
     75: 12,
     76: 9,
     77: 13,
     78: 15,
     79: 13,
     80: 16,
     81: 19,
     82: 2,
     83: 4,
     84: 6,
     85: 19,
     86: 5,
     87: 5,
     88: 8,
     89: 19,
     90: 18,
     91: 1,
     92: 2,
     93: 15,
     94: 6,
     95: 0,
     96: 17,
     97: 8,
     98: 14,
     99: 13}

    return class_dict[target]

##CIFAR20RUC

In [None]:
class CIFAR20RUC(torchdivision_datasets.CIFAR100):
    def __init__(self, root, transform, transform2, transform3, transform4=None, target_transform=None,train=True, download = False):
        self.root = root
        self.train = train  # training set or test set
        self.transform = transform
        self.transform2 = transform2
        self.transform3 = transform3
        self.transform4 = transform4

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
        self._load_meta()

    def __getitem__(self, index) :
        img, target = self.data[index], cifar100_to_cifar20(self.targets[index])
        img = Image.fromarray(img)

        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform2(img)
            img3 = self.transform3(img)

        if self.transform4 != None:
            img4 = self.transform4(img)
            return img1, img2, img3, img4, target, index
        else:
            return img1, img2, img3, target, index

        return img1, img2, img3, target, index

##preprocess

In [None]:
def preprocess():
    mean = [0.5071, 0.4867, 0.4408]
    std = [0.2675, 0.2565, 0.2761]
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transform_strong = transforms.Compose([
            transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
            transforms.RandomHorizontalFlip(),
            RandAugmentMC(n=2, m=2),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    trainset = CIFAR20RUC(root="./data/cifar-20", transform=transform_test, transform2 = transform_train, transform3 = transform_train, transform4 = transform_strong, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
    testset = CIFAR20RUC(root="./data/cifar-20",transform=transform_test, transform2 = transform_test, transform3 = transform_test,  download=False)
    evalloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

    return trainset, trainloader, testset, evalloader, 20

##Download and split dataset

In [None]:
trainset, trainloader, testset, evalloader, class_num  = preprocess()

Files already downloaded and verified


In [None]:
net = ClusteringModel(Resnet_CIFAR(), class_num)

In [None]:
net2 = copy.deepcopy(net)

In [None]:
net_uc = copy.deepcopy(net)

In [None]:
net_embd = ContrastiveModel(Resnet_CIFAR())

In [None]:
try:
    state_dict = torch.load(o_model)
    state_dict2 = torch.load(e_model)
    net_uc.load_state_dict(state_dict)
    net_embd.load_state_dict(state_dict2, strict = True)
    net.load_state_dict(state_dict, strict = False)
    net2.load_state_dict(state_dict, strict = False)
    net.cluster_head = nn.ModuleList([nn.Linear(512, class_num) for _ in range(1)])
    net2.cluster_head = nn.ModuleList([nn.Linear(512, class_num) for _ in range(1)])
except:
    print("Check Model Directory!")
    # exit(0)

In [None]:
print(device)
if device == 'cuda':
  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
  net2 = torch.nn.DataParallel(net2, device_ids=range(torch.cuda.device_count()))
  net_uc = torch.nn.DataParallel(net_uc, device_ids=range(torch.cuda.device_count()))
  net_embd = torch.nn.DataParallel(net_uc, device_ids=range(torch.cuda.device_count()))
  cudnn.benchmark = True

cuda


In [None]:
net.to(device)
net2.to(device)
net_uc.to(device)
net_embd.to(device)

DataParallel(
  (module): DataParallel(
    (module): ClusteringModel(
      (backbone): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running

##linear_rampup

In [None]:
def linear_rampup(current, rampup_length=200):
  if rampup_length == 0:
    return 1.0
  else:
    current = np.clip((current) / rampup_length, 0.1, 1.0)
    return float(current)

##criterion_rb

In [None]:
class criterion_rb(object):
  def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
    # Clean sample Loss
    probs_u = torch.softmax(outputs_u, dim=1)
    Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
    # Pseudo Label Loss
    Lu = 100*torch.mean((probs_u - targets_u)**2)
    Lu = linear_rampup(epoch) * Lu
    # Total Loss
    return Lx, Lu

In [None]:
optimizer1 = torch.optim.SGD(net.parameters(), lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
optimizer2 = torch.optim.SGD(net2.parameters(), lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
criterion = criterion_rb()

##extract_confidence

In [None]:
def extract_confidence(net, p_label, evalloader, threshold):
  net.eval()
  devide = torch.tensor([]).cuda()
  clean_num = 0
  correct_num = 0
  for batch_idx, (inputs1, _, _, targets, indexes) in enumerate(evalloader):
    inputs1, targets = inputs1.cuda(), targets.cuda().float()
    labels = p_label[indexes].float()
    logits = net(inputs1)
    prob = torch.softmax(logits.detach_(), dim=-1)
    max_probs, _ = torch.max(prob, dim=-1)
    mask = max_probs.ge(threshold).float()
    devide = torch.cat([devide, mask])
    s_idx = (mask == 1)
    clean_num += labels[s_idx].shape[0]
    correct_num += torch.sum((labels[s_idx] == targets[s_idx])).item()

  print(correct_num, clean_num)
  return devide

##extract_metric

In [None]:
def extract_metric(net, p_label, evalloader, n_num):
    net.eval()
    feature_bank = []
    with torch.no_grad():
        for batch_idx, (inputs1 , _, _, _, indexes) in enumerate(evalloader):
            out = net(inputs1.cuda())
            feature_bank.append(out)
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        sim_indices_list = []
        for batch_idx, (inputs1 , _, _, _, indexes) in enumerate(evalloader):
            out = net(inputs1.cuda(non_blocking=True))
            sim_matrix = torch.mm(out, feature_bank)
            _, sim_indices = sim_matrix.topk(k=n_num, dim=-1)
            sim_indices_list.append(sim_indices)
        feature_labels = p_label.cuda()
        first = True
        count = 0
        clean_num = 0
        correct_num = 0
        for batch_idx, (inputs1 , _, _, targets, indexes) in enumerate(evalloader):
            # labels = p_label[indexes].cuda().long()
            labels = p_label[indexes.to(device)].long()

            sim_indices = sim_indices_list[count]
            sim_labels = torch.gather(feature_labels.expand(inputs1.size(0), -1), dim=-1, index=sim_indices)
            # counts for each class
            one_hot_label = torch.zeros(inputs1.size(0) * sim_indices.size(1), 20).cuda()
            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0)
            pred_scores = torch.sum(one_hot_label.view(inputs1.size(0), -1, 20), dim=1)
            count += 1
            pred_labels = pred_scores.argsort(dim=-1, descending=True)
            prob, _ = torch.max(F.softmax(pred_scores, dim=-1), 1)
            # Check whether prediction and current label are same
            noisy_label = labels
            s_idx1 = (pred_labels[:, :1].float() == labels.unsqueeze(dim=-1).float()).any(dim=-1).float()
            s_idx = (s_idx1 == 1.0)
            clean_num += labels[s_idx].shape[0]
            # correct_num += torch.sum((labels[s_idx].float() == targets[s_idx].cuda().float())).item()

            correct_num += torch.sum((labels[s_idx].float() == targets[s_idx.to(targets.device)].cuda().float())).item()



            if first:
                prob_set = prob
                pred_same_label_set = s_idx
                first = False
            else:
                prob_set = torch.cat((prob_set, prob), dim = 0)
                pred_same_label_set = torch.cat((pred_same_label_set, s_idx), dim = 0)

        print(correct_num, clean_num)
        return pred_same_label_set

##extract_hybrid

In [None]:
def extract_hybrid(devide1, devide2, p_label, evalloader):
    devide = (devide1.float() + devide2.float() == 2)
    clean_num = 0
    correct_num = 0
    for batch_idx, (inputs1, _, _, targets, indexes) in enumerate(evalloader):
        inputs1, targets = inputs1.cuda(), targets.cuda().float()
        labels = p_label[indexes].float()
        mask = devide[indexes]
        s_idx = (mask == 1)
        clean_num += labels[s_idx].shape[0]
        correct_num += torch.sum((labels[s_idx] == targets[s_idx])).item()

    print(correct_num, clean_num)
    return devide

In [None]:
# Extract Pseudo Label
acc_uc, p_label = test(net_uc, evalloader, device, class_num)
print(acc_uc)
devide1 = extract_confidence(net_uc, p_label, evalloader, s_thr)
devide2 = extract_metric(net_embd, p_label, evalloader, n_num)
devide = extract_hybrid(devide1, devide2, p_label, evalloader)

50.605999999999995
21029 35331
24247 45335
21007 35225


In [None]:
conf1 =  torch.zeros(50000)
conf2 =  torch.zeros(50000)

##LabelSmoothLoss

In [None]:
class LabelSmoothLoss(nn.Module):

    def __init__(self, smoothing=0.0):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, input, target):
        log_prob = F.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * \
            self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1).long(), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

In [None]:
LSloss = LabelSmoothLoss(0.5)

##Train

In [None]:
def adjust_learning_rate(lr, epochs, optimizer, epoch):
    # cosine learning rate schedule
    lr = lr
    lr *= 0.5 * (1. + math.cos(math.pi * epoch / epochs))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
def get_threshold(current):
    return 0.9 + 0.02*int(current / 40)

In [None]:
def train(epoch, net, net2, trainloader, optimizer, criterion_rb, devide, p_label, conf, batch_size):
    train_loss = AverageMeter()
    net.train()
    net2.train()

    num_iter = (len(trainloader.dataset)//batch_size)+1
    # adjust learning rate
    adjust_learning_rate(lr, epochs, optimizer, epoch)
    optimizer.zero_grad()
    correct_u = 0
    unsupervised = 0
    conf_self = torch.zeros(50000)
    for batch_idx, (inputs1 , inputs2, inputs3, inputs4, targets, indexes) in enumerate(trainloader):
        inputs1, inputs2, inputs3, inputs4, targets = inputs1.float().cuda(), inputs2.float().cuda(), inputs3.float().cuda(), inputs4.float().cuda(), targets.cuda().long()
        s_idx = (devide[indexes] == 1)
        u_idx = (devide[indexes] == 0)
        labels = p_label[indexes].cuda().long()
        labels_x = torch.tensor(p_label[indexes][s_idx]).squeeze().long().cpu()
        target_x = torch.zeros(labels_x.shape[0], 20).scatter_(1, labels_x.view(-1,1), 1).float().cuda()

        logit_o, logit_w1, logit_w2, logit_s = net(inputs1), net(inputs2), net(inputs3), net(inputs4)
        logit_s = logit_s[s_idx]
        max_probs, _ = torch.max(torch.softmax(logit_o, dim=1), dim=-1)
        conf_self[indexes] = max_probs.detach().cpu()
        optimizer.zero_grad()

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u11 = logit_w1[u_idx]
            outputs_u21  = logit_w2[u_idx]
            logit_o2 = net2(inputs1)
            logit_w12 = net2(inputs2)
            logit_w22 = net2(inputs3)
            outputs_u12 = logit_w12[u_idx]
            outputs_u22  = logit_w22[u_idx]
            pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
            ptu = pu**(1/0.5) # temparature sharpening
            target_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
            target_u = target_u.detach().float()

            px = torch.softmax(logit_o2[s_idx], dim=1)
            indexes = indexes.cuda()
            conf = conf.cuda()
            w_x = conf[indexes][s_idx]
            w_x = w_x.view(-1,1).float().cuda()
            px = (1-w_x)*target_x + w_x*px
            ptx = px**(1/0.5) # temparature sharpening
            target_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
            target_x = target_x.detach().float()

            if logit_o[u_idx].shape[0] > 0:
                max_probs, targets_u1 = torch.max(torch.softmax(logit_o[u_idx], dim=1), dim=-1)
                thr = get_threshold(epoch)
                mask_u = max_probs.ge(thr).float()
                u_idx2 = (mask_u == 1)
                unsupervised += torch.sum(mask_u).item()
                correct_u += torch.sum((targets_u1[u_idx2] == targets[u_idx][u_idx2])).item()
                update = indexes[u_idx][u_idx2]
                devide[update] = True
                p_label[update] = targets_u1[u_idx2].float()


        l = np.random.beta(4.0, 4.0)
        l = max(l, 1-l)

        all_inputs = torch.cat([inputs2[s_idx], inputs3[s_idx], inputs2[u_idx], inputs3[u_idx]],dim=0)
        all_targets = torch.cat([target_x, target_x, target_u, target_u], dim=0)
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        logits = net(mixed_input)
        batch_size = target_x.shape[0]

        Lx, Lu = criterion_rb(logits[:batch_size*2], mixed_target[:batch_size*2], logits[batch_size*2:], mixed_target[batch_size*2:], epoch+batch_idx/num_iter)
        total_loss = Lx + Lu + LSloss(logit_s, labels_x.cuda())

        total_loss.backward()
        train_loss.update(total_loss.item(), inputs2.size(0))
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Epoch: [{epoch}][{elps_iters}/{tot_iters}] '
                  'Train loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) '.format(
                      epoch=epoch, elps_iters=batch_idx,tot_iters=len(trainloader),
                      train_loss=train_loss))
    conf_self = (conf_self - conf_self.min()) / (conf_self.max() - conf_self.min())
    return train_loss.avg, devide, p_label, conf_self

In [None]:
for epoch in range(epochs):
  print("== Train RUC ==")
  loss, devide, p_label, conf1 = train(epoch, net, net2, trainloader, optimizer1, criterion, devide, p_label, conf2, batch_size)
  loss, devide, p_label, conf2 = train(epoch, net2, net, trainloader, optimizer2, criterion, devide, p_label, conf1, batch_size)
  acc, p_list = test_ruc(net, net2, evalloader, device, class_num)
  print("accuracy: {}\n".format(acc))

  state = {'net1': net.state_dict(),
            'net2': net2.state_dict() }
  torch.save(state, './checkpoint/ruc_cifar20.t7')

== Train RUC ==


  labels_x = torch.tensor(p_label[indexes][s_idx]).squeeze().long().cpu()


Epoch: [0][0/200] Train loss: 6.1148 (6.1148) 
Epoch: [0][100/200] Train loss: 4.7774 (4.8717) 
Epoch: [0][0/200] Train loss: 6.1383 (6.1383) 
Epoch: [0][100/200] Train loss: 4.1738 (4.8568) 
accuracy: [51.562, 51.518, 51.742]

== Train RUC ==
Epoch: [1][0/200] Train loss: 4.1537 (4.1537) 
Epoch: [1][100/200] Train loss: 4.6410 (4.5229) 
Epoch: [1][0/200] Train loss: 4.7681 (4.7681) 
Epoch: [1][100/200] Train loss: 4.6218 (4.5474) 
accuracy: [51.653999999999996, 51.172, 51.57000000000001]

== Train RUC ==
Epoch: [2][0/200] Train loss: 3.9441 (3.9441) 
Epoch: [2][100/200] Train loss: 4.1757 (4.4641) 
Epoch: [2][0/200] Train loss: 4.7928 (4.7928) 
Epoch: [2][100/200] Train loss: 4.6518 (4.5313) 
accuracy: [51.605999999999995, 51.902, 51.912000000000006]

== Train RUC ==
Epoch: [3][0/200] Train loss: 4.7140 (4.7140) 
Epoch: [3][100/200] Train loss: 4.6651 (4.4295) 
Epoch: [3][0/200] Train loss: 3.8848 (3.8848) 
Epoch: [3][100/200] Train loss: 4.6252 (4.4238) 
accuracy: [51.57000000000001,