# Sperm Morphology Colorized Images Dataset Classification

# Dataloading and creating dataframe

In [4]:
import os
import pandas as pd
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def create_dataframe(data_dir, save_dir):
    """
    Create DataFrame for the Sperm Morphology dataset based on specified directories.
    
    Args:
        data_dir (str): Root path to dataset directories (containing 'train', 'validation', 'test').
        save_dir (str): Directory to save DataFrame and debug logs.
    
    Returns:
        pd.DataFrame: DataFrame with image paths, class labels, and split information.
    """
    # Define the splits and classes
    splits = ['train', 'validation', 'test']
    classes = ['Normal_Sperm', 'Non-Sperm', 'Abnormal_Sperm']
    
    # Valid image extensions
    valid_extensions = {'.jpg', '.jpeg', '.png'}
    
    data = []
    
    # Iterate through each split and class
    for split in splits:
        split_dir = os.path.join(data_dir, split)
        if not os.path.isdir(split_dir):
            logger.error(f"Split directory does not exist: {split_dir}")
            raise FileNotFoundError(f"Split directory not found: {split_dir}")
        
        for class_name in classes:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_dir):
                logger.error(f"Class directory does not exist: {class_dir}")
                raise FileNotFoundError(f"Class directory not found: {class_dir}")
            
            # Collect image paths
            for img_file in os.listdir(class_dir):
                if os.path.splitext(img_file)[1].lower() in valid_extensions:
                    img_path = os.path.join(class_dir, img_file)
                    data.append({
                        "image_path": img_path,
                        "class_label": class_name,
                        "split": split
                    })
                else:
                    logger.warning(f"Invalid image extension for file: {img_file}")
    
    # Create DataFrame
    df = pd.DataFrame(data)
    if df.empty:
        logger.error("No valid images found in dataset!")
        raise ValueError("No valid images found.")
    
    # Log dataset statistics
    logger.info(f"Created DataFrame with {len(df)} images")
    for split in splits:
        split_count = df[df['split'] == split].shape[0]
        logger.info(f"{split.capitalize()} set: {split_count} images")
        for class_name in classes:
            class_count = df[(df['split'] == split) & (df['class_label'] == class_name)].shape[0]
            logger.info(f"  {class_name}: {class_count} images")
    
    # Save DataFrame
    os.makedirs(save_dir, exist_ok=True)
    df_path = os.path.join(save_dir, "dataset.csv")
    df.to_csv(df_path, index=False)
    logger.info(f"DataFrame saved to {df_path}")
    
    return df

# Example usage
data_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\Dataset\archive (1)\dataset"
save_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\plots"
df = create_dataframe(data_dir, save_dir)

2025-06-01 18:50:31,264 - INFO - Created DataFrame with 4200 images
2025-06-01 18:50:31,270 - INFO - Train set: 3000 images
2025-06-01 18:50:31,272 - INFO -   Normal_Sperm: 1000 images
2025-06-01 18:50:31,276 - INFO -   Non-Sperm: 1000 images
2025-06-01 18:50:31,280 - INFO -   Abnormal_Sperm: 1000 images
2025-06-01 18:50:31,283 - INFO - Validation set: 600 images
2025-06-01 18:50:31,286 - INFO -   Normal_Sperm: 200 images
2025-06-01 18:50:31,289 - INFO -   Non-Sperm: 200 images
2025-06-01 18:50:31,290 - INFO -   Abnormal_Sperm: 200 images
2025-06-01 18:50:31,290 - INFO - Test set: 600 images
2025-06-01 18:50:31,290 - INFO -   Normal_Sperm: 200 images
2025-06-01 18:50:31,297 - INFO -   Non-Sperm: 200 images
2025-06-01 18:50:31,297 - INFO -   Abnormal_Sperm: 200 images


2025-06-01 18:50:31,333 - INFO - DataFrame saved to O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\plots\dataset.csv


# Data Analytics

