# Phase 1: Data Loading, Inspection & Understanding

This notebook goes through the first phase of the MTTL Malaria Detection experimentation pipeline. We will load, inspect, and analyze your chosen malaria dataset to understand its structure, contents, and label distributions, In this case the NIH-NLM Thin dataset.

## 1. Import Required Libraries

Import all necessary Python libraries for data handling, visualization, and image processing.

In [None]:
import os
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
from collections import Counter
from typing import List, Dict

plt.style.use('seaborn-v0_8')
sns.set_context('notebook')

## 2. Load Dataset

In [None]:
# Data root dir
NLM_ROOT = os.path.join('..', 'data', 'NIH-NLM-ThinBloodSmearsPf')
POINT_SET_DIR = os.path.join(NLM_ROOT, 'Point Set')
POLYGON_SET_DIR = os.path.join(NLM_ROOT, 'Polygon Set')

# List all patient folders in Point Set and Polygon Set directories
point_set_folders = sorted([os.path.join(POINT_SET_DIR, d) for d in os.listdir(POINT_SET_DIR) if os.path.isdir(os.path.join(POINT_SET_DIR, d))])
polygon_set_folders = sorted([os.path.join(POLYGON_SET_DIR, d) for d in os.listdir(POLYGON_SET_DIR) if os.path.isdir(os.path.join(POLYGON_SET_DIR, d))])

# Print out findings
print(f"Found {len(point_set_folders)} Point Set folders.")
print(f"Found {len(polygon_set_folders)} Polygon Set folders.")

## 3. Explore Dataset Structure

Inspect the directory structure and file organization. List files, folders, and check for annotation files in the NIH-NLM dataset.

In [None]:
# Dataset Structure
sample_folder = point_set_folders[0] if point_set_folders else None
if sample_folder:
    print(f"Sample folder: {sample_folder}")
    subfolders = os.listdir(sample_folder)
    print(f"Subfolders in sample folder: {subfolders}")
    img_dir = os.path.join(sample_folder, 'Img')
    gt_dir = os.path.join(sample_folder, 'GT')
    if os.path.isdir(img_dir):
        print(f"Images: {os.listdir(img_dir)[:3]} ... (total: {len(os.listdir(img_dir))})")
    if os.path.isdir(gt_dir):
        print(f"Annotation files: {os.listdir(gt_dir)[:3]} ... (total: {len(os.listdir(gt_dir))})")
else:
    print("No Point Set folders found.")

## 4. Extract Dataset Information

Parse annotation file or metadata to extract relevant information such as labels, and bounding boxes if available.

In [None]:
# Parse Annotation Files
def parse_point_set_annotations(patient_folders):
    annotations = []
    for folder in patient_folders:
        gt_dir = os.path.join(folder, 'GT')
        img_dir = os.path.join(folder, 'Img')
        if not os.path.isdir(gt_dir) or not os.path.isdir(img_dir):
            continue
        for ann_file in os.listdir(gt_dir):
            if not ann_file.lower().endswith('.txt'):
                continue
            ann_path = os.path.join(gt_dir, ann_file)
            img_name = ann_file.replace('.txt', '.jpg')
            img_path = os.path.join(img_dir, img_name)
            with open(ann_path, 'r', encoding='utf-8') as f:
                lines = f.read().strip().split('\n')
                if len(lines) < 2:
                    continue
                for line in lines[1:]:
                    parts = line.split(',')
                    if len(parts) < 7:
                        continue
                    cell_type = parts[1]
                    shape = parts[3]
                    if shape == 'Point':
                        x = float(parts[5])
                        y = float(parts[6])
                        annotations.append({
                            'image_path': img_path,
                            'cell_type': cell_type,
                            'shape': shape,
                            'x': x,
                            'y': y
                        })
                    elif shape == 'Polygon':
                        n_points = int(parts[4])
                        coords = [float(v) for v in parts[5:5+2*n_points]]
                        xy = list(zip(coords[::2], coords[1::2]))
                        # Calculate bounding box for polygon like [ min(x), min(y), max(x), max(y) ]
                        xs, ys = zip(*xy)
                        bbox = [min(xs), min(ys), max(xs), max(ys)]
                        annotations.append({
                            'image_path': img_path,
                            'cell_type': cell_type,
                            'shape': shape,
                            'polygon': xy,
                            'bbox': bbox
                        })
    return annotations

annotations = parse_point_set_annotations(point_set_folders)
print(f"Parsed {len(annotations)} cell annotations from Point Set.")
if annotations:
    print("Sample annotation:", annotations[0])

## 5. Identify Classes and Labels

Identify all available classes and labels in the dataset. 

