In [None]:
import torch
torch.cuda.is_available()
import numpy as np
import utils
from torch.utils.data import Dataset
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
def make_imb_data(max_num, class_num, gamma, inv=False, bal=False):
    print('max_num={}'.format(max_num))
    mu = np.power(1/gamma, 1/(class_num - 1))
    class_num_list = []
    for i in range(class_num):
        if(inv):
            class_num_list.append(int(max_num * np.power(mu, class_num-i-1)))
        else:
            class_num_list.append(int(max_num * np.power(mu, i)))
    if(bal):
        per_class = sum(class_num_list)/class_num
        class_num_list = [int(per_class) for i in range(class_num)]
    print(class_num_list)
    return list(class_num_list)

N_SAMPLES_PER_CLASS = make_imb_data(1250, 10, 100 ,False, False)
U_SAMPLES_PER_CLASS = make_imb_data(3 * 1250, 10, 100, False, False)
N_SAMPLES_PER_CLASS_T = torch.Tensor(N_SAMPLES_PER_CLASS)

train_labeled_set, train_unlabeled_set, test_set = utils.get_cifar10('/home/apoorva/Datasets', N_SAMPLES_PER_CLASS
, U_SAMPLES_PER_CLASS)

In [None]:
# target_dist = torch.nn.functional.normalize(torch.ones([10,10]))
# target_dist = torch.diag(torch.ones(10))
target_dist = torch.zeros([10, 10])
for i in range(10):
    target_dist[i][0] = 1
labeled_loader = torch.utils.data.DataLoader(train_labeled_set, batch_size=128, num_workers=8, shuffle=True)

In [None]:
pseudo_labels = [train_unlabeled_set.targets[i] for i in np.arange(len(train_unlabeled_set.targets))]

def get_weights(pseudo_labels, lbl, target_dist):
    """Returns the sampling weights for each data instance in the dataset 

    Args:
        pseudo_labels (list): list of pseudo-labels generated by the model
        lbl (tensor): labels present in one batch obtained from the dataloader
        target_dist (tensor): target distribution obtained by dM/dCij

    Returns:
        list: list of instance-wise weights to be passed into the sampler
    """

    wts = []
    for t in pseudo_labels:
        wts.append(float(target_dist[lbl][t])/N_SAMPLES_PER_CLASS[t])
    return wts

In [None]:
def plot_distribution(counter_class):
    names = list(counter_class.keys())
    values = list(counter_class.values())

    plt.bar(names, values, tick_label=names)
    plt.show()

In [None]:
def get_loaders(unlabeled_train, num_classes, pseudo_labels, target_dist):
    """
    Returns a dict of (num_classes) dataloaders of batch size 1 depending on the target distribution

    Args:
        num_classes (int): number of dataloaders (equal to number of classes in the dataset)
        pseudo_labels (list): list of pseudo-labels generated by the model
        target_dist (tensor): target distribution obtained by dM/dCij

    Returns:
        dict: dictionary of num_classes loaders
    """
    loader_dict = {}
    for i in range(num_classes):
        pl_weights = get_weights(pseudo_labels, i, target_dist)
        sampler = torch.utils.data.WeightedRandomSampler(weights= pl_weights, num_samples = len(pseudo_labels), replacement = True)
        loader = torch.utils.data.DataLoader(unlabeled_train, batch_size=None, num_workers=8, sampler=sampler)
        loader_dict.update({f"{i}":loader})
    return loader_dict
    

# loaders = get_loaders(train_unlabeled_set, 10, pseudo_labels, target_dist)
# for i in range(10):
#     ldr = loaders[f'{i}']
#     labels = [int(i) for _, i, _ in ldr]
#     print(Counter(labels))


