In [1]:
import os
os.environ["TORCH_HOME"] = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)"

import torch
print(torch.hub.get_dir())
print(os.environ["TORCH_HOME"])


O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)\hub
O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)


In [2]:

torch.cuda.empty_cache()

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Sperm Morphology images dataset (SMIDS)

# XXXXXX

In [31]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from sklearn.utils import resample
import random
import logging
from torchvision import models
from torchvision.models import ResNet18_Weights

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Handle truncated/corrupted images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Step 1: Create DataFrames
def create_dataframes(root_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """
    Create DataFrames for train, val, and test splits from SMIDS dataset.
    
    Args:
        root_dir (str): Root directory of the SMIDS dataset.
        train_ratio, val_ratio, test_ratio (float): Split ratios.
    
    Returns:
        tuple: (train_df, val_df, test_df, class_to_idx)
    """
    valid_extensions = {'.png', '.jpg', '.jpeg', '.tif', '.bmp'}
    data = []
    class_names = ['Abnormal_Sperm', 'Non-Sperm', 'Normal_Sperm']
    class_to_idx = {name: idx for idx, name in enumerate(class_names)}

    # Verify directories
    actual_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
    logger.info(f"Directories found in {root_dir}: {actual_dirs}")

    # Collect image paths and labels
    for class_name in class_names:
        class_dir = os.path.join(root_dir, class_name)
        if not os.path.isdir(class_dir):
            logger.warning(f"Directory {class_dir} does not exist!")
            continue
        label = class_to_idx[class_name]
        for file in os.listdir(class_dir):
            if os.path.splitext(file)[1].lower() in valid_extensions:
                image_path = os.path.join(class_dir, file)
                data.append({
                    "file_path": image_path,
                    "label": label,
                    "class_name": class_name
                })

    df = pd.DataFrame(data)
    if df.empty:
        logger.error("No images found in dataset!")
        raise ValueError("No valid images found.")

    # Stratified split
    train_df, temp_df = train_test_split(
        df, train_size=train_ratio, stratify=df['label'], random_state=42
    )
    val_ratio_adj = val_ratio / (val_ratio + test_ratio)
    val_df, test_df = train_test_split(
        temp_df, train_size=val_ratio_adj, stratify=temp_df['label'], random_state=42
    )

    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)

    for df, split in zip([train_df, val_df, test_df], ["Train", "Validation", "Test"]):
        logger.info(f"{split} Dataset: {len(df)} images")
        logger.info(f"Class distribution:\n{df['class_name'].value_counts()}")

    return train_df, val_df, test_df, class_to_idx

# Step 2: Resample Train DataFrame
def resample_train_df(train_df, class_to_idx, save_dir):
    """
    Balance the training DataFrame by oversampling minority classes.
    
    Args:
        train_df (pd.DataFrame): Training DataFrame.
        class_to_idx (dict): Class name to index mapping.
        save_dir (str): Directory to save plots.
    
    Returns:
        pd.DataFrame: Balanced training DataFrame.
    """
    class_counts = train_df['label'].value_counts()
    majority_count = class_counts.max()
    majority_label = class_counts.idxmax()

    dfs_by_class = [train_df[train_df['label'] == label] for label in class_counts.index]
    balanced_dfs = []
    for df_class, label in zip(dfs_by_class, class_counts.index):
        if label == majority_label:
            balanced_dfs.append(df_class)
        else:
            df_oversampled = resample(
                df_class, replace=True, n_samples=majority_count, random_state=42
            )
            balanced_dfs.append(df_oversampled)

    train_df_balanced = pd.concat(balanced_dfs)
    train_df_balanced = train_df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

    logger.info("New class distribution after oversampling:")
    logger.info(f"{train_df_balanced['label'].value_counts()}")

    # Plot class distribution
    class_names = sorted(class_to_idx, key=class_to_idx.get)
    plt.figure(figsize=(8, 6))
    counts = train_df_balanced['label'].value_counts().sort_index()
    plt.bar(counts.index, counts.values)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title('Balanced Train Set Class Distribution')
    plt.xticks(counts.index, class_names, rotation=45)
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "train_balanced_class_distribution.png"))
    plt.close()

    return train_df_balanced

# Step 3: Data Transformations
def get_transforms(split):
    """
    Define Albumentations transforms using ImageNet mean and std.
    
    Args:
        split (str): Dataset split ('train', 'val', 'test').
    
    Returns:
        A.Compose: Transformation pipeline.
    """
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    if split == "train":
        return A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=15, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
            A.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ])
    return A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])

class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img = np.array(img)
        # Clamp pixel values to [0, 255] to prevent invalid values
        img = np.clip(img, 0, 255)
        augmented = self.transform(image=img)
        return augmented['image']

