In [8]:
import os
import random
import shutil
from PIL import Image, ImageEnhance
import numpy as np
import matplotlib.pyplot as plt

In [9]:
BASE_DIR = "data/CT Scan"

TRAIN_DIR = os.path.join(BASE_DIR, "Train")
VAL_DIR   = os.path.join(BASE_DIR, "Validation")
TEST_DIR  = os.path.join(BASE_DIR, "Test")

OUTPUT_DIR = "data/augmented"

AUG_PER_IMAGE = 5   

IMG_SIZE = (224,224)

In [10]:
def augment_image(img):

    # rotation
    if random.random() < 0.7:
        angle = random.randint(-20,20)
        img = img.rotate(angle)

    # flip
    if random.random() < 0.5:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)

    # brightness
    if random.random() < 0.5:
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(random.uniform(0.7,1.3))

    # contrast
    if random.random() < 0.5:
        enhancer = ImageEnhance.Contrast(img)
        img = enhancer.enhance(random.uniform(0.7,1.3))

    # zoom
    if random.random() < 0.5:
        w,h = img.size
        scale = random.uniform(0.8,1.0)
        nw,nh = int(w*scale), int(h*scale)
        left = random.randint(0,w-nw)
        top  = random.randint(0,h-nh)
        img = img.crop((left,top,left+nw,top+nh))
        img = img.resize((w,h))

    # noise
    if random.random() < 0.5:
        arr = np.array(img)
        noise = np.random.normal(0,5,arr.shape)
        arr = arr + noise
        arr = np.clip(arr,0,255).astype(np.uint8)
        img = Image.fromarray(arr)

    return img

In [11]:
def copy_folder(src, dst):
    if os.path.exists(dst):
        shutil.rmtree(dst)
    shutil.copytree(src,dst)

copy_folder(VAL_DIR , os.path.join(OUTPUT_DIR,"Validation"))
copy_folder(TEST_DIR, os.path.join(OUTPUT_DIR,"Test"))

In [12]:
train_out = os.path.join(OUTPUT_DIR,"Train")

if os.path.exists(train_out):
    shutil.rmtree(train_out)

for class_name in os.listdir(TRAIN_DIR):

    class_path = os.path.join(TRAIN_DIR,class_name)
    save_class = os.path.join(train_out,class_name)

    os.makedirs(save_class,exist_ok=True)

    images = os.listdir(class_path)

    for img_name in images:

        path = os.path.join(class_path,img_name)

        img = Image.open(path).convert("RGB")
        img = img.resize(IMG_SIZE)

        # save original
        img.save(os.path.join(save_class,img_name))

        # create augmented copies
        for i in range(AUG_PER_IMAGE):
            aug = augment_image(img)
            aug.save(os.path.join(
                save_class,
                f"{img_name.split('.')[0]}_aug{i}.jpg"
            ))

In [15]:
train_out = os.path.join(OUTPUT_DIR,"Train")

def count_images(folder):
    total = 0
    for cls in os.listdir(folder):
        total += len(os.listdir(os.path.join(folder,cls)))
    return total

orig = count_images(TRAIN_DIR)
aug  = count_images(train_out)

print("Original Train:", orig)
print("Augmented Train:", aug)

Original Train: 3693
Augmented Train: 21358
