**LOADING DEPENDENCIES**

In [None]:
import os
import sys
import cv2
import json
import math
import torch
import random
import numpy as np
import pandas as pd
import openslide
import seaborn as sns
import tensorflow as tf
import torchvision
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from scipy import ndimage
import matplotlib.pyplot as plt
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torchvision import transforms, datasets
from matplotlib.patches import Polygon
from skimage.transform import rotate, AffineTransform
from matplotlib import patches
import xml.etree.ElementTree as ET 
from sklearn.model_selection import GroupKFold
from torchvision.transforms import ToTensor
import torchvision.models as models
from collections import defaultdict
from sklearn.metrics import roc_curve, auc, classification_report

**EXTRACTING PATCHES**

In [None]:
# function to return patches along with labels

def get_patch_info(wsi_nos, save_dir): 
    os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist
    
    wsi_ids = [] 
    ann_ids = [] 
    patch_paths = []
    patch_labels = [] 
    
    for k in range(len(wsi_nos)): 
        print(wsi_nos[k]) 
        
        # Path to the SVS and XML files
        img_path = f'/kaggle/input/bach-breast-cancer-histology-images/ICIAR2018_BACH_Challenge/ICIAR2018_BACH_Challenge/WSI/{wsi_nos[k]}.svs'
        ann_path = f'/kaggle/input/bach-breast-cancer-histology-images/ICIAR2018_BACH_Challenge/ICIAR2018_BACH_Challenge/WSI/{wsi_nos[k]}.xml'

        # Open the SVS file
        slide = openslide.OpenSlide(img_path)

        # Choosing a level (e.g., 0 for the highest resolution)
        level = 0

        # Load and parse the XML file
        tree = ET.parse(ann_path)
        root = tree.getroot() 

        ann_no = 1 
        # Extract coordinates for each region and group by class
        for region in root.findall('.//Region'):
            # Try to get the region_value(class) from the Attribute tag first
            attribute = region.find('.//Attribute')
            region_value = attribute.get('Value') if attribute is not None else region.get('Text')
            
            # Find region boundaries
            min_x, max_x, min_y, max_y = float('inf'), float('-inf'), float('inf'), float('-inf')
            for vertex in region.findall('.//Vertex'): 
                min_x, max_x = min(min_x, float(vertex.get('X'))), max(max_x, float(vertex.get('X')))
                min_y, max_y = min(min_y, float(vertex.get('Y'))), max(max_y, float(vertex.get('Y')))

            # Calculate width and height of the bounding box
            box_width = max_x - min_x
            box_height = max_y - min_y

            # Calculate inner bounding box by ignoring 10% from each side
            inner_min_x = min_x + 0.1 * box_width
            inner_max_x = max_x - 0.1 * box_width
            inner_min_y = min_y + 0.1 * box_height
            inner_max_y = max_y - 0.1 * box_height

            patch_no = 1
            # Slide through inner region boundaries for patches
            for i in range(int(inner_min_y), int(inner_max_y), 500): 
                for j in range(int(inner_min_x), int(inner_max_x), 500): 
                    # Read a region of the slide at the selected level (current patch)
                    curr_patch = slide.read_region((j, i), level, (500, 500))
                    curr_patch = curr_patch.resize((224, 224))  

                    # Convert to numpy array
                    curr_patch = np.array(curr_patch)[:, :, :3]
                    curr_patch = (curr_patch / 255.0 * 255).astype(np.uint8)  # Normalize to saveable format

                    # Convert to PIL Image for saving
                    patch_image = Image.fromarray(curr_patch)

                    # Define patch filename
                    patch_filename = f'{wsi_nos[k]}_ann{ann_no}_patch_no{patch_no}.png'
                    patch_path = os.path.join(save_dir, patch_filename)

                    # Save the patch image
                    patch_image.save(patch_path)

                    # Collect patch info
                    patch_paths.append(patch_path)
                    patch_labels.append(region_value) 
                    wsi_ids.append(wsi_nos[k]) 
                    ann_ids.append(ann_no) 
                    patch_no += 1

            print('Annotation', ann_no, 'done') 
            ann_no += 1 
        
        slide.close() 
    
    return wsi_ids, ann_ids, patch_paths, patch_labels

