In [1]:
# !pip3 install pyarrow pillow --upgrade --user
import pyarrow.parquet as pq
from datasets import Dataset
import pandas as pd
import os

In [2]:
# Load the Arrow data to a Dataset

current_directory = os.getcwd()

dataset_train = Dataset.from_file(os.path.join(current_directory, "raw\\train\\") + "data-00000-of-00001.arrow")

dataset_validation = Dataset.from_file(os.path.join(current_directory, "raw\\validation\\") + "data-00000-of-00001.arrow")

dataset_test = Dataset.from_file(os.path.join(current_directory, "raw\\test\\") + "data-00000-of-00001.arrow")

dataset_train

Dataset({
    features: ['image_file_path', 'image', 'labels'],
    num_rows: 1034
})

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.data = dataset
        self.transform = transforms.Compose([
            transforms.Resize((500,500)),  # Resize to our desired size
            transforms.ToTensor(),          # Convert PIL Image to PyTorch tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize RGB channels
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        image = self.transform(sample['image'])
        label = sample['labels']

        return image, label
    

custom_train = CustomDataset(dataset_train)
custom_validation = CustomDataset(dataset_validation)
custom_test = CustomDataset(dataset_test)

# Create a DataLoader for training, validation and test
train_loader = DataLoader(custom_train, batch_size=32, shuffle=True)    
validation_loader = DataLoader(custom_validation, batch_size=32, shuffle=False)
test_loader = DataLoader(custom_test, batch_size=32, shuffle=False)

In [6]:
# Create the directory if it doesn't exist
if not os.path.exists('dataloaders'):
    os.makedirs('dataloaders')

# Save the DataLoader to a file
torch.save(train_loader, 'dataloaders/train_loader.pt')
torch.save(validation_loader, 'dataloaders/validation_loader.pt')
torch.save(test_loader, 'dataloaders/test_loader.pt')
