In [2]:
import os
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.transforms import Compose, ToTensor, Resize, Normalize

In [3]:
class EyeDataset(Dataset):
    def __init__(self, img_paths, img_labels, transform=None):
        self.img_paths = img_paths
        self.img_labels = img_labels
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.img_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
def load_data(img_dir):
    img_paths = []
    img_labels = []

    for img_file in os.listdir(img_dir):
        img_path = os.path.join(img_dir, img_file)
        if img_file.endswith('_bacterial.JPG'):
            img_paths.append(img_path)
            img_labels.append(0)
        elif img_file.endswith('_fungal.JPG'):
            img_paths.append(img_path)
            img_labels.append(1)

    print(f"Loaded {len(img_paths)} images.")  # Debugging line
    return img_paths, img_labels


In [5]:
def get_dataloaders(img_dir, batch_size):
    img_paths, img_labels = load_data(img_dir)

    # Verify the dataset lengths
    print(f"Total images: {len(img_paths)}")
    print(f"Total labels: {len(img_labels)}")

    if len(img_paths) == 0 or len(img_labels) == 0:
        raise ValueError("No images were loaded. Check the directory and filenames.")

    # Split the data
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        img_paths, img_labels, test_size=0.28, stratify=img_labels, random_state=42
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.7143, stratify=temp_labels, random_state=42
    )

    # Verify the split lengths
    print(f"Train set size: {len(train_paths)}")
    print(f"Validation set size: {len(val_paths)}")
    print(f"Test set size: {len(test_paths)}")

    transform = Compose([
        Resize((300, 300)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = EyeDataset(train_paths, train_labels, transform=transform)
    val_dataset = EyeDataset(val_paths, val_labels, transform=transform)
    test_dataset = EyeDataset(test_paths, test_labels, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, test_loader

In [6]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 10, kernel_size=5, padding=2),
            nn.BatchNorm2d(10),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(10, 20, kernel_size=10, padding=5),
            nn.BatchNorm2d(20),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(20, 30, kernel_size=20, padding=10),
            nn.BatchNorm2d(30),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(30, 15),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(15, 5),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(5, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [7]:
def train_model(model, criterion, optimizer, dataloaders, num_epochs, device):
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            all_labels = []
            all_preds = []
            all_probs = []

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device).float().view(-1)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs).view(-1)
                    probs = torch.sigmoid(outputs)
                    preds = probs.round()
                    criterion = nn.BCEWithLogitsLoss()
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                all_labels.extend(labels.cpu().detach().numpy())
                all_preds.extend(preds.cpu().detach().numpy())
                all_probs.extend(probs.cpu().detach().numpy())

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # Compute metrics for validation phase
            if phase == 'val':
                cm = confusion_matrix(all_labels, all_preds)
                TP = cm[1, 1]
                TN = cm[0, 0]
                FP = cm[0, 1]
                FN = cm[1, 0]

                sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0
                specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
                precision = TP / (TP + FP) if TP + FP > 0 else 0
                f1_score = 2 * (precision * sensitivity) / (precision + sensitivity) if precision + sensitivity > 0 else 0
                roc_auc = roc_auc_score(all_labels, all_probs)
                precision_vals, recall_vals, _ = precision_recall_curve(all_labels, all_probs)
                pr_auc = auc(recall_vals, precision_vals)

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                print(f'Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}')
                print(f'Precision: {precision:.4f}, F1-Score: {f1_score:.4f}')
                print(f'ROC-AUC: {roc_auc:.4f}, PR-AUC: {pr_auc:.4f}')

                # Deep copy the model
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = model.state_dict()

        print()

    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [8]:
def evaluate_model(model, dataloader, device):
    model.eval()
    running_corrects = 0
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device).float()

            outputs = model(inputs).view(-1)
            probs = torch.sigmoid(outputs)
            preds = probs.round()

            running_corrects += torch.sum(preds == labels.data)
            all_labels.extend(labels.cpu().detach().numpy())
            all_preds.extend(preds.cpu().detach().numpy())
            all_probs.extend(probs.cpu().detach().numpy())

    accuracy = running_corrects.double() / len(dataloader.dataset)
    cm = confusion_matrix(all_labels, all_preds)
    TP = cm[1, 1]
    TN = cm[0, 0]
    FP = cm[0, 1]
    FN = cm[1, 0]

    sensitivity = TP / (TP + FN)
    specificity = TN / (TN + FP)
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    f1_score = 2 * (precision * sensitivity) / (precision + sensitivity) if precision + sensitivity > 0 else 0
    roc_auc = roc_auc_score(all_labels, all_probs)
    precision_vals, recall_vals, _ = precision_recall_curve(all_labels, all_probs)
    pr_auc = auc(recall_vals, precision_vals)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}')
    print(f'Precision: {precision:.4f}, F1-Score: {f1_score:.4f}')
    print(f'ROC-AUC: {roc_auc:.4f}, PR-AUC: {pr_auc:.4f}')

In [12]:
def main():
    img_dir = 'drive/MyDrive/bct_fng'  # Replace with your directory path
    batch_size = 15
    num_epochs = 150
    learning_rate = 0.001

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Get data loaders
    train_loader, val_loader, test_loader = get_dataloaders(img_dir, batch_size)
    dataloaders = {'train': train_loader, 'val': val_loader}

    # Initialize the model, loss function, and optimizer
    model = CNNModel().to(device)
    class_weights = torch.tensor([1.5, 1.0]).to(device)  # Adjust as per class weights
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    model = train_model(model, criterion, optimizer, dataloaders, num_epochs, device)

    # Evaluate the model on the test set
    evaluate_model(model, test_loader, device)

In [13]:
if __name__ == "__main__":
    main()

Using device: cpu
Loaded 6528 images.
Total images: 6528
Total labels: 6528
Train set size: 4700
Validation set size: 522
Test set size: 1306
Epoch 0/149
----------


KeyboardInterrupt: 