In [5]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def analyze_dataset(df, save_dir="plots", num_samples=3):
    """
    Analyze and visualize the Sperm Morphology Colorized dataset.

    Args:
        df (pd.DataFrame): DataFrame with image paths, class labels, and splits.
        save_dir (str): Directory to save plots and statistics.
        num_samples (int): Number of sample images to visualize per class.

    Returns:
        pd.DataFrame: DataFrame with image statistics (dimensions, brightness, etc.).
    """
    os.makedirs(save_dir, exist_ok=True)

    # --- DataFrame Inspection ---
    logger.info("DataFrame Analysis:")
    logger.info(f"DataFrame shape: {df.shape}")
    logger.info(f"Columns: {list(df.columns)}")
    
    # Check for NaN values
    nan_counts = df.isna().sum()
    logger.info("NaN values per column:")
    for col, count in nan_counts.items():
        logger.info(f"  {col}: {count}")
    
    # Describe DataFrame
    logger.info("DataFrame description:")
    logger.info(df.describe(include='all'))
    
    # --- Class Distribution Analysis ---
    class_counts = df.groupby(['split', 'class_label']).size().unstack(fill_value=0)
    logger.info("Class distribution by split:")
    logger.info(class_counts)
    
    # Plot 1: Class distribution bar plot
    plt.figure(figsize=(10, 6))
    class_counts.plot(kind='bar', stacked=False)
    plt.title("Class Distribution by Split")
    plt.xlabel("Split")
    plt.ylabel("Number of Images")
    plt.legend(title="Class")
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "class_distribution_bar.png"))
    plt.close()
    logger.info("Saved class distribution bar plot")

    # Plot 2: Class distribution pie chart (overall)
    overall_class_counts = df['class_label'].value_counts()
    plt.figure(figsize=(8, 8))
    plt.pie(overall_class_counts, labels=overall_class_counts.index, autopct='%1.1f%%', startangle=140)
    plt.title("Overall Class Distribution")
    plt.savefig(os.path.join(save_dir, "class_distribution_pie.png"))
    plt.close()
    logger.info("Saved class distribution pie plot")

    # --- Image Properties Analysis ---
    image_stats = []
    brightness_values = []
    channel_means = {'R': [], 'G': [], 'B': []}
    
    for idx, row in df.iterrows():
        img = cv2.imread(row['image_path'])
        if img is None:
            logger.warning(f"Failed to load image: {row['image_path']}")
            continue
        
        # Image dimensions
        height, width = img.shape[:2]
        
        # Brightness (mean intensity across all channels)
        brightness = np.mean(img)
        brightness_values.append(brightness)
        
        # Channel-wise means (BGR to RGB)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        channel_means['R'].append(np.mean(img_rgb[:, :, 0]))
        channel_means['G'].append(np.mean(img_rgb[:, :, 1]))
        channel_means['B'].append(np.mean(img_rgb[:, :, 2]))
        
        image_stats.append({
            "image_path": row['image_path'],
            "class_label": row['class_label'],
            "split": row['split'],
            "width": width,
            "height": height,
            "brightness": brightness
        })
    
    # Convert to DataFrame
    stats_df = pd.DataFrame(image_stats)
    
    # Save statistics
    stats_path = os.path.join(save_dir, "image_stats.csv")
    stats_df.to_csv(stats_path, index=False)
    logger.info(f"Saved image statistics to {stats_path}")
    
    # Log summary statistics
    if not stats_df.empty:
        logger.info("Image statistics summary:")
        logger.info(f"Mean dimensions: {stats_df['width'].mean():.0f}x{stats_df['height'].mean():.0f} pixels")
        logger.info(f"Mean brightness: {stats_df['brightness'].mean():.2f} ± {stats_df['brightness'].std():.2f}")
        logger.info(f"RGB channel means: R={np.mean(channel_means['R']):.2f}, G={np.mean(channel_means['G']):.2f}, B={np.mean(channel_means['B']):.2f}")

    # Plot 3: Image dimensions scatter plot
    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=stats_df, x='width', y='height', hue='class_label', style='split')
    plt.title("Image Dimensions by Class and Split")
    plt.xlabel("Width (pixels)")
    plt.ylabel("Height (pixels)")
    plt.savefig(os.path.join(save_dir, "image_dimensions_scatter.png"))
    plt.close()
    logger.info("Saved image dimensions scatter plot")

    # Plot 4: Brightness distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(data=stats_df, x='brightness', hue='class_label', multiple='stack', bins=30)
    plt.title("Brightness Distribution by Class")
    plt.xlabel("Mean Brightness")
    plt.ylabel("Frequency")
    plt.savefig(os.path.join(save_dir, "brightness_distribution.png"))
    plt.close()
    logger.info("Saved brightness distribution plot")

    # Plot 5: RGB channel distribution
    plt.figure(figsize=(10, 6))
    for channel, values in channel_means.items():
        sns.kdeplot(values, label=channel)
    plt.title("RGB Channel Intensity Distribution")
    plt.xlabel("Mean Intensity")
    plt.ylabel("Density")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "rgb_channel_distribution.png"))
    plt.close()
    logger.info("Saved RGB channel distribution plot")

    # Plot 6: Sample visualizations (num_samples per class)
    classes = df['class_label'].unique()
    for class_name in classes:
        class_df = df[df['class_label'] == class_name].sample(min(num_samples, len(df[df['class_label'] == class_name])), random_state=42)
        plt.figure(figsize=(15, 5))
        for idx, (_, row) in enumerate(class_df.iterrows()):
            img = cv2.imread(row['image_path'])
            if img is None:
                logger.warning(f"Skipping visualization for {row['image_path']}")
                continue
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            plt.subplot(1, num_samples, idx + 1)
            plt.imshow(img_rgb)
            plt.title(f"{class_name}\n{os.path.basename(row['image_path'])}")
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"sample_visualization_{class_name}.png"))
        plt.close()
        logger.info(f"Saved sample visualization for {class_name}")

    return stats_df