In [None]:
# unique ids for annotated wsis 

ids = ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10'] 

In [None]:
# obtaining patch info for all wsis 

dir = '/kaggle/working/patches'
wsi_ids, ann_ids, patch_paths, patch_labels = get_patch_info(ids, dir) 

**EXTRACTING NORMAL PATCHES**

In [None]:
# function to return patches along with labels

def get_normal_patch_info(wsi_nos, save_dir, curr_ann_count): 
    os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist
    
    wsi_ids = [] 
    ann_ids = [] 
    patch_paths = [] 
    patch_labels = []  
    
    for k in range(len(wsi_nos)): 
        print("WSI ", wsi_nos[k])
        
        # Path to the SVS and XML files
        img_path = '/kaggle/input/bach-breast-cancer-histology-images/ICIAR2018_BACH_Challenge/ICIAR2018_BACH_Challenge/WSI/' + wsi_nos[k] + '.svs' 
        ann_path = '/kaggle/input/bach-dataset-patches/Extra Normal Annotations/Extra Normal Annotations/' + wsi_nos[k] + '_norm.xml' 

        # Open the SVS file
        slide = openslide.OpenSlide(img_path)

        # Choosing a level (e.g., 0 for the highest resolution)
        level = 0

        # Get slide dimensions at the selected level
        width, height = slide.level_dimensions[level] 

        # Load and parse the XML file
        tree = ET.parse(ann_path)
        root = tree.getroot()

        ann_no = curr_ann_count[wsi_nos[k]] + 1
        
        # Extract coordinates for each region 
        for region in root.findall('.//Region'): 
            min_x = float('inf')
            max_x = float('-inf')
            min_y = float('inf')
            max_y = float('-inf')
            for vertex in region.findall('.//Vertex'): 
                # Update min and max x and y coordinates
                min_x = min(min_x, int(vertex.get('X')))
                max_x = max(max_x, int(vertex.get('X')))
                min_y = min(min_y, int(vertex.get('Y')))
                max_y = max(max_y, int(vertex.get('Y'))) 

            # Calculate width and height of the bounding box
            box_width = max_x - min_x
            box_height = max_y - min_y

            # Calculate inner bounding box by ignoring 10% from each side
            inner_min_x = min_x + 0.1 * box_width
            inner_max_x = max_x - 0.1 * box_width
            inner_min_y = min_y + 0.1 * box_height
            inner_max_y = max_y - 0.1 * box_height

            patch_no = 1
            # Slide vertically within the adjusted bounding box
            for i in range(int(inner_min_y), int(inner_max_y), 500): 
                # Slide horizontally
                for j in range(int(inner_min_x), int(inner_max_x), 500): 

                    # Read a region of the slide at the selected level (current patch) 
                    curr_patch = slide.read_region((j, i), level, (500, 500)) 
                    curr_patch = curr_patch.resize((224, 224))  

                    # Convert to numpy array
                    curr_patch = np.array(curr_patch)[:, :, :3]
                    curr_patch = (curr_patch / 255.0 * 255).astype(np.uint8)  # Normalize to saveable format
                    
                    # Convert to PIL Image for saving
                    patch_image = Image.fromarray(curr_patch)

                    # Define patch filename
                    patch_filename = f'{wsi_nos[k]}_ann{ann_no}_patch_no{patch_no}.png'
                    patch_path = os.path.join(save_dir, patch_filename)

                    # Save the patch image
                    patch_image.save(patch_path)

                    # Collect patch info
                    patch_paths.append(patch_path)
                    patch_labels.append("Normal") 
                    wsi_ids.append(wsi_nos[k]) 
                    ann_ids.append(ann_no) 
                    patch_no += 1
            
            print('Annotation', ann_no, 'done') 
            ann_no += 1 
        
        slide.close() 
    
    return wsi_ids, ann_ids, patch_paths, patch_labels

