In [1]:
import os
import numpy as np
import shutil
from tensorflow.keras.models import load_model
from PIL import Image
from tensorflow.keras.preprocessing.image import save_img

In [2]:
generator_model_path = '/kaggle/input/lung-synth-generator/keras/default/1/GAN_generator.keras'

BASE_INPUT_DIR = '/kaggle/input/lung-ds-clahe/Full_slice'
BASE_OUTPUT_DIR = '/kaggle/working/lung-ds-synth/Full_slice'
CLASSES = ["0", "1", "2", "3", "4"]
NOISE_DIM = 256                 
NUM_CLASSES = 5 

splits = ["train", "val"]
generator = load_model(generator_model_path)

In [3]:
if os.path.exists(BASE_OUTPUT_DIR):
    shutil.rmtree(BASE_OUTPUT_DIR)

In [4]:
for s in splits:
    for cls in CLASSES:
        os.makedirs(os.path.join(BASE_OUTPUT_DIR, s, cls), exist_ok=True)

In [5]:
def is_image_file(filename):
    return os.path.splitext(filename)[1].lower() == '.png'

for s in splits:
    for cls in CLASSES:
        in_cls_dir = os.path.join(BASE_INPUT_DIR, s, cls)
        out_cls_dir = os.path.join(BASE_OUTPUT_DIR, s, cls)
        if not os.path.exists(in_cls_dir):
            print(f"Warning: Directory {in_cls_dir} does not exist. Skipping.")
            continue
        for file in os.listdir(in_cls_dir):
            if is_image_file(file):
                src_path = os.path.join(in_cls_dir, file)
                dst_path = os.path.join(out_cls_dir, file)
                shutil.copy2(src_path, dst_path)

In [7]:
class_counts = {}
for cls in CLASSES:
    cls_dir = os.path.join(BASE_OUTPUT_DIR, "train", cls)
    count = len([f for f in os.listdir(cls_dir) if is_image_file(f)])
    class_counts[cls] = count
    print(f"Class {cls}: {count} samples")

max_count = max(class_counts.values())
print(f"\nMaximum number of samples among classes: {max_count}")

Class 0: 196 samples
Class 1: 359 samples
Class 2: 887 samples
Class 3: 329 samples
Class 4: 119 samples

Maximum number of samples among classes: 887


In [8]:
def generate_sample(generator, class_label, noise_dim=NOISE_DIM, num_classes=NUM_CLASSES):
    noise = np.random.normal(0, 1, noise_dim)
    
    label = np.zeros(num_classes)
    label[int(class_label)] = 1
    
    gen_input = np.concatenate([noise, label], axis=0)
    gen_input = np.expand_dims(gen_input, axis=0)
    
    gen_img = generator.predict(gen_input, verbose=0)
    gen_img = gen_img[0]
    
    gen_img = (gen_img * 255).clip(0, 255).astype(np.uint8)
    return gen_img

In [9]:
overshoot = int(max_count * 0.5)

for cls in CLASSES:
    current_count = class_counts[cls]
    num_to_generate = max_count - current_count + overshoot
    if num_to_generate > 0:
        print(f"\nGenerating {num_to_generate} samples for class {cls} ...")
        dst_folder = os.path.join(BASE_OUTPUT_DIR, "train", cls)
        for i in range(num_to_generate):
            gen_img = generate_sample(generator, class_label=cls)
                    
            out_filename = os.path.join(dst_folder, f"generated_{i}.png")
            save_img(out_filename, gen_img)
            
        print(f"Finished generating samples for class {cls}.")
    else:
        print(f"\nClass {cls} already has {current_count} samples. No generation needed.")
        
print("\nDataset balancing complete!")


Generating 1134 samples for class 0 ...
Finished generating samples for class 0.

Generating 971 samples for class 1 ...
Finished generating samples for class 1.

Generating 443 samples for class 2 ...
Finished generating samples for class 2.

Generating 1001 samples for class 3 ...
Finished generating samples for class 3.

Generating 1211 samples for class 4 ...
Finished generating samples for class 4.

Dataset balancing complete!


In [13]:
for cls in CLASSES:
    cls_dir = os.path.join(BASE_OUTPUT_DIR, "train", cls)
    count = len([f for f in os.listdir(cls_dir) if is_image_file(f)])
    print(f"After augmentation, class {cls} has {count} samples.")

After augmentation, class 0 has 1330 samples.
After augmentation, class 1 has 1330 samples.
After augmentation, class 2 has 1330 samples.
After augmentation, class 3 has 1330 samples.
After augmentation, class 4 has 1330 samples.


In [14]:
import shutil
shutil.make_archive("lung-ds", 'zip', "/kaggle/working/lung-ds-synth")

'/kaggle/working/lung-ds.zip'