In [6]:

stats_df = analyze_dataset(df, save_dir)

2025-06-01 18:55:30,034 - INFO - DataFrame Analysis:
2025-06-01 18:55:30,037 - INFO - DataFrame shape: (4200, 3)
2025-06-01 18:55:30,039 - INFO - Columns: ['image_path', 'class_label', 'split']
2025-06-01 18:55:30,039 - INFO - NaN values per column:
2025-06-01 18:55:30,039 - INFO -   image_path: 0
2025-06-01 18:55:30,039 - INFO -   class_label: 0
2025-06-01 18:55:30,039 - INFO -   split: 0
2025-06-01 18:55:30,039 - INFO - DataFrame description:
2025-06-01 18:55:30,064 - INFO -                                                image_path   class_label  split
count                                                4200          4200   4200
unique                                               4200             3      3
top     O:\O drive\AI\my project\medical image project...  Normal_Sperm  train
freq                                                    1          1400   3000
2025-06-01 18:55:30,079 - INFO - Class distribution by split:
2025-06-01 18:55:30,079 - INFO - class_label  Abnormal_Sperm 

<Figure size 1000x600 with 0 Axes>

# DataPrepration and Loadering

In [21]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import random
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('data_preparation.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Verify PyTorch installation
try:
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
except AttributeError as e:
    logger.error(f"PyTorch installation issue: {e}")
    raise
import cv2
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import logging

logger = logging.getLogger(__name__)

class SpermMorphologyDataset:
    def __init__(self, df, transform=None, phase='train', expected_size=(224, 224)):
        """
        Initialize the Sperm Morphologyology Dataset.

        Args:
            df (pd.DataFrame): DataFrame with image paths and labels
            transform (A.Compose, optional): Albumentations transformation pipeline
            phase (str): Dataset phase ('train', 'validation', 'test')
            expected_size (tuple): Expected image size (height, width)
        """
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.phase = phase
        self.expected_size = expected_size
        self.class_to_idx = {'Normal_Sperm': 0, 'Non-Sperm': 1, 'Abnormal_Sperm': 2}
        valid_indices, invalid_images = self._validate_images()
        self.df = self.df.iloc[valid_indices].reset_index(drop=True)
        self.invalid_images = invalid_images  # Store invalid_images as instance variable
        if self.invalid_images:
            logger.warning(f"Found {len(self.invalid_images)} invalid images in {self.phase} set")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_path']
        label_str = self.df.iloc[idx]['class_label']
        try:
            label = self.class_to_idx[label_str]
        except KeyError:
            logger.error(f"Unknown class label '{label_str}' at index {idx}")
            raise ValueError(f"Class label '{label_str}' not found in class_to_idx")

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']

        return img, label

    def _validate_images(self):
        valid_indices = []
        invalid_images = []
        validate_transform = A.Compose([
            A.Resize(height=self.expected_size[0], width=self.expected_size[1]),
            ToTensorV2()
        ])
        for idx in range(len(self.df)):
            img_path = self.df.iloc[idx]['image_path']
            try:
                img = cv2.imread(img_path)
                if img is None:
                    raise ValueError("Failed to load image")
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                transformed = validate_transform(image=img)
                img_transformed = transformed['image']
                if img_transformed.shape[1:] != self.expected_size:
                    raise ValueError(f"Transformed image size {img_transformed.shape[1:]} != {self.expected_size}")
                valid_indices.append(idx)
            except Exception as e:
                logger.warning(f"Invalid image at index {idx}: {img_path}, error: {e}")
                invalid_images.append({'index': idx, 'path': img_path, 'error': str(e)})
        logger.info(f"Validated {len(valid_indices)}/{len(self.df)} images in {self.phase} set")
        return valid_indices, invalid_images

        
def get_transforms(phase='train', image_size=(224, 224)):
    if phase == 'train':
        transform = A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),  # Ensure all images are 224x224
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=30, p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:  # validation or test
        transform = A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    return transform

