# Generate Bags

In [1]:
import numpy as np
from numpy.random import dirichlet, multinomial
from sklearn.utils import shuffle
import random

class InsufficientDataPoints(Exception):
	pass


class InvalidAlpha(Exception):
	pass

In [2]:

def make_bags_dirichlet(train_y, num_classes, bag_size, num_bags, alpha):
	if len(alpha) != num_classes:
		raise InvalidAlpha("the dirichlet distribution's parameter should have length equal to num_classes")

	multinomial_param = dirichlet(alpha, num_bags)
	bag_arr = np.zeros(multinomial_param.shape)
	for row_num in range(bag_arr.shape[0]):
		bag_arr[row_num, :] = multinomial(bag_size, multinomial_param[row_num, :])
	bag_arr = bag_arr.astype(np.int64)
	return _make_bags_counts(train_y, num_classes, bag_arr)

def _make_bags_counts(train_y, num_classes, lp_arr):
	train_y = np.array(train_y, dtype=np.int64)  # y has to be integers starting from 0

	# first need to verify the number of data points
	total_label_counts = {}
	for label in range(num_classes):
		total_label_counts[int(label)] = (train_y == label).astype(int).sum()
	expected_label_counts = {i: np.sum(lp_arr[:, i]) for i in range(num_classes)}
	for label in range(num_classes):
		if total_label_counts[label] < expected_label_counts[label]:
			raise InsufficientDataPoints("Requested data points > total number of data points")
	# done checking

	label2indices = {}
	for i in range(len(train_y)):
		label = int(train_y[i])
		if label not in label2indices.keys():
			label2indices[label] = set({})
		label2indices[label].add(i)

	bag2indices, bag2size, bag2prop = {}, {}, {}
	for bag_idx in range(lp_arr.shape[0]):
		bag2indices[bag_idx] = []
		for label in range(num_classes):
			class_indices = random.sample(list(label2indices[label]), lp_arr[bag_idx, label])
			label2indices[label] -= set(class_indices)
			bag2indices[bag_idx].extend(class_indices)
		bag2size[bag_idx] = len(bag2indices[bag_idx])
		bag2prop[bag_idx] = np.zeros((num_classes,))
		for j in range(num_classes):
			bag2prop[bag_idx][j] = np.sum(train_y[bag2indices[bag_idx]] == j) / bag2size[bag_idx]
	return bag2indices, bag2size, bag2prop

def truncate_data(data, bag2indices):
	idx_list = []
	for bag_id in bag2indices.keys():
		idx_list.extend(bag2indices[bag_id])
	idx_list.sort()
	data_truncated = data[idx_list]
	idx2new = {idx_list[i]: i for i in range(len(idx_list))}
	bag2new = {bag_id: list(map(idx2new.get, bag2indices[bag_id])) for bag_id in bag2indices.keys()}
	return data_truncated, bag2new

In [3]:
# download the fully labeled data
import torchvision
train_dataset = torchvision.datasets.CIFAR10(root=".", train=True, download=True)
labels = train_dataset.targets

Files already downloaded and verified


In [4]:
n_classes = 10
bag_size = 32
n_bags = 1000
bag2indices, bag2size, bag2prop = make_bags_dirichlet(labels, num_classes=n_classes, bag_size=bag_size, num_bags=n_bags, alpha = tuple([1 for _ in range(n_classes)]))
training_data, bag2indices = truncate_data(train_dataset.data, bag2indices)

# bag2indices maps a bag id to indices of feature vectors
# bag2size maps a bag id to its size
# bag2prop maps a bag id to its label proportions

# Model and Dataloaders

In [5]:
import torch
from PIL import Image
from torch.utils.data import Sampler
import numpy as np

In [6]:
# helper functions to load the data for training

