<a href="https://colab.research.google.com/github/pariyamd/CPATH_TTA/blob/main/2_Validation_TvN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install staintools

In [2]:
# !pip install spams

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim import Adam
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
from PIL import Image, ImageOps
import staintools
from statistics import median

In [4]:
# from google.colab import drive
# drive.mount('/content/drive')

In [5]:
# !unzip "/content/drive/MyDrive/COMP499/Project/Validation data/val_dataset_2_norm.zip" -d "/content/drive/MyDrive/COMP499/Project/Validation data/Extracted_Norm"
# !unzip "/content/drive/MyDrive/COMP499/Project/Validation data/val_dataset_2_tu.zip" -d "/content/drive/MyDrive/COMP499/Project/Validation data/Extracted_Tu"

### **SET PARAMETERS**

In [6]:
###Path to directory with models (one or several models to be tested)
model_dir = 'Models'
###DIRECTORY WITH IMAGES
#Tumor
base_dir_tu = 'Validation data/Extracted_Tu/tu'
#Benign
base_dir_norm = 'Validation data/Extracted_Norm/norm'
###OUTPUT DIRECTORY FOR RESULT FILES
result_dir = 'Validation data/Validation Result'

## **GENERATE LIST AND NAMES**




In [7]:
###GENERATE LIST OF MODELS (if several models are tested)
model_names = sorted(os.listdir(model_dir))
###MODEL PATCH SIZES: define the patch size to use within models
#here for example two models in list, each working with 350px patches
m_p_s_list = [350, 350]

## **INITIALIZE STAIN NORMALIZER**

In [14]:
#Standartization image
st = staintools.read_image('/content/drive/MyDrive/COMP499/Project/Validation data/standard_he_stain_small.jpg')
standardizer = staintools.LuminosityStandardizer.standardize(st)
#Inititate StainNormalizer "macenko"
stain_norm = staintools.StainNormalizer(method='macenko')
#Read Hematoxylin/Eosin staining schema from Standartization image
stain_norm.fit(st)

In [21]:
def process_and_standardize_image(image_path, stain_norm):
    # Load, optionally standardize brightness, and apply stain normalization
    image = staintools.read_image(image_path)
    image = staintools.LuminosityStandardizer.standardize(image)
    normalized_image = stain_norm.transform(image)
    return normalized_image

## **FUNCTIONS**

In [31]:
import torch
import torchvision.transforms.functional as TF

#Implementation of the Strategy C8 (derivates of the main image, s. Methods)
#as a function
#As input: native version of the patch
def gateway_median(patch):
    # Ensure the input patch is a PyTorch tensor
    if not isinstance(patch, torch.Tensor):
        patch = TF.to_tensor(patch)

    # Ensure patch is in the correct format
    if patch.dim() == 2:  # Grayscale, add a channel dimension
        patch = patch.unsqueeze(0)
    elif patch.dim() == 4:  # Batch dimension present, remove it
        patch = patch.squeeze(0)

    pred_list = []

    # Process patch and its derivatives
    # Original and rotated patches
    pred_list.append(pred(patch.unsqueeze(0)))  # Original
    for angle in [90, 180, 270]:
        rotated_patch = TF.rotate(patch, angle)
        pred_list.append(pred(rotated_patch.unsqueeze(0)))

    # Flipped and rotated patches
    r90 = TF.rotate(patch, 90)
    r270 = TF.rotate(patch, 270)
    r90_VF = TF.vflip(r90)
    r270_VF = TF.vflip(r270)
    pred_list.append(pred(r90_VF.unsqueeze(0)))
    pred_list.append(pred(r270_VF.unsqueeze(0)))

    # Vertically and horizontally flipped patches
    VF = TF.vflip(patch)
    HF = TF.hflip(patch)
    pred_list.append(pred(VF.unsqueeze(0)))
    pred_list.append(pred(HF.unsqueeze(0)))

    # Convert numpy arrays in pred_list to tensors before stacking
    pred_list_tensors = [torch.tensor(item, dtype=torch.float32) for item in pred_list]  # Conversion

    # Stack the predictions for median calculation
    pred_stack = torch.stack(pred_list_tensors)

    # Calculate the median of predictions
    preds_med, _ = torch.median(pred_stack, dim=0)

    # Convert the tensor of medians back to a numpy array if necessary
    preds_med = preds_med.squeeze().cpu().numpy()  # Ensure to move to CPU if it was on GPU

    return preds_med