def create_data_loaders(df, batch_size=32, image_size=(224, 224), random_seed=42, num_workers=0):
    """
    Create PyTorch DataLoaders for training, validation, and testing.

    Args:
        df (pd.DataFrame): DataFrame with image paths, class labels, and splits.
        batch_size (int): Batch size for data loaders.
        image_size (tuple): Target image size (height, width).
        random_seed (int): Seed for reproducibility.
        num_workers (int): Number of subprocesses for data loading.

    Returns:
        dict: Dictionary of DataLoader objects for 'train', 'val', and 'test'.
    """
    # Set random seeds
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed)

    # Split DataFrame by pre-defined splits
    splits = ['train', 'validation', 'test']
    datasets = {}
    invalid_images_all = []
    for split in splits:
        split_df = df[df['split'] == split].reset_index(drop=True)
        if split_df.empty:
            logger.error(f"No data found for {split} split")
            raise ValueError(f"Empty {split} split")
        datasets[split] = SpermMorphologyDataset(
            split_df,
            transform=get_transforms(phase='train' if split == 'train' else 'val', image_size=image_size),
            phase=split,
            expected_size=image_size
        )
        invalid_images_all.extend(datasets[split].invalid_images)
        logger.info(f"{split.capitalize()} set: {len(datasets[split])} images")

    # Save invalid images report
    if invalid_images_all:
        invalid_df = pd.DataFrame(invalid_images_all)
        invalid_df.to_csv(os.path.join(save_dir, 'invalid_images.csv'), index=False)
        logger.info(f"Saved invalid images report to {os.path.join(save_dir, 'invalid_images.csv')}")

    # Create data loaders
    data_loaders = {}
    for split in splits:
        data_loaders[split] = DataLoader(
            datasets[split],
            batch_size=batch_size,
            shuffle=(split == 'train'),
            num_workers=num_workers,
            pin_memory=True,
            drop_last=(split == 'train')  # Drop last incomplete batch for training
        )

    return data_loaders

