In [7]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from data.utils import save_data
from data.c_cifar import ContrastiveCIFAR10
from data.c_mnist import ContrastiveMNIST

### CIFAR-10 Data

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

full_train_dataset = datasets.CIFAR10(root='./cifar-data', train=True, download=True, transform=transform)

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

test_dataset = datasets.CIFAR10(root='./cifar-data', train=False, download=True, transform=transform)

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [3]:
loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
save_data(loaders, 'cifar_data_loaders.pkl')

## Contrastive CIFAR-10

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

cifar_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [9]:
full_train_dataset = datasets.CIFAR10(root='./cifar-data', train=True, download=True)

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

test_dataset = datasets.CIFAR10(root='./cifar-data', train=False, download=True, transform=transform)

c_cifar_train = ContrastiveCIFAR10(train_dataset, cifar_transform)
c_cifar_val = ContrastiveCIFAR10(val_dataset, cifar_transform)

batch_size = 128

train_loader = DataLoader(c_cifar_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(c_cifar_val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [10]:
loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
save_data(loaders, 'contrastive_cifar_data_loaders.pkl')

## MNIST

In [3]:
transform = transforms.ToTensor()

full_train_dataset = datasets.MNIST(root='./mnist-data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./mnist-data', train=False, download=True, transform=transform)

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [4]:
loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
save_data(loaders, 'mnist_data_loaders.pkl')

## Contrastive MNIST

In [10]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),      
    transforms.RandomRotation(degrees=15),        
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), 
    transforms.ToTensor()
])

full_train_dataset = datasets.MNIST(root='./mnist-data', train=True, download=True)
test_dataset = datasets.MNIST(root='./mnist-data', train=False, download=True, transform=transforms.ToTensor())

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

c_mnist_train = ContrastiveMNIST(train_dataset, transform)
c_mnist_val = ContrastiveMNIST(val_dataset, transform)

train_loader = DataLoader(c_mnist_train, batch_size=128, shuffle=True)
val_loader = DataLoader(c_mnist_val, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [11]:
loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
save_data(loaders, 'contrastive_mnist_data_loaders.pkl')