In [None]:
!nvidia-smi

In [None]:
import torch
from torch.utils.data import Subset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

import numpy as np

import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [38]:
def get_subset_indices(dataset, percentage, seed=0):
  rng = np.random.RandomState(seed)
  targets = np.array(dataset.targets)
  num_classes = len(np.unique(targets))
  num_samples_per_class = int(percentage*len(dataset)/num_classes)

  print("num_samples_per_class = ", num_samples_per_class)

  indices = []

  for c in range(num_classes):
    class_indices = (targets == c).nonzero()[0]
    indices.extend(
        list(rng.choice(class_indices, size=num_samples_per_class, replace=False))
    )
  return indices

In [28]:
def dataloader_with_seed_perc_conc(conc_trainset, seed, perc):
    for i in range(len(conc_trainset.datasets)):
        if i == 0:
            subset = Subset(
                conc_trainset.datasets[i],
                get_subset_indices(conc_trainset.datasets[i], perc, int(seed))
            )
            continue
        else:
            newset = Subset(
                conc_trainset.datasets[i],
                get_subset_indices(conc_trainset.datasets[i], perc, int(seed))
            )
            subset = torch.utils.data.ConcatDataset(
                [subset, newset]
            )
  
    trainloader = torch.utils.data.DataLoader(
      subset, batch_size=256, shuffle=True, num_workers=2)

    return trainloader

In [29]:
def dataloader_with_seed_perc(trainset, seed, perc):
  train_subset = Subset(
      trainset, 
      get_subset_indices(trainset, perc, int(seed))
  )
  
  trainloader = torch.utils.data.DataLoader(
    train_subset, batch_size=256, shuffle=True, num_workers=2)

  return trainloader

In [30]:
GID_data_path = "/home/hz271/PyTorch-StudioGAN/biggan_images/samples/CIFAR10-BigGAN-DiffAug-train-2022_02_11_07_23_15/fake"

# Datasets
generated_train_dataset = torchvision.datasets.ImageFolder(
    root=GID_data_path,
    transform=transform_train
)

original_train_dataset = torchvision.datasets.CIFAR10(
    root='/home/hz271/OOD/Saved/data/', train=True, download=True, transform=transform_train)

# Concatenate
train_set = torch.utils.data.ConcatDataset(
    [generated_train_dataset, original_train_dataset]
)

test_set = torchvision.datasets.CIFAR10(
    root='/home/hz271/OOD/Saved/data/', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [40]:
dt = dataloader_with_seed_perc_conc(train_set, 0, 0.01)

num_samples_per_class =  1000
num_samples_per_class =  50


In [None]:
for batch_idx, (inputs, targets) in enumerate(dt):
    print("batch idx = ", batch_idx, " inputs = ", inputs, " targets = ", targets)
    break

In [None]:
trainloader = dataloader_with_seed_perc(train_set, perc = 0.001, seed = 0)

In [None]:
for batch_idx, (inputs, targets) in enumerate(trainloader):
    print("batch idx = ", batch_idx, " inputs = ", inputs, " targets = ", targets)
    break