def visualize_samples(data_loaders, save_dir, num_samples=3):
    """
    Visualize sample images from each DataLoader.

    Args:
        data_loaders (dict): Dictionary of DataLoaders.
        save_dir (str): Directory to save visualizations.
        num_samples (int): Number of samples per class to visualize.
    """
    os.makedirs(save_dir, exist_ok=True)
    class_names = ['Normal_Sperm', 'Non_Sperm', 'Abnormal_Sperm']
    
    for phase in data_loaders:
        plt.figure(figsize=(15, 5*num_samples))
        samples_per_class = {0: 0, 1: 0, 2: 0}
        plot_idx = 1
        
        for images, labels in data_loaders[phase]:
            for img, lbl in zip(images, labels):
                if samples_per_class[lbl.item()] < num_samples:
                    # Denormalize for visualization
                    img_np = img.permute(1, 2, 0).numpy()
                    mean = np.array([0.485, 0.456, 0.406])
                    std = np.array([0.229, 0.224, 0.225])
                    img_np = std * img_np + mean
                    img_np = np.clip(img_np, 0, 1)
                    
                    plt.subplot(num_samples, 3, plot_idx)
                    plt.imshow(img_np)
                    plt.title(f"{phase.capitalize()}: {class_names[lbl.item()]}")
                    plt.axis('off')
                    samples_per_class[lbl.item()] += 1
                    plot_idx += 1
                
                if all(count >= num_samples for count in samples_per_class.values()):
                    break
            if all(count >= num_samples for count in samples_per_class.values()):
                break
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{phase}_samples.png"))
        plt.close()
        logger.info(f"Saved {phase} sample visualization")

def validate_data_loaders(data_loaders):
    """
    Validate DataLoader outputs and log batch statistics.

    Args:
        data_loaders (dict): Dictionary of DataLoaders.
    """
    for phase, loader in data_loaders.items():
        try:
            batch_count = 0
            class_counts = {0: 0, 1: 0, 2: 0}
            for images, labels in loader:
                batch_count += 1
                logger.info(f"{phase.capitalize()} batch {batch_count} - Images shape: {images.shape}, Labels shape: {labels.shape}")
                for lbl in labels:
                    class_counts[lbl.item()] += 1
                if batch_count == 1:  # Log first batch only for brevity
                    break
            logger.info(f"{phase.capitalize()} class distribution in first batch: {class_counts}")
        except Exception as e:
            logger.error(f"Error validating {phase} DataLoader: {e}")
            raise

def save_preparation_report(df, data_loaders, save_dir):
    """
    Save a report summarizing data preparation.

    Args:
        df (pd.DataFrame): Input DataFrame.
        data_loaders (dict): Dictionary of DataLoaders.
        save_dir (str): Directory to save report.
    """
    os.makedirs(save_dir, exist_ok=True)
    report_path = os.path.join(save_dir, "data_preparation_report.txt")
    
    with open(report_path, 'w') as f:
        f.write("Sperm Morphology Dataset Preparation Report\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Total images: {len(df)}\n")
        for split in ['train', 'validation', 'test']:
            split_count = len(df[df['split'] == split])
            f.write(f"{split.capitalize()} set: {split_count} images\n")
            for cls in ['Normal_Sperm', 'Non_Sperm', 'Abnormal_Sperm']:
                cls_count = len(df[(df['split'] == split) & (df['class_label'] == cls)])
                f.write(f"  {cls}: {cls_count} images\n")
        f.write("\nBatch sizes:\n")
        for phase, loader in data_loaders.items():
            f.write(f"{phase.capitalize()}: {loader.batch_size}\n")
    
    logger.info(f"Saved preparation report to {report_path}")


2025-06-01 19:23:42,140 - INFO - PyTorch version: 2.6.0+cu126
2025-06-01 19:23:42,141 - INFO - CUDA available: True