# Step 4: Custom Dataset
class DataFrameDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.classes = sorted(self.df['class_name'].unique())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.df = self._validate_and_filter_dataset()

    def _validate_and_filter_dataset(self):
        valid_rows = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            img_path = row['file_path']
            label = row['label']
            if not os.path.isfile(img_path):
                logger.warning(f"Invalid file path at index {idx}: {img_path}")
                continue
            if not isinstance(label, (int, np.integer)) or label not in range(len(self.classes)):
                logger.warning(f"Invalid label at index {idx}: {label}")
                continue
            try:
                img = Image.open(img_path).convert('RGB')
                img_array = np.array(img)
                if np.isnan(img_array).any() or np.isinf(img_array).any():
                    logger.warning(f"NaN/Inf pixels in image at index {idx}: {img_path}")
                    img.close()
                    continue
                if img_array.max() > 255 or img_array.min() < 0:
                    logger.warning(f"Out-of-range pixels in image at index {idx}: {img_path}")
                img.close()
            except Exception as e:
                logger.warning(f"Corrupted image at index {idx}: {img_path}, error: {e}")
                continue
            valid_rows.append(idx)
        if not valid_rows:
            raise ValueError("No valid items in dataset after filtering")
        filtered_df = self.df.iloc[valid_rows].reset_index(drop=True)
        logger.info(f"Filtered out {len(self.df) - len(filtered_df)} invalid rows")
        return filtered_df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['file_path']
        label = row['label']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {e}")
            return None
        if self.transform:
            img = self.transform(img)
            if torch.isnan(img).any() or torch.isinf(img).any():
                logger.error(f"NaN/Inf in transformed image at index {idx}: {img_path}")
                return None
        label_tensor = torch.tensor(label, dtype=torch.long)
        return img, label_tensor

# Step 5: Create DataLoaders
def create_data_loaders(train_df, val_df, test_df, batch_size=64):
    """
    Create DataLoaders for the dataset using ImageNet normalization.
    
    Args:
        train_df, val_df, test_df (pd.DataFrame): DataFrames for each split.
        batch_size (int): Batch size.
    
    Returns:
        tuple: (train_loader, val_loader, test_loader, num_classes)
    """
    train_transform = AlbumentationsTransform(get_transforms("train"))
    val_test_transform = AlbumentationsTransform(get_transforms("val"))

    train_dataset = DataFrameDataset(train_df, transform=train_transform)
    val_dataset = DataFrameDataset(val_df, transform=val_test_transform)
    test_dataset = DataFrameDataset(test_df, transform=val_test_transform)

    num_classes = len(train_dataset.classes)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )

    logger.info(f"Number of classes: {num_classes}")
    logger.info(f"Training dataset size: {len(train_loader.dataset)}")
    logger.info(f"Validation dataset size: {len(val_loader.dataset)}")
    logger.info(f"Test dataset size: {len(test_loader.dataset)}")
    return train_loader, val_loader, test_loader, num_classes

# Step 6: Model Loader
def get_model(model_name, num_classes, device):
    """
    Load a pre-trained ResNet18 model for full fine-tuning.
    
    Args:
        model_name (str): 'resnet18'.
        num_classes (int): Number of classes (3).
        device (torch.device): Device to run the model.
    
    Returns:
        nn.Module: Configured model.
    """
    logger.info(f"Loading model: {model_name} with {num_classes} classes on {device}")
    if model_name == "resnet18":
        model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
        if model.fc.bias is not None:
            nn.init.constant_(model.fc.bias, 0)
    else:
        raise ValueError(f"Model {model_name} not supported.")
    return model.to(device)

# Step 7: Early Stopping
class EarlyStopping:
    def __init__(self, patience=10, verbose=True, save_dir=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.best_epoch = None
        self.early_stop = False
        self.save_dir = save_dir

    def __call__(self, val_loss, epoch, model_weights, model_name_prefix):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0
            self.save_best_weights(model_weights, model_name_prefix)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    logger.info(f"Early stopping triggered after {self.counter} epochs.")

    def save_best_weights(self, model_weights, model_name_prefix):
        os.makedirs(self.save_dir, exist_ok=True)
        model_path = os.path.join(self.save_dir, f"{model_name_prefix}_epoch_{self.best_epoch + 1}.pth")
        torch.save(model_weights, model_path)
        if self.verbose:
            logger.info(f"Best model weights saved to {model_path}")

# Step 8: Training and Validation
def train_and_validate(model, train_loader, val_loader, optimizer, scheduler, model_name_prefix, 
                       epochs=25, device=None, early_stopping=None, save_dir=None):
    """
    Train and validate the model with cross-entropy loss.
    
    Args:
        model, train_loader, val_loader, optimizer, scheduler, model_name_prefix, epochs,
        device, early_stopping, save_dir
    
    Returns:
        nn.Module: Trained model.
    """
    os.makedirs(save_dir, exist_ok=True)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

    for epoch in range(epochs):
        logger.info(f"Epoch {epoch + 1}/{epochs}")

        # Training
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            if inputs is None or labels is None:
                logger.debug(f"Skipping invalid training batch {batch_idx}")
                continue

            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = torch.clamp(outputs, min=-100, max=100)
            loss = criterion(outputs, labels)

            if torch.isnan(loss) or torch.isinf(loss):
                logger.error(f"NaN/Inf loss at batch {batch_idx}: loss={loss.item()}")
                logger.debug(f"Input range: {inputs.min().item():.4f} to {inputs.max().item():.4f}")
                logger.debug(f"Output range: {outputs.min().item():.4f} to {outputs.max().item():.4f}")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)

            if batch_idx % 10 == 0:
                logger.info(f"Batch {batch_idx}: loss={loss.item():.4f}")

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / total_preds
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        logger.info(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

        # Validation
        model.eval()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                if inputs is None or labels is None:
                    continue
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs = torch.clamp(outputs, min=-100, max=100)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                correct_preds += (preds == labels).sum().item()
                total_preds += labels.size(0)

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = correct_preds / total_preds
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_acc)
        logger.info(f"Validation Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

        scheduler.step(epoch_loss)
        if early_stopping:
            early_stopping(epoch_loss, epoch, model.state_dict(), model_name_prefix)
            if early_stopping.early_stop:
                logger.info(f"Early stopping at epoch {epoch + 1}")
                break

        torch.cuda.empty_cache()

    if early_stopping and early_stopping.best_epoch is not None:
        best_model_path = os.path.join(save_dir, f"{model_name_prefix}_epoch_{early_stopping.best_epoch + 1}.pth")
        model.load_state_dict(torch.load(best_model_path))
        logger.info(f"Loaded best model from epoch {early_stopping.best_epoch + 1}")

    # Plot losses and accuracies
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(valid_losses) + 1), valid_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Acc')
    plt.plot(range(1, len(valid_accuracies) + 1), valid_accuracies, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'metrics.png'))
    plt.close()

    return model

