# Leukemia Type and Grade Classification Model

This notebook builds and trains a multi-output Convolutional Neural Network (CNN) to classify leukemia images into four types (ALL, AML, CLL, CML) and three disease grades (Chronic, Accelerated, Blast Crisis).

## 1. Imports and Configuration

First, we'll import the necessary libraries and define the main configuration parameters for our model and data.

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import save_img
from tensorflow.keras.utils import to_categorical, plot_model
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# --- Configuration ---
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 20 # Keep it low for a quick test run, increase for real training
NUM_CLASSES_TYPE = 4
NUM_CLASSES_GRADE = 3
BASE_DATA_DIR = "data"

## 2. Dummy Data Generation

This section contains a helper function to create a fake dataset that mimics the expected directory structure (`data/ALL`, `data/AML`, etc.). This is useful for testing the entire pipeline without needing the actual dataset. It also generates a `labels.csv` file which is crucial for associating each image with its correct grade.

In [None]:
def generate_dummy_data(base_dir, num_images_per_class=50):
    """
    Generates a dummy image dataset and a corresponding labels CSV file.
    """
    if os.path.exists(base_dir):
        print(f"Directory '{base_dir}' already exists. Skipping data generation.")
        if os.path.exists('labels.csv'):
            return pd.read_csv('labels.csv')
        else:
            # If directory exists but CSV is missing, we should still regenerate labels
            print("labels.csv not found. Will regenerate labels.")
    
    print("Generating dummy data...")
    leukemia_types = ["ALL", "AML", "CLL", "CML"]
    grades = ["Chronic", "Accelerated", "Blast_Crisis"]
    
    records = []

    for l_type in leukemia_types:
        class_dir = os.path.join(base_dir, l_type)
        os.makedirs(class_dir, exist_ok=True)
        for i in range(num_images_per_class):
            # Create a random noise image
            dummy_image_array = np.random.rand(IMG_SIZE[0], IMG_SIZE[1], 3) * 255
            
            # Assign a random grade
            grade = np.random.choice(grades)
            
            # Define file path
            img_filename = f"{l_type}_{i}_{grade}.png"
            img_filepath = os.path.join(class_dir, img_filename)

            # Save the image
            save_img(img_filepath, dummy_image_array)

            # Store record for CSV
            records.append({
                "filepath": img_filepath,
                "leukemia_type": l_type,
                "grade": grade
            })

    # Create and save the labels DataFrame
    labels_df = pd.DataFrame(records)
    labels_df.to_csv("labels.csv", index=False)
    print("Dummy data generation complete.")
    return labels_df

## 3. Data Loading and Preprocessing

Here, we define functions to create an efficient `tf.data.Dataset`. This involves:
1.  Reading the image files.
2.  Decoding, resizing, and normalizing the images.
3.  Mapping string labels (like 'ALL', 'Chronic') to one-hot encoded vectors.
4.  Shuffling, batching, and prefetching the data for optimal performance during training.

In [None]:
def load_and_preprocess_image(filepath, label_type, label_grade):
    """
    Loads an image from filepath, decodes it, resizes, and normalizes it.
    """
    img = tf.io.read_file(filepath)
    # Ensure we decode as PNG, even if file extension varies. Handles potential format issues.
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, [IMG_SIZE[0], IMG_SIZE[1]])
    img = img / 255.0  # Normalize to [0, 1]
    return img, (label_type, label_grade)

def create_dataset(df, type_map, grade_map):
    """
    Creates a TensorFlow Dataset from a pandas DataFrame.
    """
    filepaths = df['filepath'].values
    
    # Map string labels to integer indices
    type_labels = df['leukemia_type'].map(type_map).values
    grade_labels = df['grade'].map(grade_map).values

    # One-hot encode the labels
    type_labels_one_hot = to_categorical(type_labels, num_classes=NUM_CLASSES_TYPE)
    grade_labels_one_hot = to_categorical(grade_labels, num_classes=NUM_CLASSES_GRADE)

    # Create the dataset
    dataset = tf.data.Dataset.from_tensor_slices((filepaths, type_labels_one_hot, grade_labels_one_hot))
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(buffer_size=len(df)).batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)
    
    return dataset

## 4. Model Architecture (Multi-Output CNN)