In [22]:

# df_path = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\plots\dataset.csv"
# save_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\plots"

# try:
#     df = pd.read_csv(df_path)
# except FileNotFoundError:
#     logger.error(f"Dataset CSV not found at {df_path}")
#     raise

# Create DataLoaders
data_loaders = create_data_loaders(
    df,
    batch_size=32,  # Reverted to 32 for stability
    image_size=(224, 224),
    random_seed=42,
    num_workers=0
)

# Validate DataLoaders
validate_data_loaders(data_loaders)

# Visualize samples
visualize_samples(data_loaders, save_dir, num_samples=3)

# Save preparation report
save_preparation_report(df, data_loaders, save_dir)

2025-06-01 19:23:44,062 - INFO - Validated 3000/3000 images in train set
2025-06-01 19:23:44,062 - INFO - Train set: 3000 images
2025-06-01 19:23:44,424 - INFO - Validated 600/600 images in validation set
2025-06-01 19:23:44,430 - INFO - Validation set: 600 images
2025-06-01 19:23:44,779 - INFO - Validated 600/600 images in test set
2025-06-01 19:23:44,783 - INFO - Test set: 600 images
2025-06-01 19:23:46,036 - INFO - Train batch 1 - Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
2025-06-01 19:23:46,074 - INFO - Train class distribution in first batch: {0: 10, 1: 14, 2: 8}
2025-06-01 19:23:46,144 - INFO - Validation batch 1 - Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
2025-06-01 19:23:46,144 - INFO - Validation class distribution in first batch: {0: 32, 1: 0, 2: 0}
2025-06-01 19:23:46,206 - INFO - Test batch 1 - Images shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])
2025-06-01 19:23:46,206 - INFO - Test 

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import logging
import os
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import seaborn as sns
import torchvision.models as models

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
logger.info(f"PyTorch version: {torch.__version__}")

# Model Definitions
def get_efficientnet_b0(num_classes=3):
    """Initialize EfficientNet-B0 with pre-trained weights."""
    model = models.efficientnet_b0(weights='IMAGENET1K_V1')
    # Modify classifier for 3 classes
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    return model

def get_mobilenet_v3_small(num_classes=3):
    """Initialize MobileNetV3-Small with pre-trained weights."""
    model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')
    # Modify classifier for 3 classes
    model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    return model

# Evaluation Metrics
def compute_metrics(outputs, labels):
    """Compute accuracy and F1-score."""
    _, preds = torch.max(outputs, 1)
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    return accuracy, f1, preds, labels

# Visualization Helper Functions
def plot_metrics(history, save_dir, model_name):
    """Plot training and validation metrics."""
    os.makedirs(save_dir, exist_ok=True)
    epochs = range(1, len(history['train_loss']) + 1)

    plt.figure(figsize=(15, 5))

    # Plot Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, history['train_loss'], label='Train Loss')
    plt.plot(epochs, history['val_loss'], label='Val Loss')
    plt.title(f'{model_name} Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot Accuracy
    plt.subplot(1, 3, 2)
    plt.plot(epochs, history['val_accuracy'], label='Val Accuracy')
    plt.title(f'{model_name} Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot F1 Score
    plt.subplot(1, 3, 3)
    plt.plot(epochs, history['val_f1'], label='Val F1')
    plt.title(f'{model_name} F1 Score per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{model_name}_training_metrics.png'))
    plt.close()

def plot_confusion_matrix(all_preds, all_labels, class_names, save_dir, model_name):
    """Plot and save confusion matrix."""
    os.makedirs(save_dir, exist_ok=True)
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(save_dir, f'{model_name}_confusion_matrix.png'))
    plt.close()

