The following code does data augmentation to the multiclass (Tumor and Bacterial proteins) dataset. We had about 1,500 samples in total (including training, test and validation data). With the following code we stacj the changes and generate five randomly different images per sample so we will have 7,500 samples in total.

In [19]:
import os
import random
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
from PIL import Image

# Parameters
num_augmentations_per_image = 5  # CHANGE THIS TO CONTROL HOW MANY VARIANTS PER IMAGE

# Set seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

# Paths
base_dir = "/Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project"
data_dir = os.path.join(base_dir, "multiclass_dataset/multiclass")
augmented_data_dir = os.path.join(base_dir, "data_augmented_multiclass_dataset/multiclass")
augmented_image_datasets_dir = os.path.join(base_dir, "data_augmented_multiclass_dataset/image_datasets")
os.makedirs(augmented_data_dir, exist_ok=True)
os.makedirs(augmented_image_datasets_dir, exist_ok=True)

# Ensure class folders exist
class_folders = ["tumor_immunogenic", "tumor_non_immunogenic", "bacterial_immunogenic", "bacterial_non_immunogenic"]
for cls in class_folders:
    os.makedirs(os.path.join(augmented_data_dir, cls), exist_ok=True)

# Augmentations
augmentation_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(30, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Normal transform (same as original)
original_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Load datasets
original_dataset = ImageFolder(root=data_dir, transform=original_transform)
raw_dataset = ImageFolder(root=data_dir)

# Augment and save images
augmented_metadata = []
for idx, (image, label) in enumerate(raw_dataset):
    class_name = raw_dataset.classes[label]
    class_folder = os.path.join(augmented_data_dir, class_name)

    for i in range(num_augmentations_per_image):
        img_name = f"augmented_{idx}_{i}.png"
        img_path = os.path.join(class_folder, img_name)
        if os.path.exists(img_path):
            print(f"Skipping existing file: {img_path}")
            continue
        aug_img = augmentation_transforms(image)
        save_image(aug_img, img_path)
        print(f"Saved: {img_path}")
        augmented_metadata.append((img_path, label))

# Reload augmented images
augmented_dataset = []
for img_path, label in augmented_metadata:
    img = Image.open(img_path).convert("RGB")
    tensor = original_transform(img)
    augmented_dataset.append((tensor, label))

# Merge datasets
full_dataset = list(original_dataset) + augmented_dataset

# Split
train_size = int(0.75 * len(full_dataset))
val_size = int(0.05 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_set, val_set, test_set = random_split(full_dataset, [train_size, val_size, test_size])

# Save splits
torch.save(train_set, os.path.join(augmented_image_datasets_dir, "train_set.pth"))
torch.save(val_set, os.path.join(augmented_image_datasets_dir, "val_set.pth"))
torch.save(test_set, os.path.join(augmented_image_datasets_dir, "test_set.pth"))

# Print summary
print("\nDataset saved.")
print(f"Total samples (original + augmented): {len(full_dataset)}")
print(f"Train: {len(train_set)} | Val: {len(val_set)} | Test: {len(test_set)}")


Skipping existing file: /Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project/data_augmented_multiclass_dataset/multiclass/bacterial_immunogenic/augmented_0_0.png
Skipping existing file: /Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project/data_augmented_multiclass_dataset/multiclass/bacterial_immunogenic/augmented_0_1.png
Skipping existing file: /Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project/data_augmented_multiclass_dataset/multiclass/bacterial_immunogenic/augmented_0_2.png
Saved: /Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project/data_augmented_multiclass_dataset/multiclass/bacterial_immunogenic/augmented_0_3.png
Saved: /Users/marcobenavides/Documents/Columbia University/Spring 2025/DL Biomedical Imaging/Project/data_augmented_multiclass_dataset/multiclass/bacterial_immunogenic/augmented_0_4.png
Skipping existing f