def truncate_data_group(x, y, instance2group):
	idx_list = []
	for i in range(x.shape[0]):
		if instance2group[i] != -1:
			idx_list.append(i)
	x_truncated = x[idx_list]
	y_truncated = y[idx_list]
	idx2new = {idx_list[i]: i for i in range(len(idx_list))}
	instance2group_new = {}
	for old, new in idx2new.items():
		instance2group_new[new] = instance2group[old]
	new2idx = {idx2new[idx]: idx for idx in idx2new.keys()}
	return x_truncated, y_truncated, instance2group_new, new2idx


class LLPFC_DATASET_BASE(torch.utils.data.Dataset):
	def __init__(self, data, noisy_y, group2transition, instance2weight, instance2group, transform):
		self.data, self.noisy_y, self.instance2group, self.new2idx = truncate_data_group(data, noisy_y, instance2group)
		self.group2transition = group2transition
		self.instance2weight = instance2weight
		self.transform = transform

	def __len__(self):
		return len(self.data)


class FORWARD_CORRECT_CIFAR10(LLPFC_DATASET_BASE):
	def __getitem__(self, index):
		img, y_ = self.data[index], self.noisy_y[index]
		trans_m = self.group2transition[self.instance2group[index]]
		weight = self.instance2weight[self.new2idx[index]]
		img = Image.fromarray(img)
		if self.transform is not None:
			img = self.transform(img)
		return img, int(y_), torch.tensor(trans_m, dtype=None), weight

We use the resnet implementation from https://github.com/kevinorjohn/LLP-VAT.
You can use other networks.

In [7]:
# MIT License
#
# Copyright (c) 2020 Kuen-Han Tsai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def wide_resnet_d_w(d, w, **kwargs):
    net = WideResNet(d, w, **kwargs)
    net.apply(conv_init)
    return net


class GaussianNoise(nn.Module):
    """ add gasussian noise into feature """
    def __init__(self, std):
        super(GaussianNoise, self).__init__()
        self.std = std

    def forward(self, x):
        zeros_ = torch.zeros_like(x)
        n = torch.normal(zeros_, std=self.std)
        return x + n


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=True)


def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.xavier_uniform_(m.weight, gain=np.sqrt(2))
        nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class WideBasic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1):
        super(WideBasic, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes,
                               planes,
                               kernel_size=3,
                               padding=1,
                               bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          planes,
                          kernel_size=1,
                          stride=stride,
                          bias=True), )

    def forward(self, x):
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)

        return out


class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes, in_channel, image_size, return_features=False):
        super(WideResNet, self).__init__()
        self.in_planes = 16
        if image_size == 32:  # CIFAR10, SVHN
            self.pool_size = 8
        elif image_size == 28:  # EMNIST
            self.pool_size = 7
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6
        k = widen_factor

        print('| Wide-Resnet %dx%d' % (depth, k))
        nStages = [16, 16 * k, 32 * k, 64 * k]

        self.conv1 = conv3x3(in_channel, nStages[0])
        self.layer1 = self._wide_layer(WideBasic,
                                       nStages[1],
                                       n,
                                       dropout_rate,
                                       stride=1)
        self.layer2 = self._wide_layer(WideBasic,
                                       nStages[2],
                                       n,
                                       dropout_rate,
                                       stride=2)
        self.layer3 = self._wide_layer(WideBasic,
                                       nStages[3],
                                       n,
                                       dropout_rate,
                                       stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = nn.Linear(nStages[3], num_classes)

        self.return_features = return_features

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, self.pool_size)
        features = out.view(out.size(0), -1)
        out = self.linear(features)
        if self.return_features:
            return out, features
        return out

# Grouping funcions for LLPFC

In [8]:
import numpy as np
import random
from scipy.spatial import ConvexHull
from scipy.special import factorial
from numpy.linalg import matrix_rank
from scipy.optimize import minimize, Bounds, LinearConstraint


class InvalidChoiceOfWeights(Exception):
    pass


class InvalidChoiceOfNoisyPrior(Exception):
    pass

