In [1]:
import os
import shutil
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from tqdm import tqdm
import torch



In [2]:
import os
from torchvision import transforms

# Paths and config
source_dir = 'Chest_X-Ray_Image'
target_dir = 'Dataset-1'

# Split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15
batch_size = 1  # For saving individual images
num_workers = 2

# Transforms    
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


In [3]:
full_dataset = ImageFolder(root=source_dir, transform=train_transform)
class_names = full_dataset.classes

# Calculate sizes
total_size = len(full_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size  # Remaining samples

# Perform split
train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

# Assign specific transforms for each split
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform
test_dataset.dataset.transform = test_transform

In [4]:
def save_images(dataset, split_name):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    for i, (img, label) in enumerate(tqdm(loader, desc=f"Saving {split_name} data")):
        class_name = class_names[label.item()]
        save_dir = os.path.join(target_dir, split_name, class_name)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"{split_name}_{i}.png")
        save_image(img, save_path)

In [5]:
save_images(train_dataset, "train")
save_images(val_dataset, "val")
save_images(test_dataset, "test")

Saving train data: 100%|██████████| 3045/3045 [03:07<00:00, 16.25it/s]
Saving val data: 100%|██████████| 652/652 [00:50<00:00, 12.79it/s]
Saving test data: 100%|██████████| 653/653 [00:53<00:00, 12.30it/s]


In [19]:
import os
from torchvision import transforms

# Paths and config
source_dir = 'Data'
target_dir = 'Dataset-2'

batch_size = 1  # For saving individual images
num_workers = 2

# Transforms    
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


In [20]:
# Load datasets
full_train_dataset = ImageFolder(root=os.path.join(source_dir, 'train'), transform=train_transform)
test_dataset = ImageFolder(root=os.path.join(source_dir, 'test'), transform=test_transform)

# Split sizes
train_ratio = 0.85  # 85% for training, 15% for validation
train_size = int(train_ratio * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Assign transforms for each split
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform
test_dataset.transform = test_transform

In [21]:
def save_images(dataset, split_name):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    for i, (img, label) in enumerate(tqdm(loader, desc=f"Saving {split_name} data")):
        class_name = class_names[label.item()]
        save_dir = os.path.join(target_dir, split_name, class_name)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"{split_name}_{i}.png")
        save_image(img, save_path)

In [22]:
save_images(train_dataset, "train")
save_images(val_dataset, "val")
save_images(test_dataset, "test")

Saving train data: 100%|██████████| 4372/4372 [04:15<00:00, 17.09it/s]
Saving val data: 100%|██████████| 772/772 [00:51<00:00, 15.07it/s]
Saving test data: 100%|██████████| 1288/1288 [00:59<00:00, 21.72it/s]