In [None]:
# checking for consistent lengths

print(len(wsi_ids), len(ann_ids), len(patch_paths), len(patch_labels))

In [None]:
# Encode with binary class labels

label_to_onehot = {
    'Benign': 0,
    'Normal': 0
}

# Mark 'Benign' and 'Normal' as 0, any other label as 1
patch_labels = [label_to_onehot.get(label, 1) for label in patch_labels]

**CLASS DISTRIBUTION** 

In [None]:
# class distribution of train, validation, and test datasets 

def plot_bag_labels(ax, title, labels): 
    
    # Count occurrences of each class
    class_counts = [labels.count(i) for i in range(2)]
    
    # Plot the distribution
    ax.bar(['Benign or Normal', 'Insitu Carcinoma or Invasive Carcinoma'], class_counts)
    ax.set_title(title)
    ax.set_ylabel('Count')

# Plotting
fig, axs = plt.subplots(1, 1, figsize=(5, 4))

plot_bag_labels(axs, 'BAG LABELS', patch_labels)

plt.tight_layout()
plt.show()

**CUSTOM DATASETS** 

In [None]:
# Defining the custom dataset class

class CustomDataset(Dataset):
    def __init__(self, patches, labels):
        assert len(patches) == len(labels), "Mismatch in number of patches and labels"
        self.patches = patches
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Check for valid index range
        if idx >= len(self.labels) or idx >= len(self.patches):
            raise IndexError(f"Index {idx} out of range for dataset with length {len(self)}")

        patch = Image.open(self.patches[idx]).convert("RGB")  # Adjust this if `patches` are file paths
        patch = patch.resize((224, 224))  
        patch = ToTensor()(patch)
        patch = nn.functional.normalize(patch, dim=0, p=2)

        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return patch, label

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

**FINAL MODEL** 

In [None]:
class Pipeline(nn.Module):
    def __init__(self,num_classes):
        #define necessary layers
        super().__init__()
        self.num_classes = num_classes
          
        self.base = models.vit_b_32(weights='DEFAULT')

        # Unfreeze model weights
        for param in self.base.parameters():
            param.requires_grad = True
        
        self.flatten = nn.Flatten()
        
        self.head = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ) 
    
    def forward(self,X):
        X = self.base(X)
        X = self.flatten(X)
        X = self.head(X)
        return X, F.sigmoid(X)

pos_weight = torch.tensor([0.6])
pos_weight = pos_weight.to(device)

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

**K FOLD CROSS VALIDATION** 

In [None]:
# Setting some hyperparameters

batch_size = 16
num_epochs = 10
num_folds = 5
group_kfold = GroupKFold(n_splits=num_folds)
metrics = {
    fold + 1: {
        'train_loss': [], 'train_accuracy': [], 'train_roc_auc': [],
        'valid_loss': [], 'valid_accuracy': [], 'valid_roc_auc': [],
        'wsi_nos': [], 'ann_nos': [], 'y_pred':[], 'y_true':[]
    } for fold in range(num_folds)
}

In [None]:
# K Fold cross validation loop

