In [None]:
import numpy as np 
import torch

import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import CIFAR10

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
np.random.seed(0)
random_seed = torch.manual_seed(0)

def data_loader(data_dir, batch_size, random_seed=random_seed, valid_size=0.1, shuffle=True, test=False): 
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(), 
        normalize,
    ])

    if test: 
        dataset = CIFAR10(root=data_dir, train=True, download=True, transform=transform)
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return data_loader

    train_dataset = CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    valid_dataset = CIFAR10(root=data_dir, train=True, download=True, transform=transform)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

    return (train_dataloader, valid_dataloader)