In [1]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1 (from tensorflow)
  Downloading tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting wheel<1.0,>=0.23.0 (from astunparse>=1.6.0->tensorflow

In [3]:
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Load the data
def load_data(data_dir='/content/drive/MyDrive/Colab Notebooks/processed_data'):
    X_train = np.load(os.path.join(data_dir, 'X_train.npy'))
    y_train = np.load(os.path.join(data_dir, 'y_train.npy'))
    X_valid = np.load(os.path.join(data_dir, 'X_valid.npy'))
    y_valid = np.load(os.path.join(data_dir, 'y_valid.npy'))
    X_test = np.load(os.path.join(data_dir, 'X_test.npy'))
    y_test = np.load(os.path.join(data_dir, 'y_test.npy'))

    with open(os.path.join(data_dir, 'class_names.pkl'), 'rb') as f:
        class_names = pickle.load(f)

    return X_train, y_train, X_valid, y_valid, X_test, y_test, class_names

# Analyze class distribution
def analyze_class_distribution(y_train, class_names):
    unique_classes, counts = np.unique(y_train, return_counts=True)
    class_distribution = dict(zip([class_names[i] for i in unique_classes], counts))

    print("Class distribution in training data:")
    for class_name, count in class_distribution.items():
        print(f"{class_name}: {count}")

    return class_distribution

# Visualize sample images
def visualize_samples(X_train, y_train, class_names, num_samples=3):
    fig, axes = plt.subplots(len(class_names), num_samples, figsize=(10, 12))

    for class_idx, class_name in enumerate(class_names):
        # Get indices of samples belonging to this class
        indices = np.where(y_train == class_idx)[0]

        # If there are not enough samples, use what's available
        samples_to_show = min(num_samples, len(indices))

        for i in range(samples_to_show):
            sample_idx = indices[i]
            img = X_train[sample_idx].reshape(64, 64)
            axes[class_idx, i].imshow(img, cmap='gray')
            axes[class_idx, i].set_title(f"{class_name}")
            axes[class_idx, i].axis('off')

    plt.tight_layout()
    plt.savefig('sample_images.png')
    plt.close()

def safe_dental_contrast(image):
    """Numerically stable contrast adjustment for dental X-rays"""
    # Convert to float32
    image = tf.image.convert_image_dtype(image, tf.float32)

    # CLAHE-like contrast with safe ranges
    image = tf.image.adjust_contrast(
        image,
        contrast_factor=tf.clip_by_value(
            tf.random.uniform([], 0.85, 1.15),  # More conservative range
            0.8, 1.2  # Absolute safety bounds
        )
    )

    # Safe brightness adjustment
    image = tf.image.adjust_brightness(
        image,
        delta=tf.clip_by_value(
            tf.random.normal([], mean=0.0, stddev=0.05),
            -0.1, 0.1
        )
    )

    # Ensure valid pixel range
    return tf.clip_by_value(image, 0.0, 1.0)


# Create augmentation generator
def create_augmentation_generator():
    return ImageDataGenerator(
        rotation_range=15,
        width_shift_range=0.05,
        height_shift_range=0.05,
        shear_range=0.05,
        zoom_range=[0.9, 1.1],
        horizontal_flip=True,
        vertical_flip=False,
        fill_mode='constant',
        cval=0.0,
        preprocessing_function=safe_dental_contrast,
        brightness_range=None
    )

# Generate augmented samples for minority classes
def generate_augmented_data(X_train, y_train, class_distribution, class_names, target_samples=10000):
    augmented_X = []
    augmented_y = []

    datagen = create_augmentation_generator()

    # Add all original samples to the augmented dataset
    augmented_X.extend(X_train)
    augmented_y.extend(y_train)

    # For each class that needs augmentation
    for class_idx, class_name in enumerate(class_names):
        # Get count of this class
        if class_name in class_distribution:
            class_count = class_distribution[class_name]
        else:
            continue

        # Skip majority class or classes with sufficient samples
        if class_count >= target_samples:
            print(f"Skipping {class_name} (already has {class_count} samples)")
            continue

        # Calculate how many augmented samples we need
        num_to_generate = target_samples - class_count
        print(f"Generating {num_to_generate} augmented samples for {class_name}")

        # Get indices of samples belonging to this class
        class_indices = np.where(y_train == class_idx)[0]

        # Generate augmented samples
        samples_generated = 0
        while samples_generated < num_to_generate:
            # Randomly select a sample from this class
            sample_idx = np.random.choice(class_indices)
            sample = X_train[sample_idx].reshape(1, 64, 64, 1)

            # Generate an augmented sample
            for x_batch in datagen.flow(sample, batch_size=1):
                augmented_X.append(x_batch[0])
                augmented_y.append(class_idx)
                samples_generated += 1

                if samples_generated >= num_to_generate:
                    break

    # Convert lists to numpy arrays
    augmented_X = np.array(augmented_X)
    augmented_y = np.array(augmented_y)

    return augmented_X, augmented_y

# Visualize augmented samples
def visualize_augmented_samples(original_sample, class_name):
    datagen = create_augmentation_generator()

    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes[0, 0].imshow(original_sample.reshape(64, 64), cmap='gray')
    axes[0, 0].set_title("Original")
    axes[0, 0].axis('off')

    sample = original_sample.reshape((1,) + original_sample.shape)

    i = 0
    for batch in datagen.flow(sample, batch_size=1):
        i += 1
        row, col = divmod(i, 3)
        if row == 0 and col == 0:
            continue  # Skip the first position as it's the original

        axes[row, col].imshow(batch[0].reshape(64, 64), cmap='gray')
        axes[row, col].set_title(f"Aug {i}")
        axes[row, col].axis('off')

        if i >= 8:  # Show 8 augmented samples + original
            break

    plt.suptitle(f"Augmentation examples for class: {class_name}")
    plt.tight_layout()
    plt.savefig(f'augmentation_example_{class_name}.png')
    plt.close()

# Save augmented dataset
def save_augmented_dataset(X_train_aug, y_train_aug, X_valid, y_valid, X_test, y_test, output_dir='/content/drive/MyDrive/Colab Notebooks/augmented_data'):
    os.makedirs(output_dir, exist_ok=True)

    np.save(os.path.join(output_dir, 'X_train_augmented.npy'), X_train_aug)
    np.save(os.path.join(output_dir, 'y_train_augmented.npy'), y_train_aug)
    np.save(os.path.join(output_dir, 'X_valid.npy'), X_valid)
    np.save(os.path.join(output_dir, 'y_valid.npy'), y_valid)
    np.save(os.path.join(output_dir, 'X_test.npy'), X_test)
    np.save(os.path.join(output_dir, 'y_test.npy'), y_test)

    print(f"Augmented dataset saved to {output_dir}")

def main():
    # Load data
    print("Loading data...")
    X_train, y_train, X_valid, y_valid, X_test, y_test, class_names = load_data()

    # Analyze original class distribution
    print("\nAnalyzing class distribution...")
    class_distribution = analyze_class_distribution(y_train, class_names)

    # Visualize samples from each class
    print("\nVisualizing sample images...")
    visualize_samples(X_train, y_train, class_names)

    # Show augmentation examples for each minority class
    print("\nVisualizing augmentation examples...")
    minority_classes = ["Cavity", "Impacted Tooth", "Implant"]
    for class_name in minority_classes:
        class_idx = class_names.index(class_name)
        sample_idx = np.where(y_train == class_idx)[0][0]
        visualize_augmented_samples(X_train[sample_idx], class_name)

    # Set target samples per class for balancing
    # Aim for a more balanced distribution without making dataset too large
    target_samples = 10000  # Adjust based on your memory constraints

    # Generate augmented data
    print(f"\nGenerating augmented data (target: {target_samples} samples per class)...")
    X_train_aug, y_train_aug = generate_augmented_data(
        X_train, y_train, class_distribution, class_names, target_samples
    )

    # Analyze augmented class distribution
    print("\nClass distribution after augmentation:")
    unique_classes, counts = np.unique(y_train_aug, return_counts=True)
    for i, count in zip(unique_classes, counts):
        print(f"{class_names[i]}: {count}")

    # Save augmented dataset
    print("\nSaving augmented dataset...")
    save_augmented_dataset(X_train_aug, y_train_aug, X_valid, y_valid, X_test, y_test)

    print("\nData augmentation completed successfully!")
    print(f"Original training set: {X_train.shape[0]} samples")
    print(f"Augmented training set: {X_train_aug.shape[0]} samples")

if __name__ == "__main__":
    main()

Loading data...

Analyzing class distribution...
Class distribution in training data:
Cavity: 3343
Fillings: 5262
Impacted Tooth: 1032
Implant: 1784
Normal: 17116

Visualizing sample images...

Visualizing augmentation examples...

Generating augmented data (target: 10000 samples per class)...
Generating 6657 augmented samples for Cavity
Generating 4738 augmented samples for Fillings
Generating 8968 augmented samples for Impacted Tooth
Generating 8216 augmented samples for Implant
Skipping Normal (already has 17116 samples)

Class distribution after augmentation:
Cavity: 10000
Fillings: 10000
Impacted Tooth: 10000
Implant: 10000
Normal: 17116

Saving augmented dataset...
Augmented dataset saved to /content/drive/MyDrive/Colab Notebooks/augmented_data

Data augmentation completed successfully!
Original training set: 28537 samples
Augmented training set: 57116 samples
