In [9]:
! pip install torch



In [4]:
import os
from torchvision import datasets, transforms
from torch.utils.data import ConcatDataset, DataLoader
from torchvision.utils import save_image

class ImageDataHandler:
    def __init__(self, image_height, image_width, train_data_path, val_data_path, train_save_dir, val_save_dir):
        self.image_height = image_height
        self.image_width = image_width
        self.train_data_path = train_data_path
        self.val_data_path = val_data_path
        self.train_save_dir = train_save_dir
        self.val_save_dir = val_save_dir

        # Create directories for saving images
        os.makedirs(train_save_dir, exist_ok=True)
        os.makedirs(val_save_dir, exist_ok=True)

        # Initialize transformations
        self.init_transforms()

        # Load datasets
        self.train_dataset = self.create_concat_dataset(train_data_path, self.train_transforms)
        self.val_dataset = self.create_concat_dataset(val_data_path, self.val_transforms)

        # Create DataLoaders
        self.train_loader = DataLoader(self.train_dataset, batch_size=32, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=32, shuffle=False)

    def init_transforms(self):
        # Define your transformations here
        self.train_transforms = [
            transforms.Compose([transforms.Resize((self.image_height, self.image_width)), transforms.ToTensor()]),
            # Add other transformations as needed
        ]

        self.val_transforms = [
            transforms.Compose([transforms.Resize((self.image_height, self.image_width)), transforms.ToTensor()]),
            # Add other transformations as needed
        ]

    def create_concat_dataset(self, data_path, transforms_list):
        dataset_list = [datasets.ImageFolder(root=data_path, transform=transform) for transform in transforms_list]
        return ConcatDataset(dataset_list)

    def save_dataset(self, dataset, save_dir, prefix='image'):
        for i, (image, label) in enumerate(dataset):
            image_path = os.path.join(save_dir, f"{prefix}_{i}_label_{label}.jpg")
            save_image(image, image_path)

# Usage
image_height, image_width = 224, 224
train_data_path = '/Users/shuai/Desktop/data/train/'
val_data_path = '/Users/shuai/Desktop/data/val/'
train_save_dir = '/Users/shuai/Desktop/data/train_tf'
val_save_dir = '/Users/shuai/Desktop/data/val_tf'

data_handler = ImageDataHandler(image_height, image_width, train_data_path, val_data_path, train_save_dir, val_save_dir)

# Now you can use the data loaders
# for images, labels in data_handler.train_loader:
#     # Your training loop here

# Save transformed images
data_handler.save_dataset(data_handler.train_dataset, data_handler.train_save_dir, prefix='train')
data_handler.save_dataset(data_handler.val_dataset, data_handler.val_save_dir, prefix='val')