In [None]:
# Identify all unique classes/labels 
if annotations:
    all_labels = [row['cell_type'] for row in annotations]
    unique_classes = sorted(set(all_labels))
    print(f"Found {len(unique_classes)} unique classes:")
    for c in unique_classes:
        print(f"- {c}")
else:
    print("No annotation data to extract classes from.")

## 6. Calculate Class Distribution and Statistics

Compute the number of samples per class and display class distribution statistics.

In [None]:
# Compute class distribution (corrected for new annotation structure)
if annotations:
    class_counts = Counter(all_labels)
    class_dist_df = pd.DataFrame({
        'Class': list(class_counts.keys()),
        'Count': list(class_counts.values())
    }).sort_values('Count', ascending=False)
    print(class_dist_df)
else:
    print("No annotation data to compute class distribution.")

## 7. Plot Sample Images with Annotations

Visualize a few sample images from the dataset with their corresponding annotations (e.g., bounding boxes, labels) overlaid.

In [None]:
# Visualize an image sample with annotations
from collections import Counter, defaultdict
import matplotlib.patches as patches

# Color map and labels
cell_type_colors = {
    'Parasitized': 'red',
    'Uninfected': 'lime',
    'White_Blood_Cell': 'blue'
}
cell_type_short = {
    'Parasitized': 'P',
    'Uninfected': 'U',
    'White_Blood_Cell': 'W'
}

if annotations:
    image_to_anns = defaultdict(list)
    for ann in annotations:
        image_to_anns[ann['image_path']].append(ann)
    image_paths = list(image_to_anns.keys())
    if image_paths:
        img_path = random.choice(image_paths)
        anns = image_to_anns[img_path]
        if os.path.exists(img_path):
            img = np.array(Image.open(img_path).convert('RGB'))
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.imshow(img)
            for ann in anns:
                color = cell_type_colors.get(ann['cell_type'], 'yellow')
                short_label = cell_type_short.get(ann['cell_type'], '?')
                if ann['shape'] == 'Point':
                    ax.plot(ann['x'], ann['y'], 'o', color=color, markersize=10, markeredgewidth=2, markeredgecolor='black')
                    ax.text(ann['x']+8, ann['y']-8, short_label, color=color, fontsize=13, weight='bold',
                            bbox=dict(facecolor='white', alpha=0.7, edgecolor=color, boxstyle='round,pad=0.2'))
                elif ann['shape'] == 'Polygon':
                    poly = np.array(ann['polygon'])
                    patch = patches.Polygon(poly, closed=True, fill=False, edgecolor=color, linewidth=2)
                    ax.add_patch(patch)
                    min_x, min_y, max_x, max_y = ann['bbox']
                    rect = patches.Rectangle((min_x, min_y), max_x-min_x, max_y-min_y, fill=False, edgecolor=color, linewidth=1, linestyle='dashed')
                    ax.add_patch(rect)
                    ax.text(min_x, min_y-10, short_label, color=color, fontsize=13, weight='bold',
                            bbox=dict(facecolor='white', alpha=0.7, edgecolor=color, boxstyle='round,pad=0.2'))
            ax.axis('off')
            handles = [patches.Patch(color=clr, label=lbl) for lbl, clr in zip(['P','U','W'], cell_type_colors.values())]
            plt.legend(handles=handles, labels=['P: Parasitized','U: Uninfected','W: WBC'], loc='lower center', frameon=True, fontsize=12, bbox_to_anchor=(0.5, -0.08), ncol=3)
            plt.show()
        else:
            print(f"Image not found: {img_path}")
    else:
        print("No images found in annotations.")
else:
    print("No annotation data to visualize.")

## Visualize More Images with Annotations

Display a grid of random images from the dataset, each with bounding boxes and class labels overlaid. This helps visually inspect annotation quality and class diversity.

In [None]:
# Display a grid of random images with colored annotations and minimal labels (by cell type)
import matplotlib.patches as patches

# Define a color map and minimal label map for cell types
cell_type_colors = {
    'Parasitized': 'red',
    'Uninfected': 'lime',
    'White_Blood_Cell': 'blue'
}
cell_type_short = {
    'Parasitized': 'P',
    'Uninfected': 'U',
    'White_Blood_Cell': 'W'
}