# Step 9: Testing
def test_model(model, test_loader, device, save_dir):
    """
    Test the model on the test set.
    
    Args:
        model, test_loader, device, save_dir
    """
    model.eval()
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0
    all_labels, all_preds, all_probs = [], [], []

    with torch.no_grad():
        for inputs, labels in test_loader:
            if inputs is None or labels is None:
                continue
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            outputs = torch.clamp(outputs, min=-100, max=100)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    test_loss = running_loss / len(test_loader.dataset)
    test_accuracy = correct_preds / total_preds
    logger.info(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    # Metrics
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    auc_scores = {}
    for i, class_name in enumerate(test_loader.dataset.classes):
        binary_labels = (all_labels == i).astype(int)
        auc_scores[class_name] = roc_auc_score(binary_labels, all_probs[:, i])
    logger.info(f"Test AUC-ROC Scores: {auc_scores}")

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=test_loader.dataset.classes,
                yticklabels=test_loader.dataset.classes)
    plt.title("Test Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(os.path.join(save_dir, "test_confusion_matrix.png"))
    plt.close()

    report = classification_report(all_labels, all_preds, 
                                  target_names=test_loader.dataset.classes, digits=4)
    with open(os.path.join(save_dir, "test_classification_report.txt"), "w") as f:
        f.write(report)
    logger.info(f"Test Classification Report saved to {save_dir}/test_classification_report.txt")


2025-05-31 14:32:14,253 - INFO - Using device: cuda


In [32]:

# Main Execution
if __name__ == "__main__":
    root_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)\dataset\archive\SMIDS"
    save_dir = r"O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)\plots"

    # Data preparation
    train_df, val_df, test_df, class_to_idx = create_dataframes(root_dir)
    train_df_balanced = resample_train_df(train_df, class_to_idx, save_dir)
    train_loader, val_loader, test_loader, num_classes = create_data_loaders(
        train_df_balanced, val_df, test_df, batch_size=64
    )

    # Model setup
    model = get_model("resnet18", num_classes, device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    early_stopping = EarlyStopping(patience=10, verbose=True, save_dir=save_dir)
    model_name_prefix = "resnet18_sperm"

    # Train and test
    model = train_and_validate(
        model, train_loader, val_loader, optimizer, scheduler,
        model_name_prefix, epochs=25, device=device, early_stopping=early_stopping,
        save_dir=save_dir
    )
    test_model(model, test_loader, device, save_dir)

2025-05-31 14:32:16,868 - INFO - Directories found in O:\O drive\AI\my project\medical image projects\Sperm Morphology Image Data Set (SMIDS)\dataset\archive\SMIDS: ['Abnormal_Sperm', 'Non-Sperm', 'Normal_Sperm']
2025-05-31 14:32:16,890 - INFO - Train Dataset: 2100 images
2025-05-31 14:32:16,891 - INFO - Class distribution:
class_name
Normal_Sperm      715
Abnormal_Sperm    703
Non-Sperm         682
Name: count, dtype: int64
2025-05-31 14:32:16,891 - INFO - Validation Dataset: 450 images
2025-05-31 14:32:16,893 - INFO - Class distribution:
class_name
Normal_Sperm      153
Abnormal_Sperm    151
Non-Sperm         146
Name: count, dtype: int64
2025-05-31 14:32:16,893 - INFO - Test Dataset: 450 images
2025-05-31 14:32:16,895 - INFO - Class distribution:
class_name
Normal_Sperm      153
Abnormal_Sperm    151
Non-Sperm         146
Name: count, dtype: int64
2025-05-31 14:32:16,903 - INFO - New class distribution after oversampling:
2025-05-31 14:32:16,906 - INFO - label
0    715
1    715
2   