# Training Function
def train_model(model, model_name, train_loader, val_loader, test_loader, num_epochs=50, lr=1e-3, patience=10, save_dir="models"):
    """Train the model with early stopping, scheduler, and evaluation."""
    os.makedirs(save_dir, exist_ok=True)
    try:
        model = model.to(device)
        logger.info(f"{model_name} moved to {device}")
    except Exception as e:
        logger.error(f"Failed to move {model_name} to {device}: {e}")
        raise

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    best_val_loss = float('inf')
    patience_counter = 0
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_bar = tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} [Train]")
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            train_bar.set_postfix(loss=loss.item())

        train_loss /= len(train_loader.dataset)
        history['train_loss'].append(train_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        val_f1 = 0.0
        all_preds = []
        all_labels = []
        val_bar = tqdm(val_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} [Val]")
        with torch.no_grad():
            for images, labels in val_bar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                accuracy, f1, preds, lbls = compute_metrics(outputs, labels)
                val_accuracy += accuracy * images.size(0)
                val_f1 += f1 * images.size(0)
                all_preds.extend(preds)
                all_labels.extend(lbls)
                val_bar.set_postfix(loss=loss.item(), accuracy=accuracy)

        val_loss /= len(val_loader.dataset)
        val_accuracy /= len(val_loader.dataset)
        val_f1 /= len(val_loader.dataset)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_accuracy)
        history['val_f1'].append(val_f1)

        logger.info(f"{model_name} Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}")
        logger.info(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}, F1: {val_f1:.4f}")

        # Save metrics to CSV
        metrics_df = pd.DataFrame(history)
        metrics_df.to_csv(os.path.join(save_dir, f'{model_name}_training_metrics.csv'), index=False)

        # Plot metrics
        plot_metrics(history, save_dir, model_name)

        # Plot confusion matrix for validation set in final epoch
        if epoch == num_epochs - 1:
            plot_confusion_matrix(all_preds, all_labels, ['Normal_Sperm', 'Non_Sperm', 'Abnormal_Sperm'], save_dir, model_name)

        # Scheduler step
        scheduler.step(val_loss)

        # Early stopping and checkpointing
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), os.path.join(save_dir, f"{model_name}_best_model.pth"))
            logger.info(f"Saved best {model_name} model")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered for {model_name}")
                break

    # Evaluate on test set
    model.eval()
    test_loss = 0.0
    test_accuracy = 0.0
    test_f1 = 0.0
    all_preds = []
    all_labels = []
    test_bar = tqdm(test_loader, desc=f"{model_name} [Test]")
    with torch.no_grad():
        for images, labels in test_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)
            accuracy, f1, preds, lbls = compute_metrics(outputs, labels)
            test_accuracy += accuracy * images.size(0)
            test_f1 += f1 * images.size(0)
            all_preds.extend(preds)
            all_labels.extend(lbls)
            test_bar.set_postfix(loss=loss.item(), accuracy=accuracy)

    test_loss /= len(test_loader.dataset)
    test_accuracy /= len(test_loader.dataset)
    test_f1 /= len(test_loader.dataset)
    logger.info(f"{model_name} Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}, F1: {test_f1:.4f}")

    # Save test confusion matrix
    plot_confusion_matrix(all_preds, all_labels, ['Normal_Sperm', 'Non_Sperm', 'Abnormal_Sperm'], save_dir, f"{model_name}_test")

    # Save test metrics
    test_metrics = {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'test_f1': test_f1
    }
    pd.DataFrame([test_metrics]).to_csv(os.path.join(save_dir, f'{model_name}_test_metrics.csv'), index=False)

    return model, history


2025-06-01 19:24:57,261 - INFO - Using device: cuda
2025-06-01 19:24:57,261 - INFO - PyTorch version: 2.6.0+cu126


# Efficientnet_b0 Model

In [24]:

# # Paths
# df_path = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\plots\dataset.csv"
# save_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Colorized\models"

# # Load DataFrame
# try:
#     df = pd.read_csv(df_path)
# except FileNotFoundError:
#     logger.error(f"Dataset CSV not found at {df_path}")
#     raise

# # Create DataLoaders
# data_loaders = create_data_loaders(df, batch_size=32, image_size=(224, 224), random_seed=42, num_workers=0)