In [9]:
def approx_noisy_prior(gamma_m, clean_prior):  # use the solution as noisy prior for LLPFC-approx
    def ls_error(x, A, b):
        return 0.5 * np.sum((np.matmul(A, x) - b) ** 2)

    def grad(x, A, b):
        return np.matmul(np.matmul(np.transpose(A), A), x) - np.matmul(np.transpose(A), b)

    def hess(x, A, b):
        return np.matmul(np.transpose(A), A)

    x0 = np.random.rand(clean_prior.shape[0])
    x0 /= np.sum(x0)

    res = minimize(ls_error,
                   x0,
                   args=(np.transpose(gamma_m), clean_prior),
                   method='trust-constr',
                   jac=grad,
                   hess=hess,
                   bounds=Bounds(np.zeros(x0.shape), np.ones(x0.shape)),
                   constraints=LinearConstraint(np.ones(x0.shape), np.ones(1), np.ones(1)),
                   )
    return res.x


def make_a_group(num_classes, clean_prior, bag_ids, bag2prop, noisy_prior_choice):
    bags_list = random.sample(list(bag_ids), num_classes)
    gamma_m = np.zeros((num_classes, num_classes))
    for row_idx in range(num_classes):
        gamma_m[row_idx, :] = bag2prop[bags_list[row_idx]]
    if noisy_prior_choice == 'approx':
        noisy_prior_approx = approx_noisy_prior(np.transpose(gamma_m), clean_prior)
    elif noisy_prior_choice == 'uniform':
        noisy_prior_approx = np.ones((num_classes,)) / num_classes
    else:
        raise InvalidChoiceOfNoisyPrior("Unknown choice of noisy prior: %s" % noisy_prior_choice)
    assert np.all(noisy_prior_approx >= 0)
    assert (np.sum(noisy_prior_approx) - 1) < 1e-4
    clean_prior_approx = np.matmul(np.transpose(gamma_m), noisy_prior_approx)

    transition_m = np.zeros((num_classes, num_classes))
    for i in range(num_classes):
        for j in range(num_classes):
            transition_m[i, j] = gamma_m[i, j] * noisy_prior_approx[i] / clean_prior_approx[j]  # clean_prior can't be 0 in this case

    if matrix_rank(transition_m) != num_classes:
        print("singular transition matrix")
    if np.any(noisy_prior_approx < 0):
        print("negative prior of noisy labels")
    return bags_list, noisy_prior_approx, transition_m


def _pow_normalize(x, t):
    """
    returns normalized x**t
    this function is used to control the probability of bag assignment
    """
    exp = x ** t
    return exp / np.sum(exp, axis=0)


def make_groups_forward(num_classes, bag2indices, bag2size, bag2prop, noisy_prior_choice, weights):
    bag_ids = set(bag2indices.keys())
    num_groups = len(bag_ids) // num_classes
    assert num_groups > 0

    clean_prior = np.zeros((num_classes, ))
    for bag_id in bag2size.keys():
        clean_prior += bag2prop[bag_id] * bag2size[bag_id]
    clean_prior /= np.sum(clean_prior)

    group2bag = {}
    group2noisyp = {}
    group2transition = {}
    group_id = 0
    groups = []
    while len(bag_ids) >= num_classes:
        bags_list, noisy_prior, transition_m = make_a_group(num_classes,
                                                            clean_prior,
                                                            bag_ids,
                                                            bag2prop,
                                                            noisy_prior_choice)
        bag_ids = bag_ids - set(bags_list)
        group2bag[group_id], group2noisyp[group_id], group2transition[group_id] = bags_list, noisy_prior, transition_m
        groups.append(group_id)
        group_id += 1
    group2bag[-1] = list(bag_ids)  # bags that are not in a group
    groups.append(-1)

    instance2group = {instance_id: group_id for group_id in groups for bag_id in group2bag[group_id] for
                      instance_id in bag2indices[bag_id]}

    # calculate the weights of groups
    if weights == 'uniform':
        group2weights = {group_id: 1.0 for group_id, trans_m in group2transition.items()}
    else:
        raise InvalidChoiceOfWeights("Unknown way to determine weights %s, use either ch_vol or uniform" % weights)

    # set the noisy labels
    noisy_y = -np.ones((sum([len(instances) for instances in bag2indices.values()]),))
    instance2weight = np.zeros((sum([len(instances) for instances in bag2indices.values()]),))
    for group_id in groups:
        if group_id == -1:
            continue

        noisy_prior = group2noisyp[group_id]
        noisy_prop = np.zeros((num_classes, ))
        for noisy_class, bag_id in enumerate(group2bag[group_id]):
            noisy_prop[noisy_class] = bag2size[bag_id]
        noisy_prop /= np.sum(noisy_prop)
        weights = np.divide(noisy_prior, noisy_prop)
        weights /= np.sum(weights)

        for noisy_class, bag_id in enumerate(group2bag[group_id]):
            for instance_id in bag2indices[bag_id]:
                noisy_y[instance_id] = noisy_class
                instance2weight[instance_id] = weights[noisy_class] * group2weights[group_id]

    return instance2group, group2transition, instance2weight, noisy_y

