Dataset

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np


In [None]:
class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, input):
        output1 = self.transform(input)
        output2 = self.transform(input)
        return output1, output2


class datasets_labeled(datasets):

    def __init__(self, root, indexs=None, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(datasets_labeled, self).__init__(root, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = transforms.transpose(transforms.normalize(self.data))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
    

class datasets_unlabeled(datasets_labeled):

    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(datasets_unlabeled, self).__init__(root, indexs, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        self.targets = np.array([-1 for i in range(len(self.targets))])

def get_datasets(root, n_labeled, datasets,
            transform_train=None, transform_val=None,
            download=True):

            base_dataset = datasets(root, train=True, download=True)
            train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
            train_labeled_dataset = datasets_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
            train_unlabeled_dataset = datasets_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
            val_dataset = datasets_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
            test_dataset = datasets_labeled(root, train=False, transform=transform_val, download=True)

            print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
            return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset


def train_val_split(labels, n_labeled_per_class):
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(labels):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs

def get_mean_std(datasets):
    imgs = [item[0] for item in datasets.trainset] # item[0] and item[1] are image and its label
    imgs = torch.stack(imgs, dim=0).numpy()

    # calculate mean over each channel (r,g,b)
    mean_r = imgs[:,0,:,:].mean()
    mean_g = imgs[:,1,:,:].mean()
    mean_b = imgs[:,2,:,:].mean()
    print(mean_r,mean_g,mean_b)

    # calculate std over each channel (r,g,b)
    std_r = imgs[:,0,:,:].std()
    std_g = imgs[:,1,:,:].std()
    std_b = imgs[:,2,:,:].std()
    print(std_r,std_g,std_b)

    mean = (mean_r, mean_g, mean_b)
    std = (std_r, std_g, std_b)

    return mean, std

datasets_mean, datasets_std = get_mean_std(datasets)