def plot_image_grid_with_annotations(annotations, n_images=8, ncols=4):
    if not annotations:
        print("No annotations to display.")
        return

    from collections import defaultdict
    image_to_anns = defaultdict(list)
    for ann in annotations:
        image_to_anns[ann['image_path']].append(ann)

    image_paths = list(image_to_anns.keys())
    sample_paths = random.sample(image_paths, min(n_images, len(image_paths)))
    nrows = (len(sample_paths) + ncols - 1) // ncols

    plt.figure(figsize=(ncols * 4, nrows * 4))
    for idx, img_path in enumerate(sample_paths):
        anns = image_to_anns[img_path]
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        ax = plt.subplot(nrows, ncols, idx + 1)
        ax.imshow(img)
        for ann in anns:
            color = cell_type_colors.get(ann['cell_type'], 'yellow')
            short_label = cell_type_short.get(ann['cell_type'], '?')
            if ann['shape'] == 'Point':
                ax.plot(ann['x'], ann['y'], 'o', color=color, markersize=8, markeredgewidth=2, markeredgecolor='black')
                ax.text(ann['x']+6, ann['y']-6, short_label, color=color, fontsize=11, weight='bold', bbox=dict(facecolor='white', alpha=0.7, edgecolor=color, boxstyle='round,pad=0.2'))
            elif ann['shape'] == 'Polygon':
                poly = np.array(ann['polygon'])
                patch = patches.Polygon(poly, closed=True, fill=False, edgecolor=color, linewidth=2)
                ax.add_patch(patch)
                min_x, min_y = poly.min(axis=0)
                max_x, max_y = poly.max(axis=0)
                rect = patches.Rectangle((min_x, min_y), max_x-min_x, max_y-min_y, fill=False, edgecolor=color, linewidth=1, linestyle='dashed')
                ax.add_patch(rect)
                ax.text(min_x, min_y-8, short_label, color=color, fontsize=11, weight='bold', bbox=dict(facecolor='white', alpha=0.7, edgecolor=color, boxstyle='round,pad=0.2'))
        ax.set_title(os.path.basename(img_path))
        ax.axis('off')
    # Add legend with minimal labels
    handles = [patches.Patch(color=clr, label=lbl) for lbl, clr in zip(['P','U','W'], cell_type_colors.values())]
    plt.figlegend(handles=handles, labels=['P: Parasitized','U: Uninfected','W: WBC'], loc='lower center', ncol=3, frameon=True, fontsize=12)
    plt.tight_layout(rect=[0,0.05,1,1])
    plt.show()

plot_image_grid_with_annotations(annotations, n_images=4, ncols=2)

## 8. Analyze Label Distribution

Create bar charts and summary statistics to analyze the distribution of labels across the dataset.

In [None]:
# Compute class distribution 
if annotations:
    class_counts = Counter(all_labels)
    class_dist_df = pd.DataFrame({
        'Class': list(class_counts.keys()),
        'Count': list(class_counts.values())
    }).sort_values('Count', ascending=False)
    print(class_dist_df)
else:
    print("No annotation data to compute class distribution.")

# Bar chart of class distribution
if annotations and class_dist_df is not None:
    plt.figure(figsize=(10, 5))
    sns.barplot(data=class_dist_df, x='Class', y='Count', palette='viridis')
    plt.title('Class Distribution in NIH-NLM Dataset')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
else:
    print("No class distribution data to plot.")

## 9. Dataset Statistics & Insights

Summarize key statistics about the dataset, such as:
- Total number of images
- Total number of annotated objects
- Number of unique classes
- Images per class (min, max, mean)
- Any notable class imbalance or annotation issues

Present these as a table or bullet points for quick reference.

In [None]:
# Compute and display dataset statistics & insights
from collections import Counter

def dataset_statistics(annotations, class_dist_df):
    if not annotations or class_dist_df is None:
        print("No annotation data available.")
        return
    
    total_images = len(set([ann['image_path'] for ann in annotations]))
    total_objects = len(annotations)
    unique_classes = class_dist_df['Class'].nunique()
    class_counts = class_dist_df.set_index('Class')['Count'].to_dict()
    min_per_class = class_dist_df['Count'].min()
    max_per_class = class_dist_df['Count'].max()
    mean_per_class = class_dist_df['Count'].mean()
    
    print(f"Total images: {total_images}")
    print(f"Total annotated objects: {total_objects}")
    print(f"Unique classes: {unique_classes}")
    print(f"Images per class (min/mean/max): {min_per_class} / {mean_per_class:.2f} / {max_per_class}")
    print("\nClass counts:")
    for cls, count in class_counts.items():
        print(f"  {cls}: {count}")
    
    # class imbalance check
    imbalance_ratio = max_per_class / min_per_class if min_per_class > 0 else float('inf')
    if imbalance_ratio > 2:
        print(f"\nWarning: Class imbalance detected (max/min ratio = {imbalance_ratio:.2f})")
    else:
        print("\nNo severe class imbalance detected.")

dataset_statistics(annotations, class_dist_df)