In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random


In [4]:
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, target_size=10000):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        self.transform = transform
        self.target_size = target_size

    def __len__(self):
        return self.target_size

    def __getitem__(self, idx):
        img_idx = idx % len(self.image_files)
        img_path = os.path.join(self.image_dir, self.image_files[img_idx])
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        return image


In [5]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
    transforms.ToTensor()
])


In [11]:
image_path = r'data\train-image\image\cancer'
dataset = CustomImageDataset(image_dir= image_path, transform=transform, target_size=10000)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


save_dir = r'data\train-image\image\oversampledCancer'
os.makedirs(save_dir, exist_ok=True)
for i, batch in enumerate(dataloader):
    for j, image in enumerate(batch):
        img = transforms.ToPILImage()(image)
        img.save(os.path.join(save_dir, f'image_{i * 32 + j}.jpg'))

print(f'Oversampled images saved to {save_dir}')



Oversampled images saved to data\train-image\image\oversampledCancer
