In [0]:
from torchvision import datasets,transforms
import torch

In [0]:
def train_test_dataloaders(seed, batch_size, workers,train_transforms,test_transforms):
  
  SEED = seed

  # CUDA?
  cuda = torch.cuda.is_available()
  print("CUDA Available?", cuda)

  # For reproducibility
  torch.manual_seed(SEED)

  if cuda:
      torch.cuda.manual_seed(SEED)

  # dataloader arguments - something you'll fetch these from cmdprmt
  dataloader_args = dict(shuffle=True, batch_size=batch_size, num_workers=workers, pin_memory=True) if cuda else dict(shuffle=True, batch_size=batch_size)
  testdataloader_args = dict(shuffle=False, batch_size=batch_size, num_workers=workers, pin_memory=True) if cuda else dict(shuffle=True, batch_size=batch_size)

  trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transforms)
  trainloader = torch.utils.data.DataLoader(trainset, **dataloader_args)

  testset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transforms)
  testloader = torch.utils.data.DataLoader(testset, **testdataloader_args)
  classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  return trainloader, testloader

In [0]:
def transformations():
  # Train Phase transformations
  train_transforms = transforms.Compose([
                                       #transforms.RandomRotation((-12.0, 12.0), fill=(1,)),
                                       transforms.RandomCrop(32, padding=4),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                       ])

  # Test Phase transformations
  test_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                       ])
  return train_transforms , test_transforms