This is the core of the project. We build a multi-output model:
- **Shared Base:** A stack of `Conv2D` and `MaxPooling2D` layers that learn to extract common visual features from the images.
- **Type Branch:** A dedicated dense head that classifies the leukemia type.
- **Grade Branch:** A second dense head that classifies the disease grade based on features like cell density, which are learned by the shared base.

In [None]:
def build_multi_output_model(input_shape, num_types, num_grades):
    """
    Builds a multi-output CNN model.
    """
    # Input Layer
    inputs = Input(shape=input_shape, name="input_layer")

    # --- Shared Convolutional Base ---
    # This part of the network learns common features from the images.
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)

    # --- Branch 1: Leukemia Type Classification ---
    type_branch = Dense(256, activation='relu')(x)
    type_branch = Dropout(0.5)(type_branch)
    type_output = Dense(num_types, activation='softmax', name='type_output')(type_branch)

    # --- Branch 2: Disease Grade Classification ---
    # The grade is determined by cell density, a feature the CNN will learn to recognize.
    grade_branch = Dense(256, activation='relu')(x)
    grade_branch = Dropout(0.5)(grade_branch)
    grade_output = Dense(num_grades, activation='softmax', name='grade_output')(grade_branch)

    # --- Create and Compile the Model ---
    model = Model(inputs=inputs, outputs=[type_output, grade_output], name="leukemia_classifier")

    # We define separate losses for each output branch
    losses = {
        'type_output': 'categorical_crossentropy',
        'grade_output': 'categorical_crossentropy'
    }
    
    model.compile(optimizer='adam',
                  loss=losses,
                  metrics=['accuracy'])
    
    return model

## 5. Visualization Function

A helper function to plot the training and validation accuracy and loss for both output branches. This helps in diagnosing issues like overfitting.

In [None]:
def plot_history(history):
    """Plots training and validation loss and accuracy."""
    plt.style.use('seaborn-v0_8-darkgrid')
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle('Model Training History', fontsize=16)

    # Plot Type Accuracy
    axes[0, 0].plot(history.history['type_output_accuracy'], label='Train Type Acc')
    axes[0, 0].plot(history.history['val_type_output_accuracy'], label='Val Type Acc')
    axes[0, 0].set_title('Leukemia Type Accuracy')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].legend(loc='lower right')

    # Plot Grade Accuracy
    axes[0, 1].plot(history.history['grade_output_accuracy'], label='Train Grade Acc')
    axes[0, 1].plot(history.history['val_grade_output_accuracy'], label='Val Grade Acc')
    axes[0, 1].set_title('Disease Grade Accuracy')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].legend(loc='lower right')

    # Plot Type Loss
    axes[1, 0].plot(history.history['type_output_loss'], label='Train Type Loss')
    axes[1, 0].plot(history.history['val_type_output_loss'], label='Val Type Loss')
    axes[1, 0].set_title('Leukemia Type Loss')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].legend(loc='upper right')

    # Plot Grade Loss
    axes[1, 1].plot(history.history['grade_output_loss'], label='Train Grade Loss')
    axes[1, 1].plot(history.history['val_grade_output_loss'], label='Val Grade Loss')
    axes[1, 1].set_title('Disease Grade Loss')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].legend(loc='upper right')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig("training_history.png")
    print("\nSaved training history plot to 'training_history.png'")
    plt.show()

## 6. Main Execution

This is where we bring everything together. We'll execute the steps in order.

### 6.1. Generate or Load Data

We call the function to generate dummy data. If you are using your own data, you should **comment out this cell** and instead load your own CSV file with `pd.read_csv('your_labels.csv')`.

In [None]:
labels_df = generate_dummy_data(BASE_DATA_DIR)
print(f"Loaded {len(labels_df)} image records.")
labels_df.head()

### 6.2. Create Label Mappings and Split Data

The model needs numerical labels, so we create dictionaries to map our string labels (e.g., 'ALL') to integers (e.g., 0). We then split our data into training, validation, and testing sets to ensure we can evaluate the model on data it has never seen before.

In [None]:
# Create label-to-integer mappings
type_classes = sorted(labels_df['leukemia_type'].unique())
grade_classes = sorted(labels_df['grade'].unique())

type_map = {label: i for i, label in enumerate(type_classes)}
grade_map = {label: i for i, label in enumerate(grade_classes)}