for fold, (train_idx, test_idx) in enumerate(group_kfold.split(patch_paths, patch_labels, groups=wsi_ids)):
    
    # Splitting bags and labels for each fold without converting to numpy array
    train_patch_paths = [patch_paths[i] for i in train_idx]
    test_patch_paths = [patch_paths[i] for i in test_idx]
    train_patch_labels = [patch_labels[i] for i in train_idx]
    test_patch_labels = [patch_labels[i] for i in test_idx]
    
    metrics[fold+1]['wsi_nos'] = [wsi_ids[i] for i in test_idx]
    metrics[fold+1]['ann_nos'] = [ann_ids[i] for i in test_idx]
    
    # Creating datasets and loaders for each fold
    train_dataset = CustomDataset(train_patch_paths, train_patch_labels)
    val_dataset = CustomDataset(test_patch_paths, test_patch_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Output the fold details
    print(f"\nFold {fold + 1}")
    print(f"Train WSI IDs: {set(np.array(wsi_ids)[train_idx])}")
    print(f"Test WSI IDs: {set(np.array(wsi_ids)[test_idx])}")
    print(f"Train size: {len(train_loader.dataset)}, Test size: {len(valid_loader.dataset)}")
    
    # Initialize empty lists to store loss and accuracy for training and validation
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []
    roc_values_train = []
    roc_values_val = []
    
    # modifying first layer for one color channel 
    loss_fn = nn.BCEWithLogitsLoss(pos_weight = pos_weight)
    model = Pipeline(1)

    # Check if GPU is available
    if torch.cuda.is_available():
        model = model.to(device)
        print('available')

    #criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=model.parameters(), lr=0.000025)
    
    for epoch in range(num_epochs):
        model.train()
        y_true_train = []
        y_scores_train = []
        train_loss = 0
        correct = 0
        total = 0
        for images, labels in tqdm(train_loader):
            # Move data to GPU if available
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs_without_sigmoid, outputs_with_sigmoid = model(images)
            loss = loss_fn(outputs_without_sigmoid.squeeze(-1), labels.float())
            loss.backward()
            optimizer.step()
            train_loss += loss.item()  # L+=l.item()
            predicted = (outputs_with_sigmoid >= 0.5).long().squeeze(-1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            y_true_train.extend(labels.cpu().detach().numpy())
            y_scores_train.extend(outputs_with_sigmoid.cpu().detach().numpy())

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(100 * correct / total)

        # roc auc logic
        fpr_train, tpr_train, _ = roc_curve(y_true_train, y_scores_train)
        roc_auc_train = auc(fpr_train, tpr_train)
        roc_values_train.append(roc_auc_train)

        # Plot ROC curve for the training set

        # Validate your model after each epoch if needed
        model.eval()
        valid_loss = 0
        correct = 0
        total = 0
        y_true_val = []
        y_scores_val = []
        with torch.no_grad():
            for images, labels in tqdm(valid_loader):
                # Move data to GPU if available
                images = images.to(device)
                labels = labels.to(device)
                outputs_without_sigmoid, outputs_with_sigmoid = model(images)
                loss = loss_fn(outputs_without_sigmoid.squeeze(-1), labels.float())
                valid_loss += loss.item()  # L+=l.item()
                predicted = (outputs_with_sigmoid >= 0.5).long().squeeze(-1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                y_true_val.extend(labels.cpu().detach().numpy())
                y_scores_val.extend(outputs_with_sigmoid.cpu().detach().numpy())

        valid_losses.append(valid_loss / len(valid_loader))
        valid_accuracies.append(100 * correct / total)
        # roc auc logic
        fpr_val, tpr_val, _ = roc_curve(y_true_val, y_scores_val)
        roc_auc_val = auc(fpr_val, tpr_val)
        roc_values_val.append(roc_auc_val)

        # Classification report
        print(f'Epoch {epoch + 1}, Train Accuracy: {train_accuracies[-1]:.2f}%, Train Loss: {train_losses[-1]:.2f}%, Train ROC-AUC: {roc_auc_train:.2f}%, Val Accuracy: {valid_accuracies[-1]:.2f}%, Val Loss: {valid_losses[-1]:.2f}%, Val ROC-AUC: {roc_auc_val:.2f}%')

        y_true_val = np.array(y_true_val)
        y_pred_val = (np.array(y_scores_val) >= 0.5).astype(int)
        print("Validation Classification Report:")
        print(classification_report(y_true_val, y_pred_val,zero_division=1))
        
         # Append metrics for current epoch and fold
        metrics[fold + 1]['train_loss'].append(train_losses[-1])
        metrics[fold + 1]['train_accuracy'].append(train_accuracies[-1])
        metrics[fold + 1]['train_roc_auc'].append(roc_values_train[-1])
        metrics[fold + 1]['valid_loss'].append(valid_losses[-1])
        metrics[fold + 1]['valid_accuracy'].append(valid_accuracies[-1])
        metrics[fold + 1]['valid_roc_auc'].append(roc_values_val[-1])
        
    metrics[fold + 1]['y_pred']=y_pred_val
    metrics[fold + 1]['y_true']=y_true_val

In [None]:
# Plotting function for losses and accuracies by fold

def plot_metrics(metrics, num_folds):
    metric_names = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy', 'train_roc_auc', 'valid_roc_auc']
    
    for metric_name in metric_names:
        plt.figure(figsize=(10, 5))
        for fold in range(1, num_folds + 1):
            fold_metric = metrics[fold][metric_name]
            plt.plot(range(1, num_epochs + 1), fold_metric, label=f'Fold {fold}')
        plt.title(f'{metric_name.replace("_", " ").title()} over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.replace('_', ' ').title())
        plt.legend()
        plt.show()

plot_metrics(metrics, num_folds)

**WSI LEVEL METRICS** 

In [None]:
# Function to calculate results on a WSI Level

def calculate_wsi_level_metrics(wsi_ids, ann_ids, ground_labels, predicted_labels):
    # Dictionary to store TP, FP, TN, FN counts per WSI and per class
    wsi_metrics = defaultdict(lambda: {"0": {"TP": 0, "FP": 0, "TN": 0, "FN": 0}, 
                                       "1": {"TP": 0, "FP": 0, "TN": 0, "FN": 0}})

    # Iterate through each annotation and count TP, FP, TN, FN per WSI and per class
    for wsi, ground, pred in zip(wsi_ids, ground_labels, predicted_labels):
        if ground == 1 and pred == 1:
            wsi_metrics[wsi]["1"]["TP"] += 1
        elif ground == 0 and pred == 1:
            wsi_metrics[wsi]["1"]["FP"] += 1
            wsi_metrics[wsi]["0"]["FN"] += 1
        elif ground == 0 and pred == 0:
            wsi_metrics[wsi]["0"]["TP"] += 1
        elif ground == 1 and pred == 0:
            wsi_metrics[wsi]["1"]["FN"] += 1
            wsi_metrics[wsi]["0"]["FP"] += 1

    # Calculate accuracy for class 1 and F1 scores for each class per WSI
    wsi_scores = {}
    for wsi, metrics in wsi_metrics.items():
        # Calculate accuracy for class 1
        total_1 = metrics["1"]["TP"] + metrics["1"]["FP"] + metrics["1"]["TN"] + metrics["1"]["FN"]
        accuracy_1 = ((metrics["1"]["TP"] + metrics["1"]["TN"]) / total_1) * 100 if total_1 > 0 else 0

        # Calculate F1 scores for both classes
        f1_scores = {}
        for class_label in ["0", "1"]:
            tp, fp, tn, fn = metrics[class_label]["TP"], metrics[class_label]["FP"], metrics[class_label]["TN"], metrics[class_label]["FN"]

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1_scores[class_label] = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        wsi_scores[wsi] = (accuracy_1, f1_scores["0"], f1_scores["1"])

    for wsi, (accuracy_1, f1_0, f1_1) in wsi_scores.items():
        print(f"WSI {wsi}: Class 1 Accuracy = {accuracy_1:.2f}%, F1 Score Class 0 = {f1_0:.2f}, F1 Score Class 1 = {f1_1:.2f}")

In [None]:
# Displaying WSI level metrics

wsi_level_metrics = {}

for fold in range(1, num_folds + 1):
    # Retrieve WSI IDs, annotation IDs, true labels, and predictions for the validation data in the current fold
    wsi_ids = metrics[fold]['wsi_nos']
    ann_ids = metrics[fold]['ann_nos']
    true_labels = metrics[fold]['y_true']
    predicted_labels = metrics[fold]['y_pred']

    wsi_metrics = calculate_wsi_level_metrics(wsi_ids, ann_ids, true_labels, predicted_labels)
    wsi_level_metrics[fold] = wsi_metrics