# Train EfficientNet-B0
efficientnet = get_efficientnet_b0(num_classes=3)
efficientnet, eff_history = train_model(
    efficientnet,
    model_name="EfficientNetB0",
    train_loader=data_loaders['train'],
    val_loader=data_loaders['validation'],
    test_loader=data_loaders['test'],
    num_epochs=50,
    lr=1e-3,
    patience=10,
    save_dir=save_dir
)



2025-06-01 19:25:15,258 - INFO - EfficientNetB0 moved to cuda
EfficientNetB0 Epoch 1/50 [Train]: 100%|██████████| 93/93 [00:28<00:00,  3.25it/s, loss=0.363]
EfficientNetB0 Epoch 1/50 [Val]: 100%|██████████| 19/19 [00:01<00:00, 11.37it/s, accuracy=0.917, loss=0.101]
2025-06-01 19:25:45,588 - INFO - EfficientNetB0 Epoch 1/50 - Train Loss: 0.5678
2025-06-01 19:25:45,588 - INFO - Validation Loss: 0.5976, Accuracy: 0.7917, F1: 0.8694
2025-06-01 19:25:46,144 - INFO - Saved best EfficientNetB0 model
EfficientNetB0 Epoch 2/50 [Train]: 100%|██████████| 93/93 [00:27<00:00,  3.39it/s, loss=0.376]
EfficientNetB0 Epoch 2/50 [Val]: 100%|██████████| 19/19 [00:01<00:00, 12.06it/s, accuracy=0.875, loss=0.38]  
2025-06-01 19:26:15,126 - INFO - EfficientNetB0 Epoch 2/50 - Train Loss: 0.4279
2025-06-01 19:26:15,126 - INFO - Validation Loss: 0.3196, Accuracy: 0.8917, F1: 0.9386
2025-06-01 19:26:15,683 - INFO - Saved best EfficientNetB0 model
EfficientNetB0 Epoch 3/50 [Train]: 100%|██████████| 93/93 [00:27<

# Mobilenet_v3 Model

In [25]:
# Train MobileNetV3-Small
mobilenet = get_mobilenet_v3_small(num_classes=3)
mobilenet, mob_history = train_model(
    mobilenet,
    model_name="MobileNetV3Small",
    train_loader=data_loaders['train'],
    val_loader=data_loaders['validation'],
    test_loader=data_loaders['test'],
    num_epochs=50,
    lr=1e-3,
    patience=10,
    save_dir=save_dir
)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to C:\Users\alira/.cache\torch\hub\checkpoints\mobilenet_v3_small-047dcff4.pth
100%|██████████| 9.83M/9.83M [00:04<00:00, 2.37MB/s]
2025-06-01 19:38:20,619 - INFO - MobileNetV3Small moved to cuda
MobileNetV3Small Epoch 1/50 [Train]: 100%|██████████| 93/93 [00:21<00:00,  4.35it/s, loss=0.726]
MobileNetV3Small Epoch 1/50 [Val]: 100%|██████████| 19/19 [00:01<00:00, 14.39it/s, accuracy=0.667, loss=1.21] 
2025-06-01 19:38:43,324 - INFO - MobileNetV3Small Epoch 1/50 - Train Loss: 0.5407
2025-06-01 19:38:43,324 - INFO - Validation Loss: 0.6395, Accuracy: 0.8133, F1: 0.8781
2025-06-01 19:38:44,420 - INFO - Saved best MobileNetV3Small model
MobileNetV3Small Epoch 2/50 [Train]: 100%|██████████| 93/93 [00:21<00:00,  4.29it/s, loss=0.322]
MobileNetV3Small Epoch 2/50 [Val]: 100%|██████████| 19/19 [00:01<00:00, 14.24it/s, accuracy=0.917, loss=0.192]
2025-06-01 19:39:07,437 - INFO - MobileNetV3Small Epoch 2/50 - Train 