# Create integer-to-label reverse mappings for prediction
inv_type_map = {i: label for label, i in type_map.items()}
inv_grade_map = {i: label for label, i in grade_map.items())

print("Type Map:", type_map)
print("Grade Map:", grade_map)

# Split data into training, validation, and test sets
# Stratify ensures that the class distribution is similar across all splits
train_df, test_df = train_test_split(labels_df, test_size=0.2, random_state=42, stratify=labels_df['leukemia_type'])
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, stratify=train_df['leukemia_type'])

print(f"\nDataset split:")
print(f"Training samples:   {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples:       {len(test_df)}")

### 6.3. Create TensorFlow Datasets

We now convert our pandas DataFrames into `tf.data.Dataset` objects. This is the standard and most efficient way to feed data into a TensorFlow model for training.

In [None]:
train_dataset = create_dataset(train_df, type_map, grade_map)
val_dataset = create_dataset(val_df, type_map, grade_map)
test_dataset = create_dataset(test_df, type_map, grade_map)

print("Datasets created successfully.")

### 6.4. Build and Compile the Model

We instantiate the model architecture we defined earlier and print a summary to visualize its layers and parameters.

In [None]:
model = build_multi_output_model(
    input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3),
    num_types=NUM_CLASSES_TYPE,
    num_grades=NUM_CLASSES_GRADE
)
model.summary()

# Optionally, save a plot of the model architecture
plot_model(model, to_file='model_architecture.png', show_shapes=True)

### 6.5. Train the Model

Now, we fit the model to our training data. We also provide the validation data so we can monitor the model's performance on a separate data split after each epoch.

In [None]:
print("\n--- Starting Model Training ---")
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS
)
print("--- Model Training Complete ---")

### 6.6. Evaluate the Model

After training, we evaluate the final model's performance on the test set. This gives us a final, unbiased measure of how well our model is likely to perform on new, unseen data.

In [None]:
print("\n--- Evaluating Model on Test Data ---")
results = model.evaluate(test_dataset)

print(f"Test Loss (Total): {results[0]:.4f}")
print(f"Test Loss (Type):  {results[1]:.4f}")
print(f"Test Loss (Grade): {results[2]:.4f}")
print(f"Test Accuracy (Type): {results[3]:.4f}")
print(f"Test Accuracy (Grade): {results[4]:.4f}")

### 6.7. Visualize Training History

Plotting the training history helps us understand how the model learned over time. We look for the convergence of training and validation curves, and signs of overfitting (where validation performance gets worse while training performance improves).

In [None]:
plot_history(history)

### 6.8. Run an Example Prediction

Finally, let's take a single image from our test set and see what the model predicts for both the leukemia type and disease grade.

In [None]:
print("\n--- Running Example Prediction ---")
# Take one batch from the test set for demonstration
for images, (labels_type, labels_grade) in test_dataset.take(1):
    sample_image = images[0:1] # Get the first image, keep batch dimension
    true_label_type = inv_type_map[np.argmax(labels_type[0])]
    true_label_grade = inv_grade_map[np.argmax(labels_grade[0])]

    predictions = model.predict(sample_image)
    pred_type_idx = np.argmax(predictions[0])
    pred_grade_idx = np.argmax(predictions[1])

    predicted_type = inv_type_map[pred_type_idx]
    predicted_grade = inv_grade_map[pred_grade_idx]

    print(f"\nSample Image Analysis:")
    print(f"  > True Leukemia Type: {true_label_type}")
    print(f"  > Predicted Leukemia Type: {predicted_type}")
    print("-" * 20)
    print(f"  > True Disease Grade: {true_label_grade}")
    print(f"  > Predicted Disease Grade: {predicted_grade}")
    
    # Display the image being predicted
    plt.figure(figsize=(4,4))
    plt.imshow(sample_image[0])
    plt.title(f"Predicted Type: {predicted_type}\nPredicted Grade: {predicted_grade.replace('_', ' ')}")
    plt.axis('off')
    plt.savefig("prediction_example.png")
    print("\nSaved example prediction image to 'prediction_example.png'")
    plt.show()

### 7. Clean Up (Optional)

Run this cell to remove the dummy data directory.

In [None]:
# shutil.rmtree(BASE_DATA_DIR)
# print(f"Cleaned up dummy data directory: '{BASE_DATA_DIR}'")