In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import resample
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Define paths
data_path = "D:/DATASET/CNN/steatosis/trying"  # Update with your actual data path
balanced_data_path = "D:/DATASET/CNN/steatosis/balanced_dataset"  # Update with your desired balanced data path

In [None]:
# Create new directory if it doesn't exist
os.makedirs(balanced_data_path, exist_ok=True)

# Define desired number of samples for each class
desired_samples = 4000

# Load images for each class
class_samples = {}
for i in range(4):
    class_samples[i] = [os.path.join(data_path, str(i), file) for file in os.listdir(os.path.join(data_path, str(i)))]

In [None]:
# Remove unwanted images from classes 1 and 2 (without lipid vacuoles)
class_samples[1] = [img for img in class_samples[1] if "lipid_vacuoles" in img]
class_samples[2] = [img for img in class_samples[2] if "lipid_vacuoles" in img]

# Resample classes 0, 1, and 2 to 4000 samples each
for i in [0, 1, 2]:
    class_samples[i] = resample(class_samples[i], replace=True, n_samples=desired_samples, random_state=42)

# Oversample class 3 to 4000 samples
class_samples[3] = resample(class_samples[3], replace=True, n_samples=desired_samples, random_state=42)

In [None]:
# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

In [None]:
# Generate augmented images for each class
for i in range(4):
    class_dir = os.path.join(balanced_data_path, str(i))
    os.makedirs(class_dir, exist_ok=True)
    for img_path in class_samples[i]:
        img_name = os.path.basename(img_path)
        img = plt.imread(img_path)
        img = img.reshape((1,) + img.shape)
        j = 0
        for batch in datagen.flow(img, batch_size=1, save_to_dir=class_dir, save_prefix='aug', save_format='png'):
            j += 1
            if j >= 4:  # Generate 4 augmented images per original image
                break
                
# Calculate and plot class distribution before and after resampling
class_distribution_before = {str(i): len(class_samples[i]) for i in range(4)}
class_distribution_after = {str(i): len(os.listdir(os.path.join(balanced_data_path, str(i)))) for i in range(4)}

In [None]:
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.bar(class_distribution_before.keys(), class_distribution_before.values())
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.title('Class Distribution Before Resampling')

In [None]:
plt.subplot(1, 2, 2)
plt.bar(class_distribution_after.keys(), class_distribution_after.values())
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.title('Class Distribution After Resampling')

In [None]:
plt.tight_layout()
plt.show()