# Train an LLPFC model

In [10]:
def test_model(model, test_loader, criterion, device):
	# test a model with fully label dataset
	model.eval()
	with torch.no_grad():
		correct = 0
		total = 0
		total_loss = 0
		for images, labels in test_loader:
			images = images.to(device)
			labels = labels.to(device)
			outputs = model(images)
			_, predicted = torch.max(outputs.data, 1)
			total += labels.size(0)
			correct += (predicted == labels).sum().item()

			prob = nn.functional.softmax(outputs, dim=1)
			loss = criterion(prob, labels, device)
			total_loss += loss.item()
	return correct / total, total_loss / total


def validate_model_forward(model, loss_f_val, val_loader, device):
	model.eval()
	total_loss = 0
	total = 0
	for i, (images, noisy_y, trans_m, weights) in enumerate(val_loader):
		total_loss += compute_forward_loss_on_minibatch(model, loss_f_val, images, noisy_y, trans_m, weights, device).item()
		total += noisy_y.size(0)
	return total_loss / total


def train_model_forward_one_epoch(model, loss_f, optimizer, train_loader, device, epoch, scheduler):
	# train the model one epoch with forward correction
	# label input of loss_f must be an integer
	model.train()
	total_step = len(train_loader)
	for i, (images, noisy_y, trans_m, weights) in enumerate(train_loader):
		loss = compute_forward_loss_on_minibatch(model, loss_f, images, noisy_y, trans_m, weights, device)
		# Backward pass
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		if (i + 1) % 100 == 0:
			print('				Step [{}/{}], Loss: {:.4f}'.format(i + 1, total_step, loss.item()))
		if type(scheduler) == torch.optim.lr_scheduler.CosineAnnealingWarmRestarts:
			scheduler.step(epoch + i / total_step)
	if type(scheduler) == torch.optim.lr_scheduler.MultiStepLR:
		scheduler.step()
	elif type(scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
		scheduler.step(validate_model_forward(model, loss_f, train_loader, device))


def compute_forward_loss_on_minibatch(model, loss_f, images, noisy_y, trans_m, weights, device):
	# Move tensors to the configured device
	images = images.to(device)
	noisy_y = noisy_y.to(device)
	trans_m = trans_m.to(device)
	weights = weights.to(device)
	# Forward pass
	outputs = model(images)
	prob = nn.functional.softmax(outputs, dim=1)
	prob_corrected = torch.bmm(trans_m.float(), prob.reshape(prob.shape[0], -1, 1)).reshape(prob.shape[0], -1)
	loss = loss_f(prob_corrected, noisy_y, weights, device)
	return loss

In [11]:
import torch
import torch.nn as nn
from torch.distributions.constraints import simplex
from torch.utils.data import SubsetRandomSampler

import numpy as np

In [12]:
def loss_f(x, y, weights, device, epsilon=1e-8):
    assert torch.all(simplex.check(x))
    x = torch.clamp(x, epsilon, 1 - epsilon)
    unweighted = nn.functional.nll_loss(torch.log(x), y, reduction='none')
    weights /= weights.sum()
    return (unweighted * weights).sum()


def loss_f_val(x, y, weights, device, epsilon=1e-8):
    assert torch.all(simplex.check(x))
    x = torch.clamp(x, epsilon, 1 - epsilon)
    unweighted = nn.functional.nll_loss(torch.log(x), y, reduction='none')
    return (unweighted * weights).sum()


def loss_f_test(x, y, device, epsilon=1e-8):
    x = torch.clamp(x, epsilon, 1 - epsilon)
    return nn.functional.nll_loss(torch.log(x), y, reduction='sum')

In [13]:
import torch.optim as optim
import torchvision.transforms as transforms

transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),  # mean-std of cifar10
        ])
transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),  # mean-std of cifar10
        ])

train_batch_size = 64
test_dataset = torchvision.datasets.CIFAR10(root=".", train=False, transform=transform_test, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=256, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 32
in_channel = 3

model = WideResNet(depth=22, widen_factor=8, dropout_rate=0.3, num_classes=n_classes, in_channel=in_channel, image_size=image_size).to(device)
optimizer = optim.Adamax(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

total_epochs = 100
num_epoch_regroup = 5
noisy_prior_choice = "approx"  # change it to "uniform" to use LLPFC-uniform
weights = "uniform"

num_regroup = 0
for epoch in range(total_epochs):
        if epoch % num_epoch_regroup == 0:
            instance2group, group2transition, instance2weight, noisy_y = make_groups_forward(n_classes,
                                                                                             bag2indices,
                                                                                             bag2size,
                                                                                             bag2prop,
                                                                                             noisy_prior_choice,
                                                                                             weights)
            fc_train_dataset = FORWARD_CORRECT_CIFAR10(training_data,
                                                        noisy_y,
                                                        group2transition,
                                                        instance2weight,
                                                        instance2group,
                                                        transform_train)
            llp_train_loader = torch.utils.data.DataLoader(dataset=fc_train_dataset, shuffle=True,
                                                               batch_size=train_batch_size)
            num_regroup += 1
        print(f"Regroup-{num_regroup} Epoch-{epoch}")
        print(f"		lr: {optimizer.param_groups[0]['lr']}")
        train_model_forward_one_epoch(model, loss_f, optimizer, llp_train_loader, device, epoch, None)
        acc, test_error = test_model(model, test_loader, loss_f_test, device)
        print(f"      test_error = {test_error}, accuracy = {100 * acc}%")

Files already downloaded and verified
| Wide-Resnet 22x8
Regroup-1 Epoch-0
		lr: 0.0001
				Step [100/500], Loss: 2.1731
				Step [200/500], Loss: 2.2123
				Step [300/500], Loss: 2.2620
				Step [400/500], Loss: 2.1908
				Step [500/500], Loss: 2.2343
      test_error = 1.963569681930542, accuracy = 29.14%
Regroup-1 Epoch-1
		lr: 0.0001
				Step [100/500], Loss: 2.1883
				Step [200/500], Loss: 2.0571
				Step [300/500], Loss: 2.3043
				Step [400/500], Loss: 2.2584
				Step [500/500], Loss: 2.0798
      test_error = 1.5885544761657715, accuracy = 43.22%
Regroup-1 Epoch-2
		lr: 0.0001
				Step [100/500], Loss: 2.1805
				Step [200/500], Loss: 2.0959
				Step [300/500], Loss: 2.1227
				Step [400/500], Loss: 2.0922
				Step [500/500], Loss: 2.0739
      test_error = 1.672159308242798, accuracy = 41.02%
Regroup-1 Epoch-3
		lr: 0.0001
				Step [100/500], Loss: 2.2770
				Step [200/500], Loss: 2.1911
				Step [300/500], Loss: 2.1219
				Step [400/500], Loss: 2.1167
				Step [500/500], Los