In [2]:
import os
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

data_folder = os.getcwd() + "/data"

def preprocess_data(data_folder, image_size=(256, 256), batch_size=16):

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])


    subfolders = ['vangogh_train', 'vangogh_test', 'landscape_train', 'landscape_test']
    dataloaders = {}

    for folder in subfolders:
        subfolder_path = os.path.join(data_folder, folder)
        if not os.path.exists(subfolder_path):
            raise FileNotFoundError(f"Directory not found: {subfolder_path}")
        
        dataset = datasets.ImageFolder(root=subfolder_path, transform=transform)
        
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=('train' in folder))
        
        dataloaders[folder] = dataloader
    
    return dataloaders


# testing
def test_preprocessing(dataloaders, image_size=(256, 256)):
    print("Testing preprocessing...")

    for subset, dataloader in dataloaders.items():
        print(f"Testing subset: {subset}")
        
        for batch_idx, (images, _) in enumerate(dataloader):
            print(f"Batch {batch_idx + 1}:")
            
            # Check image dimensions
            assert images.shape[2:] == image_size, f"Images are not resized to {image_size}"
            print(f"  - Batch size: {images.shape[0]}")
            print(f"  - Image dimensions: {images.shape[2:]}")
            
            # Check normalization
            assert images.min() >= -1 and images.max() <= 1, "Images are not normalized to [-1, 1]"
            print(f"  - Pixel range: [{images.min().item():.2f}, {images.max().item():.2f}]")
            
            # display images
            if batch_idx == 0:
                for i in range(min(4, images.size(0))):
                    image = images[i]
                    image = (image * 0.5 + 0.5).clamp(0, 1)
                    plt.imshow(image.permute(1, 2, 0).numpy())
                    plt.title(f"Subset: {subset}, Image: {i + 1}")
                    plt.axis('off')
                    plt.show()
            break

'''
The following is used to test the images

dataloaders = preprocess_data(data_folder, image_size=(256, 256), batch_size=16)
test_preprocessing(dataloaders, image_size=(256, 256))
'''

'\nThe following is used to test the images\n\ndataloaders = preprocess_data(data_folder, image_size=(256, 256), batch_size=16)\ntest_preprocessing(dataloaders, image_size=(256, 256))\n'