In [None]:
import os
import numpy as np
import shutil
from tensorflow.keras.models import load_model
from PIL import Image

In [None]:
input_dir = 'dataset' 
output_dir = 'balanced_dataset'
generator_model_path = 'model.keras'

generator = load_model(generator_model_path)

In [None]:
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)

In [None]:
os.makedirs(output_dir, exist_ok=True)
classes = [str(i) for i in range(5)]
for cls in classes:
    os.makedirs(os.path.join(output_dir, cls), exist_ok=True)

In [None]:
def is_image_file(filename):
    return os.path.splitext(filename)[1].lower() == '.png'

# Copy images from input dataset to output dataset
for cls in classes:
    in_cls_dir = os.path.join(input_dir, cls)
    out_cls_dir = os.path.join(output_dir, cls)
    if not os.path.exists(in_cls_dir):
        print(f"Warning: Directory {in_cls_dir} does not exist. Skipping.")
        continue
    for file in os.listdir(in_cls_dir):
        if is_image_file(file):
            src_path = os.path.join(in_cls_dir, file)
            dst_path = os.path.join(out_cls_dir, file)
            shutil.copy2(src_path, dst_path)

In [None]:
class_counts = {}
for cls in classes:
    cls_dir = os.path.join(output_dir, cls)
    count = len([f for f in os.listdir(cls_dir) if is_image_file(f)])
    class_counts[cls] = count
    print(f"Class {cls}: {count} samples")

max_count = max(class_counts.values())
print(f"\nMaximum number of samples among classes: {max_count}")

In [None]:
def generate_sample(class_idx):
    noise = np.random.normal(0, 1, (256,))
    
    one_hot = np.zeros(5)
    one_hot[class_idx] = 1
    
    gen_input = np.concatenate([noise, one_hot])
    gen_input = np.expand_dims(gen_input, axis=0)  # shape: (1, 261)
    
    generated = generator.predict(gen_input)
    
    generated = np.squeeze(generated, axis=0)
    
    if generated.min() < 0:
        generated = (generated + 1) / 2  # now in [0,1]
    generated = np.clip(generated, 0, 1) * 255
    generated = generated.astype(np.uint8)
    
    return generated

In [None]:
for cls in classes:
    current_count = class_counts[cls]
    samples_to_generate = max_count - current_count
    cls_dir = os.path.join(output_dir, cls)
    print(f"\nClass {cls}: Generating {samples_to_generate} synthetic samples...")
    
    for i in range(samples_to_generate):
        synthetic_img = generate_sample(int(cls))
        
        filename = f"generated_{i}.png"
        file_path = os.path.join(cls_dir, filename)
        
        img = Image.fromarray(synthetic_img)
        img.save(file_path)
        
print("\nSynthetic sample generation complete.")


In [None]:
for cls in classes:
    cls_dir = os.path.join(output_dir, cls)
    count = len([f for f in os.listdir(cls_dir) if is_image_file(f)])
    print(f"After augmentation, class {cls} has {count} samples.")