In [None]:
class UnlabeledDatasetCollate(Dataset):
    """
    Class for collating the pseudo label-image pairs after target-distribution dependent sampling
    Returns a dataset object that can be passed into a dataloader directly
    """
    def __init__(self, labels_list, images_list, transform=None, target_transform=None):
        """
        Args:
            labels_list (list): list of labels obtained after sampling
            images_list (list): list of images corresponding to the labels sampled
            transform, target_transform (torch.nn.Sequential, optional): if we intend to apply any transforms on 
            the sampled images. Defaults to None.
        """
        self.labels_list = labels_list
        self.images_list = images_list
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels_list) 

    def __getitem__(self, idx):
        image = self.images_list[idx]
        label = self.labels_list[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
list_dist = np.zeros(10)
for imgs, lbls, idxs in labeled_loader:    
    sl_list = []
    image_list = []
    imgs = imgs.to(device)
    lbls = lbls.to(device)  
    loaders = get_loaders(train_unlabeled_set, 10, pseudo_labels, target_dist) # todo: train_unlabeled_set replaced by a megabatch
    for i in lbls:
        image , sampled_label, _ = next(iter(loaders[f'{i}']))
        sl_list.append(int(sampled_label)) 
        image_list.append(image)
    # print(Counter(sl_list))
    ulb_dataset_sampled = UnlabeledDatasetCollate(sl_list, image_list)
    ulb_loader = torch.utils.data.DataLoader(ulb_dataset_sampled, batch_size=128, num_workers=8, shuffle=True)
    img, lbl_ulb = next(iter(ulb_loader))
    lbl_ulb = lbl_ulb.numpy()
    for i in lbl_ulb:
        list_dist[i]+=1
    print(Counter(lbl_ulb))
    plot_distribution(Counter(lbl_ulb))

print(list_dist) 

In [None]:
class UnlabeledDataLoader():

    def __init__(self, ulb_dataset, model, target_dist, batch_size):
        super(UnlabeledDataLoader, self).__init__()
        self.model = model
        self.ulb_dataset = ulb_dataset
        self.target_dist = target_dist
        self.batch_size = batch_size
        self.pseudo_labels = self.gen_pl()
    
    def gen_pl(self):
        list_labels = []
        ulb_loader = torch.utils.data.DataLoader(self.ulb_dataset, batch_size= 128, num_workers=8, shuffle=True)
        for img in ulb_loader:
            list_labels.append(self.model(img).numpy().flatten())
        return list_labels
            
    def get_loaders(self):
        """
        Returns a dict of (num_classes) dataloaders of batch size 1 depending on the target distribution

        Args:
            num_classes (int): number of dataloaders (equal to number of classes in the dataset)
            pseudo_labels (list): list of pseudo-labels generated by the model
            target_dist (tensor): target distribution obtained by dM/dCij

        Returns:
            dict: dictionary of num_classes loaders
        """
        loader_dict = {}
        for i in range(target_dist.shape[0]):
            pl_weights = get_weights(self.pseudo_labels, i, self.target_dist)
            sampler = torch.utils.data.WeightedRandomSampler(weights= pl_weights, num_samples = len(self.pseudo_labels), replacement = True)
            loader = torch.utils.data.DataLoader(self.ulb_dataset, batch_size=None, num_workers=8, sampler=sampler)
            loader_dict.update({f"{i}":loader})
        return loader_dict

        

    def get_weights(self, lbl):
        """Returns the sampling weights for each data instance in the dataset 

        Args:
            pseudo_labels (list): list of pseudo-labels generated by the model
            lbl (tensor): labels present in one batch obtained from the dataloader
            target_dist (tensor): target distribution obtained by dM/dCij

        Returns:
            list: list of instance-wise weights to be passed into the sampler
        """

        wts = []
        for t in self.pseudo_labels:
            wts.append(float(self.target_dist[lbl][t])/N_SAMPLES_PER_CLASS[t])
        return wts

    def get_batch(self, labels):
        """Returns a batch of unlabeled images and their pseudo labels sampled from a distribution dependent 

        Args:
            labels (torch.tensor): labels obtained from one minibatch of the labeled loader

        Returns:
            img: torch.tensor of unlabeled images (shape NxCxHxW)
            lbl_ulb: torch tensor of sampled pseudo labels (shape Nx1) 
        """
        sl_list = []
        image_list = []
        loaders = self.get_loaders(self.ulb_dataset, len(self.ulb_dataset.classes), self.pseudo_labels, self.target_dist)
        for i in labels:
            image , sampled_label, _ = next(iter(loaders[f'{i}']))
            sl_list.append(int(sampled_label)) 
            image_list.append(image)
        ulb_dataset_sampled = UnlabeledDatasetCollate(sl_list, image_list)
        ulb_loader = torch.utils.data.DataLoader(ulb_dataset_sampled, batch_size=128, num_workers=8, shuffle=True)
        img, lbl_ulb = next(iter(ulb_loader))
        return img, lbl_ulb