In [13]:
#Function for generation of prediction for single patches (used in C8)
def pred(patch):
    # Check if the input is a PIL Image and convert it to tensor if necessary
    if not isinstance(patch, torch.Tensor):
        patch = TF.to_tensor(patch)

    # Normalize the patch
    patch = TF.normalize(patch, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # Add a batch dimension if it's missing
    if patch.dim() == 3:
        patch = patch.unsqueeze(0)

    # Ensure the patch is on the same device as the model
    patch = patch.to(next(model.parameters()).device)

    # Perform the prediction
    with torch.no_grad():
        preds = model(patch)

    # Convert predictions to probabilities using softmax
    preds = torch.softmax(preds, dim=1)

    # Convert the predictions to a numpy array and return
    return preds.cpu().numpy()

In [28]:
#Loop for analysis of patches from validation dataset with "tumor" label
#as function
def processor_tu(m_p_s):
    global model
    global output_tu_C1, output_tu_C8

    work_dir = base_dir_tu

    model.eval()

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((m_p_s, m_p_s)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    fnames = sorted(os.listdir(work_dir))

    for fname in fnames:
        filename = os.path.join(work_dir, fname)
        processed_image = process_and_standardize_image(filename, stain_norm)
         # Convert to PIL for compatibility with torchvision transforms
        im_pil = Image.fromarray(processed_image.astype('uint8'), 'RGB')

        # Apply PyTorch transformations
        x = transform(im_pil).unsqueeze(0)  # Add batch dimension
        # Move tensor to CUDA if available
        if torch.cuda.is_available():
            x = x.cuda()

        # Make a prediction
        with torch.no_grad():
            preds = model(x).squeeze().cpu().numpy()  # Ensure to move the tensor back to CPU if necessary

        pr_1, pr_2, pr_3 = np.round(preds, 3)
        output = f"{fname}\t{pr_1}\t{pr_2}\t{pr_3}\n"

        # Write down output of C1
        with open(output_tu_C1, "a+") as results:
            results.write(output)

        # Additional analysis using C8 strategy
        if 0.2 < preds[2] < 0.5:
            print(f"{fname} was misclassified, but trying C8")
            # Ensure 'gateway_median' is adapted for PyTorch tensors and correctly processes 'processed_image'
            im_pil_for_c8 = Image.fromarray(processed_image.astype('uint8'), 'RGB')  # Convert to PIL for compatibility
            preds_C8 = gateway_median(im_pil_for_c8)

            output_C8 = f"{fname}\t{preds_C8[0]}\t{preds_C8[1]}\t{preds_C8[2]}\n"
            with open(output_tu_C8, "a+") as results:
                results.write(output_C8)

            if preds_C8[2] > 0.5:
                print(f"{fname} was reclassified")
            else:
                print(f"{fname} was not reclassified")
        else:
            status = "is a tumor" if preds[2] >= 0.5 else "was misclassified"
            print(f"{fname} {status}")

In [29]:
#Loop for analysis of patches from validation dataset with "benign" label
#as function
#Analogous to tumor processor above
def processor_norm(m_p_s):
    global model, output_n_C1, output_n_C8

    # Define the working directory for "benign" labeled images
    work_dir = base_dir_norm

    # Ensure the model is in evaluation mode
    model.eval()

    # Define the transformation pipeline for the images
    transform = transforms.Compose([
        transforms.Resize((m_p_s, m_p_s)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    fnames = sorted(os.listdir(work_dir))

    # Process each file in the directory
    for fname in fnames:
        filename = os.path.join(work_dir, fname)
        processed_image = process_and_standardize_image(filename, stain_norm)
        im_pil = Image.fromarray(processed_image.astype('uint8'), 'RGB')  # Convert to PIL image for transforms

        x = transform(im_pil).unsqueeze(0)  # Apply transformations and add batch dimension

        # Move tensor to CUDA if available, otherwise use CPU
        x = x.cuda() if torch.cuda.is_available() else x

        with torch.no_grad():
            preds = model(x).squeeze().cpu().numpy()  # Prediction

        pr_1, pr_2, pr_3 = np.round(preds, 3)

        output = f"{fname}\t{pr_1}\t{pr_2}\t{pr_3}\n"
        with open(output_n_C1, "a+") as results:
            results.write(output)

        # Additional C8 strategy analysis if needed
        if preds[2] > 0.5:
            print(f"{fname} was misclassified, but trying C8")
            preds_C8 = gateway_median(processed_image)  # Ensure this is the correct format for gateway_median

            output_C8 = f"{fname}\t{preds_C8[0]}\t{preds_C8[1]}\t{preds_C8[2]}\n"
            with open(output_n_C8, "a+") as results:
                results.write(output_C8)

            if preds_C8[2] < 0.5:
                print(f"{fname} was reclassified")
            else:
                print(f"{fname} was not reclassified")
        else:
            status = "is normal" if preds[2] < 0.5 else "was misclassified"
            print(f"{fname} {status}")

## **MAIN LOOP**

In [32]:
import os

i = 0
for model_name in model_names:
    print("Loading model: ", model_name, " ...")
    path_model = os.path.join(model_dir, model_name)

    # Skip if the model path is not a file (e.g., it's a directory)
    if not os.path.isfile(path_model):
        print(f"Skipped {model_name} since it's not a file.")
        continue  # Skip the rest of the loop for this iteration

    model = torch.load(path_model)  # Load the model
    model = model.cuda()

    # Create paths to results files
    output_tu_C1 = os.path.join(result_dir, f"{model_name}__C1__tu.txt")
    output_n_C1 = os.path.join(result_dir, f"{model_name}__C1__norm.txt")
    output_tu_C8 = os.path.join(result_dir, f"{model_name}__C8__tu.txt")
    output_n_C8 = os.path.join(result_dir, f"{model_name}__C8__norm.txt")

    # Start analysis
    processor_tu(m_p_s_list[i])
    processor_norm(m_p_s_list[i])

    # Increment of i for m_p_s
    i += 1
    print("Ready! Going to the next model.")  # feedback


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
norm.8598.jpg is normal
norm.8599.jpg is normal
norm.86.jpg is normal
norm.860.jpg is normal
norm.8600.jpg is normal
norm.8601.jpg is normal
norm.8602.jpg was misclassified, but trying C8
norm.8602.jpg was not reclassified
norm.8603.jpg is normal
norm.8604.jpg is normal
norm.8605.jpg is normal
norm.8606.jpg is normal
norm.8607.jpg is normal
norm.8608.jpg is normal
norm.8609.jpg is normal
norm.861.jpg is normal
norm.8610.jpg is normal
norm.8611.jpg was misclassified, but trying C8
norm.8611.jpg was reclassified
norm.8612.jpg is normal
norm.8613.jpg is normal
norm.8614.jpg is normal
norm.8615.jpg is normal
norm.8616.jpg is normal
norm.8617.jpg was misclassified, but trying C8
norm.8617.jpg was not reclassified
norm.8618.jpg is normal
norm.8619.jpg was misclassified, but trying C8
norm.8619.jpg was reclassified
norm.862.jpg is normal
norm.8620.jpg is normal
norm.8621.jpg is normal
norm.8622.jpg is normal
norm.8623.jpg is nor

KeyboardInterrupt: 