In [None]:
!pip install openslide-python
!apt-get install -y openslide-tools

Collecting openslide-python
  Downloading openslide-python-1.3.1.tar.gz (358 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m359.0/359.0 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: openslide-python
  Building wheel for openslide-python (setup.py) ... [?25l[?25hdone
  Created wheel for openslide-python: filename=openslide_python-1.3.1-cp310-cp310-linux_x86_64.whl size=33554 sha256=78e8d821ef351592afee82cb61124e87a44ff29e7b788dab25057d73a09f3db1
  Stored in directory: /root/.cache/pip/wheels/79/79/fa/29a0087493c69dff7fd0b70fab5d6771002a531010161d2d97
Successfully built openslide-python
Installing collected packages: openslide-python
Successfully installed openslide-python-1.3.1
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libopenslide0
Suggested packages:
  libt

Slide-level parameters: Feature extraction for Machine-learning classifier

In [None]:
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv
from skimage.color import rgb2hed, hed2rgb
from skimage import img_as_ubyte
from tqdm import tqdm

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def color_separate(ihc_rgb):
    # Convert the RGB image to HED using the prebuilt skimage method
    ihc_hed = rgb2hed(ihc_rgb)

    # Create an RGB image for each of the separated stains
    # Convert them to ubyte for easy saving to drive as an image
    null = np.zeros_like(ihc_hed[:, :, 0])
    ihc_h = img_as_ubyte(hed2rgb(np.stack((ihc_hed[:, :, 0], null, null), axis=-1)))
    ihc_d = img_as_ubyte(hed2rgb(np.stack((null, null, ihc_hed[:, :, 2]), axis=-1)))
    ihc_e = img_as_ubyte(hed2rgb(np.stack((null, ihc_hed[:, :, 1], null), axis=-1)))  # Eosin channel

    return ihc_h, ihc_d, ihc_e

def main(args, slide_directory, model_path, data_folder, output_folder):
    os.makedirs(data_folder, exist_ok=True)  # Create data folder if it doesn't exist

    # Function to generate tile coordinates for a single slide
    def generate_grid(slide_path, tile_size=224, overlap=0):
        slide = openslide.OpenSlide(slide_path)
        dimensions = slide.level_dimensions[args.level]
        grid = []
        for x in range(0, dimensions[0], args.tile_size - overlap):
            for y in range(0, dimensions[1], args.tile_size - overlap):
                grid.append((x, y))
        return grid

    # Extract label directly from the file name
    def extract_label_from_filename(filename):
        try:
            # Extract the label (last character before '.svs')
            label_str = filename.split('_')[-1].split('.')[0]
            # Extract numeric part of the label string
            label_numeric = ''.join(filter(str.isdigit, label_str))
            label = int(label_numeric)  # Convert the numeric part to an integer
            return label
        except ValueError:
            print(f"Failed to extract label from filename: {filename}")
            return None

    # Prepare the dataset
    def prepare_mil_input_data(slide_path, tile_size=224, overlap=0, mult=1.0, level=0):
        grid = generate_grid(slide_path, tile_size, overlap)
        slide_name = os.path.basename(slide_path)
        targets = extract_label_from_filename(slide_name)

        data = {
            'slides': [slide_path],
            'grid': [grid],
            'targets': [targets],
            'mult': mult,
            'level': level
        }

        return data

    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    for filename in os.listdir(slide_directory):
        if filename.endswith(".svs"):
            slide_path = os.path.join(slide_directory, filename)
            data_file = os.path.join(data_folder, f"{os.path.splitext(filename)[0]}_data.pt")

            # Prepare MIL input data
            data = prepare_mil_input_data(slide_path, args.tile_size, args.overlap, args.mult, args.level)
            torch.save(data, data_file)

            # Load data
            dset = MILdataset(data_file, trans)
            loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

            dset.setmode(1)
            probs, tile_info = inference(loader, model)

            # Initialize counters
            positive_count = 0
            negative_count = 0

            # Collect slide-level information
            slide_info = {}

            # Initialize a list to store mean intensities for each slide
            slide_mean_intensities = {}

            # Initialize a list to store probabilities for all tiles where prob >= 0.5
            slide_mean_probs = {}

            # Initialize dictionaries to store mean intensity values for IHC channels
            slide_mean_ihc_h = {}
            slide_mean_ihc_d = {}
            slide_mean_ihc_e = {}

            # Initialize dictionaries to store mean intensity values for individual RGB channels
            slide_mean_intensity_R = {}
            slide_mean_intensity_G = {}
            slide_mean_intensity_B = {}

            for slideIDX, coord, prob in tile_info:
                slide_name = dset.slidenames[slideIDX]
                if slide_name not in slide_info:
                    slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

                if prob >= 0.5:
                    slide_info[slide_name]['positive_tiles'] += 1
                    positive_count += 1
                    img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')

                    # Compute mean RGB intensity
                    img_array = np.array(img)
                    mean_intensity = np.mean(img_array, axis=(0, 1))  # Compute mean intensity for each channel (R, G, B)

                    # Compute mean intensity for individual RGB channels
                    mean_intensity_R = np.mean(img_array[:, :, 0])  # Red channel
                    mean_intensity_G = np.mean(img_array[:, :, 1])  # Green channel
                    mean_intensity_B = np.mean(img_array[:, :, 2])  # Blue channel

                    # Extract IHC channel and compute mean intensity
                    ihc_h, ihc_d, ihc_e = color_separate(img_array)
                    mean_intensity_ihc_h = np.mean(ihc_h)
                    mean_intensity_ihc_d = np.mean(ihc_d)
                    mean_intensity_ihc_e = np.mean(ihc_e)

                    # Update slide mean intensity
                    if slide_name not in slide_mean_intensities:
                        slide_mean_intensities[slide_name] = []
                    slide_mean_intensities[slide_name].append(mean_intensity)

                    # Update slide mean intensity for individual RGB channels
                    if slide_name not in slide_mean_intensity_R:
                        slide_mean_intensity_R[slide_name] = []
                    slide_mean_intensity_R[slide_name].append(mean_intensity_R)

                    if slide_name not in slide_mean_intensity_G:
                        slide_mean_intensity_G[slide_name] = []
                    slide_mean_intensity_G[slide_name].append(mean_intensity_G)

                    if slide_name not in slide_mean_intensity_B:
                        slide_mean_intensity_B[slide_name] = []
                    slide_mean_intensity_B[slide_name].append(mean_intensity_B)

                    # Update slide mean probability
                    if slide_name not in slide_mean_probs:
                        slide_mean_probs[slide_name] = []
                    slide_mean_probs[slide_name].append(prob)

                    # Update slide mean intensity for IHC channels
                    if slide_name not in slide_mean_ihc_h:
                        slide_mean_ihc_h[slide_name] = []
                    slide_mean_ihc_h[slide_name].append(mean_intensity_ihc_h)

                    if slide_name not in slide_mean_ihc_d:
                        slide_mean_ihc_d[slide_name] = []
                    slide_mean_ihc_d[slide_name].append(mean_intensity_ihc_d)

                    if slide_name not in slide_mean_ihc_e:
                        slide_mean_ihc_e[slide_name] = []
                    slide_mean_ihc_e[slide_name].append(mean_intensity_ihc_e)
                else:
                    slide_info[slide_name]['negative_tiles'] += 1
                    negative_count += 1

            # Calculate the mean intensity for each slide with positive tiles
            for slide_name, intensities in slide_mean_intensities.items():
                mean_intensity = np.mean(intensities, axis=0)  # Compute mean intensity across all tiles for each channel
                slide_info[slide_name]['mean_intensity'] = mean_intensity

            # Calculate the mean intensity for each slide with positive tiles for individual RGB channels
            for slide_name, intensities in slide_mean_intensity_R.items():
                mean_intensity_R = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_R'] = mean_intensity_R

            for slide_name, intensities in slide_mean_intensity_G.items():
                mean_intensity_G = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_G'] = mean_intensity_G

            for slide_name, intensities in slide_mean_intensity_B.items():
                mean_intensity_B = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_B'] = mean_intensity_B

            # Calculate the mean probability for each slide with positive tiles
            for slide_name, probabilities in slide_mean_probs.items():
                mean_prob = np.mean(probabilities)
                slide_info[slide_name]['mean_prob'] = mean_prob

            # Calculate the mean intensity for each slide with positive tiles for IHC channels
            for slide_name, intensities in slide_mean_ihc_h.items():
                mean_intensity_ihc_h = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_h'] = mean_intensity_ihc_h

            for slide_name, intensities in slide_mean_ihc_d.items():
                mean_intensity_ihc_d = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_d'] = mean_intensity_ihc_d

            for slide_name, intensities in slide_mean_ihc_e.items():
                mean_intensity_ihc_e = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_e'] = mean_intensity_ihc_e

            # Calculate distances between tiles with prob >= 0.5
            for slide_name, info in slide_info.items():
                coords = [coord for slideIDX, coord, prob in tile_info if slideIDX == dset.slidenames.index(slide_name) and prob >= 0.5]
                if len(coords) > 1:
                    distances = np.linalg.norm(np.diff(coords, axis=0), axis=1)
                    mean_distance = np.mean(distances)
                    max_distance = np.max(distances)
                    min_distance = np.min(distances)
                else:
                    mean_distance = 0
                    max_distance = 0
                    min_distance = 0
                info['mean_distance'] = mean_distance
                info['max_distance'] = max_distance
                info['min_distance'] = min_distance

            # Write to CSV file
            output_csv = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_prediction_results.csv")
            with open(output_csv, 'w', newline='') as csvfile:
                fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label',
                              'mean_intensity_R', 'mean_intensity_G', 'mean_intensity_B', 'mean_intensity', 'mean_prob',
                              'mean_intensity_ihc_h', 'mean_intensity_ihc_d', 'mean_intensity_ihc_e', 'mean_distance', 'max_distance', 'min_distance']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                writer.writeheader()
                for slide_name, info in slide_info.items():
                    total_tiles = info['positive_tiles'] + info['negative_tiles']
                    if total_tiles > 0:
                        positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
                    else:
                        positive_percentage = 0

                    ground_truth = get_ground_truth(slide_name)
                    predicted_label = get_predicted_label(info['positive_tiles'])

                    # Write mean intensity value to CSV
                    mean_intensity_str = ','.join(map(str, info.get('mean_intensity', ['N/A'])))

                    # Write mean probability value to CSV
                    mean_prob = info.get('mean_prob', 'N/A')

                    # Write mean intensity of individual RGB channels to CSV
                    mean_intensity_R = info.get('mean_intensity_R', 'N/A')
                    mean_intensity_G = info.get('mean_intensity_G', 'N/A')
                    mean_intensity_B = info.get('mean_intensity_B', 'N/A')

                    # Write mean intensity of IHC channels to CSV
                    mean_intensity_ihc_h = info.get('mean_intensity_ihc_h', 'N/A')
                    mean_intensity_ihc_d = info.get('mean_intensity_ihc_d', 'N/A')
                    mean_intensity_ihc_e = info.get('mean_intensity_ihc_e', 'N/A')

                    writer.writerow({'slide_name': slide_name,
                                     'ground_truth': ground_truth,
                                     'predicted_label': predicted_label,
                                     'positive_tiles': info['positive_tiles'],
                                     'negative_tiles': info['negative_tiles'],
                                     'positive_percentage': positive_percentage,
                                     'mean_prob': mean_prob,
                                     'mean_intensity': mean_intensity_str,
                                     'mean_intensity_R': mean_intensity_R,
                                     'mean_intensity_G': mean_intensity_G,
                                     'mean_intensity_B': mean_intensity_B,
                                     'mean_intensity_ihc_h': mean_intensity_ihc_h,
                                     'mean_intensity_ihc_e': mean_intensity_ihc_e,
                                     'mean_intensity_ihc_d': mean_intensity_ihc_d,
                                     'mean_distance': info['mean_distance'],
                                     'max_distance': info['max_distance'],
                                     'min_distance': info['min_distance']})

if __name__ == "__main__":
    args = Namespace(batch_size=256, workers=4, tile_size=224, overlap=0, mult=1.0, level=0)
    slide_directory = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/validation_WSI/'
    model_path = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/MIL_output/MIL_checkpoint_best_10epochs_k10.pth'  # Update this path
    data_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/data_loader/'  # Folder to save data files
    output_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/ML/testing/'  # Output folder path for positive tiles

    main(args, slide_directory, model_path, data_folder, output_folder)

In [None]:
# @title Hidden Cell
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv
from skimage.color import rgb2hed, hed2rgb
from skimage import img_as_ubyte

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def color_separate(ihc_rgb):
    # Convert the RGB image to HED using the prebuilt skimage method
    ihc_hed = rgb2hed(ihc_rgb)

    # Create an RGB image for each of the separated stains
    # Convert them to ubyte for easy saving to drive as an image
    null = np.zeros_like(ihc_hed[:, :, 0])
    ihc_h = img_as_ubyte(hed2rgb(np.stack((ihc_hed[:, :, 0], null, null), axis=-1)))
    ihc_d = img_as_ubyte(hed2rgb(np.stack((null, null, ihc_hed[:, :, 2]), axis=-1)))
    ihc_e = img_as_ubyte(hed2rgb(np.stack((null, ihc_hed[:, :, 1], null), axis=-1)))  # Eosin channel

    return (ihc_h, ihc_d, ihc_e)

def main(args, model_path, data_path, output_csv, output_folder):
    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # Load data
    dset = MILdataset(data_path, trans)
    loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    dset.setmode(1)
    probs, tile_info = inference(loader, model)

    # Initialize counters
    positive_count = 0
    negative_count = 0

    # Collect slide-level information
    slide_info = {}

    # Initialize a list to store mean intensities for each slide
    slide_mean_intensities = {}

    # Initialize a list to store probabilities for all tiles where prob >= 0.5
    slide_mean_probs = {}

    # Initialize dictionaries to store mean intensity values for IHC channels
    slide_mean_ihc_h = {}
    slide_mean_ihc_e = {}
    slide_mean_ihc_d = {}

    for slideIDX, coord, prob in tile_info:
        slide_name = dset.slidenames[slideIDX]
        if slide_name not in slide_info:
            slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

        if prob >= 0.5:
            slide_info[slide_name]['positive_tiles'] += 1
            positive_count += 1
            img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')

            # Compute mean RGB intensity
            img_array = np.array(img)
            mean_intensity = np.mean(img_array, axis=(0, 1))  # Compute mean intensity for each channel (R, G, B)

            # Extract IHC channel and compute mean intensity
            ihc_h, ihc_d, ihc_e = color_separate(img_array)
            mean_intensity_ihc_h = np.mean(ihc_h)
            mean_intensity_ihc_d = np.mean(ihc_d)
            mean_intensity_ihc_e = np.mean(ihc_e)

            # Save the image
            #img.save(os.path.join(output_folder, f"{slide_name}_{coord[0]}_{coord[1]}.png"))

            # Update slide mean intensity
            if slide_name not in slide_mean_intensities:
                slide_mean_intensities[slide_name] = []
            slide_mean_intensities[slide_name].append(mean_intensity)

            # Update slide mean probability
            if slide_name not in slide_mean_probs:
                slide_mean_probs[slide_name] = []
            slide_mean_probs[slide_name].append(prob)

            # Update slide mean intensity for IHC channels
            if slide_name not in slide_mean_ihc_h:
                slide_mean_ihc_h[slide_name] = []
            slide_mean_ihc_h[slide_name].append(mean_intensity_ihc_h)

            if slide_name not in slide_mean_ihc_d:
                slide_mean_ihc_d[slide_name] = []
            slide_mean_ihc_d[slide_name].append(mean_intensity_ihc_d)

            if slide_name not in slide_mean_ihc_e:
                slide_mean_ihc_e[slide_name] = []
            slide_mean_ihc_e[slide_name].append(mean_intensity_ihc_e)
        else:
            slide_info[slide_name]['negative_tiles'] += 1
            negative_count += 1

    # Calculate the mean intensity for each slide with positive tiles
    for slide_name, intensities in slide_mean_intensities.items():
        mean_intensity = np.mean(intensities, axis=0)  # Compute mean intensity across all tiles for each channel
        slide_info[slide_name]['mean_intensity'] = mean_intensity

    # Calculate the mean probability for each slide with positive tiles
    for slide_name, probabilities in slide_mean_probs.items():
        mean_prob = np.mean(probabilities)
        slide_info[slide_name]['mean_prob'] = mean_prob

    # Calculate the mean intensity for each slide with positive tiles for IHC channels
    for slide_name, intensities in slide_mean_ihc_h.items():
        mean_intensity_ihc_h = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_h'] = mean_intensity_ihc_h

    for slide_name, intensities in slide_mean_ihc_d.items():
        mean_intensity_ihc_d = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_d'] = mean_intensity_ihc_d

    for slide_name, intensities in slide_mean_ihc_e.items():
        mean_intensity_ihc_e = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_e'] = mean_intensity_ihc_e

    # Calculate distances between tiles with prob >= 0.5
    for slide_name, info in slide_info.items():
        coords = [coord for slideIDX, coord, prob in tile_info if slideIDX == dset.slidenames.index(slide_name) and prob >= 0.5]
        if len(coords) > 1:
            distances = np.linalg.norm(np.diff(coords, axis=0), axis=1)
            mean_distance = np.mean(distances)
            max_distance = np.max(distances)
            min_distance = np.min(distances)
        else:
            mean_distance = 0
            max_distance = 0
            min_distance = 0
        info['mean_distance'] = mean_distance
        info['max_distance'] = max_distance
        info['min_distance'] = min_distance

    # Write to CSV file
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label', 'mean_intensity', 'mean_prob', 'mean_intensity_ihc_h', 'mean_intensity_ihc_d', 'mean_intensity_ihc_e', 'mean_distance', 'max_distance', 'min_distance']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for slide_name, info in slide_info.items():
            total_tiles = info['positive_tiles'] + info['negative_tiles']
            if total_tiles > 0:
                positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
            else:
                positive_percentage = 0

            ground_truth = get_ground_truth(slide_name)
            predicted_label = get_predicted_label(info['positive_tiles'])

            # Write mean intensity value to CSV
            mean_intensity_str = ','.join(map(str, info.get('mean_intensity', ['N/A'])))

            # Write mean probability value to CSV
            mean_prob = info.get('mean_prob', 'N/A')

            # Write mean intensity of IHC channels to CSV
            mean_intensity_ihc_h = info.get('mean_intensity_ihc_h', 'N/A')
            mean_intensity_ihc_d = info.get('mean_intensity_ihc_d', 'N/A')
            mean_intensity_ihc_e = info.get('mean_intensity_ihc_e', 'N/A')

            writer.writerow({'slide_name': slide_name,
                             'ground_truth': ground_truth,
                             'predicted_label': predicted_label,
                             'positive_tiles': info['positive_tiles'],
                             'negative_tiles': info['negative_tiles'],
                             'positive_percentage': positive_percentage,
                             'mean_prob': mean_prob,
                             'mean_intensity': mean_intensity_str,
                             'mean_intensity_ihc_h': mean_intensity_ihc_h,
                             'mean_intensity_ihc_e': mean_intensity_ihc_e,
                             'mean_intensity_ihc_d': mean_intensity_ihc_d,
                             'mean_distance': info['mean_distance'],
                             'max_distance': info['max_distance'],
                             'min_distance': info['min_distance']})

    # Calculate percentage of positive predictions
    total_predictions = positive_count + negative_count
    if total_predictions > 0:
        positive_percentage = (positive_count / total_predictions) * 100
    else:
        positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

# Example usage
args = Namespace(batch_size=100, workers=4)
model_path = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update this path
data_path = '/content/mil_testing_dummy_data.pt'  # Update this path
output_csv = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy/testing_dummy_prediction_results.csv'  # Output CSV file path
output_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy/'  # Output folder path for positive tiles

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)
main(args, model_path, data_path, output_csv, output_folder)

In [None]:
# @title Hidden Cell
import os
import torch
import openslide
from tqdm import tqdm

# Configuration
slide_directory = "/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy"
tile_size = 224  # Tile size (square tiles)
overlap = 0  # Overlap between tiles
mult = 1.0  # Scale factor, usually 1 for no scaling
level = 0  # Pyramid level from which to read tiles

# Function to generate tile coordinates for a single slide
def generate_grid(slide_path, tile_size=224, overlap=0):
    slide = openslide.OpenSlide(slide_path)
    dimensions = slide.level_dimensions[level]
    grid = []
    for x in range(0, dimensions[0], tile_size - overlap):
        for y in range(0, dimensions[1], tile_size - overlap):
            grid.append((x, y))
    return grid

# Extract label directly from the file name
def extract_label_from_filename(filename):
    # Extract the label (last character before '.svs')
    return int(filename.split('_')[-1].split('.')[0])

# Prepare the dataset
def prepare_mil_input_data(slide_directory, tile_size=224, overlap=0, mult=1.0, level=0):
    slides = [f for f in os.listdir(slide_directory) if f.endswith('.svs')]
    slides.sort()  # Ensure consistent order
    full_paths = [os.path.join(slide_directory, f) for f in slides]

    grid = []
    targets = []

    for slide_path in tqdm(full_paths, desc="Generating grids and targets"):
        grid.append(generate_grid(slide_path, tile_size, overlap))
        slide_name = os.path.basename(slide_path)
        targets.append(extract_label_from_filename(slide_name))

    data = {
        'slides': full_paths,
        'grid': grid,
        'targets': targets,
        'mult': mult,
        'level': level
    }

    return data

# Main script
if __name__ == "__main__":
    data = prepare_mil_input_data(slide_directory, tile_size, overlap, mult, level)
    torch.save(data, "mil_testing_dummy_data.pt")

In [None]:
# @title Hidden Cell
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv
from skimage.color import rgb2hed, hed2rgb
from skimage import img_as_ubyte
from tqdm import tqdm

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def color_separate(ihc_rgb):
    # Convert the RGB image to HED using the prebuilt skimage method
    ihc_hed = rgb2hed(ihc_rgb)

    # Create an RGB image for each of the separated stains
    # Convert them to ubyte for easy saving to drive as an image
    null = np.zeros_like(ihc_hed[:, :, 0])
    ihc_h = img_as_ubyte(hed2rgb(np.stack((ihc_hed[:, :, 0], null, null), axis=-1)))
    ihc_d = img_as_ubyte(hed2rgb(np.stack((null, null, ihc_hed[:, :, 2]), axis=-1)))
    ihc_e = img_as_ubyte(hed2rgb(np.stack((null, ihc_hed[:, :, 1], null), axis=-1)))  # Eosin channel

    return ihc_h, ihc_d, ihc_e

def main(args, slide_directory, model_path, output_folder):
    # Function to generate tile coordinates for a single slide
    def generate_grid(slide_path, tile_size=224, overlap=0):
        slide = openslide.OpenSlide(slide_path)
        dimensions = slide.level_dimensions[args.level]
        grid = []
        for x in range(0, dimensions[0], args.tile_size - overlap):
            for y in range(0, dimensions[1], args.tile_size - overlap):
                grid.append((x, y))
        return grid

    # Extract label directly from the file name
    def extract_label_from_filename(filename):
        # Extract the label (last character before '.svs')
        return int(filename.split('_')[-1].split('.')[0])

    # Prepare the dataset
    def prepare_mil_input_data(slide_path, tile_size=224, overlap=0, mult=1.0, level=0):
        grid = generate_grid(slide_path, tile_size, overlap)
        slide_name = os.path.basename(slide_path)
        targets = extract_label_from_filename(slide_name)

        data = {
            'slides': [slide_path],
            'grid': [grid],
            'targets': [targets],
            'mult': mult,
            'level': level
        }

        return data

    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    for filename in os.listdir(slide_directory):
        if filename.endswith(".svs"):
            slide_path = os.path.join(slide_directory, filename)
            output_csv = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_prediction_results.csv")

            # Prepare MIL input data
            data = prepare_mil_input_data(slide_path, args.tile_size, args.overlap, args.mult, args.level)
            torch.save(data, f"{os.path.splitext(filename)[0]}_data.pt")

            # Load data
            dset = MILdataset(f"{os.path.splitext(filename)[0]}_data.pt", trans)
            loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

            dset.setmode(1)
            probs, tile_info = inference(loader, model)

            # Initialize counters
            positive_count = 0
            negative_count = 0

            # Collect slide-level information
            slide_info = {}

            # Initialize a list to store mean intensities for each slide
            slide_mean_intensities = {}

            # Initialize a list to store probabilities for all tiles where prob >= 0.5
            slide_mean_probs = {}

            # Initialize dictionaries to store mean intensity values for IHC channels
            slide_mean_ihc_h = {}
            slide_mean_ihc_e = {}
            slide_mean_ihc_d = {}

            for slideIDX, coord, prob in tile_info:
                slide_name = dset.slidenames[slideIDX]
                if slide_name not in slide_info:
                    slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

                if prob >= 0.5:
                    slide_info[slide_name]['positive_tiles'] += 1
                    positive_count += 1
                    img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')

                    # Compute mean RGB intensity
                    img_array = np.array(img)
                    mean_intensity = np.mean(img_array, axis=(0, 1))  # Compute mean intensity for each channel (R, G, B)

                    # Extract IHC channel and compute mean intensity
                    ihc_h, ihc_d, ihc_e = color_separate(img_array)
                    mean_intensity_ihc_h = np.mean(ihc_h)
                    mean_intensity_ihc_d = np.mean(ihc_d)
                    mean_intensity_ihc_e = np.mean(ihc_e)

                    # Update slide mean intensity
                    if slide_name not in slide_mean_intensities:
                        slide_mean_intensities[slide_name] = []
                    slide_mean_intensities[slide_name].append(mean_intensity)

                    # Update slide mean probability
                    if slide_name not in slide_mean_probs:
                        slide_mean_probs[slide_name] = []
                    slide_mean_probs[slide_name].append(prob)

                    # Update slide mean intensity for IHC channels
                    if slide_name not in slide_mean_ihc_h:
                        slide_mean_ihc_h[slide_name] = []
                    slide_mean_ihc_h[slide_name].append(mean_intensity_ihc_h)

                    if slide_name not in slide_mean_ihc_d:
                        slide_mean_ihc_d[slide_name] = []
                    slide_mean_ihc_d[slide_name].append(mean_intensity_ihc_d)

                    if slide_name not in slide_mean_ihc_e:
                        slide_mean_ihc_e[slide_name] = []
                    slide_mean_ihc_e[slide_name].append(mean_intensity_ihc_e)
                else:
                    slide_info[slide_name]['negative_tiles'] += 1
                    negative_count += 1

            # Calculate the mean intensity for each slide with positive tiles
            for slide_name, intensities in slide_mean_intensities.items():
                mean_intensity = np.mean(intensities, axis=0)  # Compute mean intensity across all tiles for each channel
                slide_info[slide_name]['mean_intensity'] = mean_intensity

            # Calculate the mean probability for each slide with positive tiles
            for slide_name, probabilities in slide_mean_probs.items():
                mean_prob = np.mean(probabilities)
                slide_info[slide_name]['mean_prob'] = mean_prob

            # Calculate the mean intensity for each slide with positive tiles for IHC channels
            for slide_name, intensities in slide_mean_ihc_h.items():
                mean_intensity_ihc_h = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_h'] = mean_intensity_ihc_h

            for slide_name, intensities in slide_mean_ihc_d.items():
                mean_intensity_ihc_d = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_d'] = mean_intensity_ihc_d

            for slide_name, intensities in slide_mean_ihc_e.items():
                mean_intensity_ihc_e = np.mean(intensities)
                slide_info[slide_name]['mean_intensity_ihc_e'] = mean_intensity_ihc_e

            # Calculate distances between tiles with prob >= 0.5
            for slide_name, info in slide_info.items():
                coords = [coord for slideIDX, coord, prob in tile_info if slideIDX == dset.slidenames.index(slide_name) and prob >= 0.5]
                if len(coords) > 1:
                    distances = np.linalg.norm(np.diff(coords, axis=0), axis=1)
                    mean_distance = np.mean(distances)
                    max_distance = np.max(distances)
                    min_distance = np.min(distances)
                else:
                    mean_distance = 0
                    max_distance = 0
                    min_distance = 0
                info['mean_distance'] = mean_distance
                info['max_distance'] = max_distance
                info['min_distance'] = min_distance

            # Write to CSV file
            with open(output_csv, 'w', newline='') as csvfile:
                fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label', 'mean_intensity', 'mean_prob', 'mean_intensity_ihc_h', 'mean_intensity_ihc_d', 'mean_intensity_ihc_e', 'mean_distance', 'max_distance', 'min_distance']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                writer.writeheader()
                for slide_name, info in slide_info.items():
                    total_tiles = info['positive_tiles'] + info['negative_tiles']
                    if total_tiles > 0:
                        positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
                    else:
                        positive_percentage = 0

                    ground_truth = get_ground_truth(slide_name)
                    predicted_label = get_predicted_label(info['positive_tiles'])

                    # Write mean intensity value to CSV
                    mean_intensity_str = ','.join(map(str, info.get('mean_intensity', ['N/A'])))

                    # Write mean probability value to CSV
                    mean_prob = info.get('mean_prob', 'N/A')

                    # Write mean intensity of IHC channels to CSV
                    mean_intensity_ihc_h = info.get('mean_intensity_ihc_h', 'N/A')
                    mean_intensity_ihc_d = info.get('mean_intensity_ihc_d', 'N/A')
                    mean_intensity_ihc_e = info.get('mean_intensity_ihc_e', 'N/A')

                    writer.writerow({'slide_name': slide_name,
                                    'ground_truth': ground_truth,
                                    'predicted_label': predicted_label,
                                    'positive_tiles': info['positive_tiles'],
                                    'negative_tiles': info['negative_tiles'],
                                    'positive_percentage': positive_percentage,
                                    'mean_prob': mean_prob,
                                    'mean_intensity': mean_intensity_str,
                                    'mean_intensity_ihc_h': mean_intensity_ihc_h,
                                    'mean_intensity_ihc_e': mean_intensity_ihc_e,
                                    'mean_intensity_ihc_d': mean_intensity_ihc_d,
                                    'mean_distance': info['mean_distance'],
                                    'max_distance': info['max_distance'],
                                    'min_distance': info['min_distance']})

            # Calculate percentage of positive predictions
            total_predictions = positive_count + negative_count
            if total_predictions > 0:
                positive_percentage = (positive_count / total_predictions) * 100
            else:
                positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

if __name__ == "__main__":
    args = Namespace(batch_size=100, workers=4, tile_size=224, overlap=0, mult=1.0, level=0)
    slide_directory = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/validation_WSI'
    model_path = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update this path
    output_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy/'  # Output folder path for positive tiles

    main(args, slide_directory, model_path, output_folder)


In [None]:
# @title Hidden Cell
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def main(args, model_path, data_path, output_csv):
    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # Load data
    dset = MILdataset(data_path, trans)
    loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    dset.setmode(1)
    probs, tile_info = inference(loader, model)

    # Initialize counters
    positive_count = 0
    negative_count = 0

    # Collect slide-level information
    slide_info = {}

    for slideIDX, coord, prob in tile_info:
        slide_name = dset.slidenames[slideIDX]
        print(f"Processing slide: {slide_name}")  # Print the slide name
        if slide_name not in slide_info:
            slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

        if prob >= 0.5:
            slide_info[slide_name]['positive_tiles'] += 1
            positive_count += 1
        else:
            slide_info[slide_name]['negative_tiles'] += 1
            negative_count += 1

    # Write to CSV file
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for slide_name, info in slide_info.items():
            total_tiles = info['positive_tiles'] + info['negative_tiles']
            if total_tiles > 0:
                positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
            else:
                positive_percentage = 0

            # Get ground truth from slide name
            ground_truth = get_ground_truth(slide_name)

            # Get predicted label based on the total number of positive tiles
            predicted_label = get_predicted_label(info['positive_tiles'])

            writer.writerow({'slide_name': slide_name,
                             'positive_tiles': info['positive_tiles'],
                             'negative_tiles': info['negative_tiles'],
                             'positive_percentage': positive_percentage,
                             'ground_truth': ground_truth,
                             'predicted_label': predicted_label})

    # Calculate percentage of positive predictions
    total_predictions = positive_count + negative_count
    if total_predictions > 0:
        positive_percentage = (positive_count / total_predictions) * 100
    else:
        positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

# Example usage
args = Namespace(batch_size=100, workers=4)
model_path = '/content/drive/MyDrive/KS_deepLearning/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update this path
data_path = '/content/mil_testing_dummy_data.pt'  # Update this path
output_csv = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_output/testing_dummy_prediction_results.csv'  # Output CSV file path

main(args, model_path, data_path, output_csv)

Generating Heatmap

In [None]:
# @title Hidden Cell
import sys
import os
import numpy as np
import openslide
from PIL import Image
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        self.slides = [openslide.OpenSlide(name) for name in lib['slides']]
        self.grid = [coord for g in lib['grid'] for coord in g]
        self.slideIDX = [i for i, g in enumerate(lib['grid']) for _ in g]
        self.transform = transform
        self.mode = None
        print(f'Loaded {len(lib["slides"])} slides with a total of {len(self.grid)} tiles.')

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []
    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:, 1].clone()
            probs[i * loader.batch_size:i * loader.batch_size + input.size(0)] = batch_probs
            for j in range(input.size(0)):
                idx = i * loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def overlay_heatmap_on_slide(slide, tile_info, output_path):
    level = len(slide.level_dimensions) - 1
    level_dimensions = slide.level_dimensions[level]
    level_downsample = slide.level_downsamples[level]
    heatmap = np.zeros(level_dimensions[::-1])  # Initialize an empty heatmap

    for coord, prob in tile_info:
        if prob >= 0.5:  # Only consider tiles with probability >= 0.5
            x, y = coord
            x_level, y_level = int(x / level_downsample), int(y / level_downsample)
            tile_size_level = int(224 / level_downsample)
            y_end = min(y_level + tile_size_level, heatmap.shape[0])
            x_end = min(x_level + tile_size_level, heatmap.shape[1])
            heatmap[y_level:y_end, x_level:x_end] = np.maximum(heatmap[y_level:y_end, x_level:x_end], prob)

    norm = Normalize(vmin=0, vmax=1, clip=True)
    mapper = ScalarMappable(norm=norm, cmap='Reds')
    colored_heatmap = mapper.to_rgba(heatmap, bytes=True)[:, :, :3]
    slide_thumbnail = slide.get_thumbnail(level_dimensions)
    slide_thumbnail = np.array(slide_thumbnail.convert("RGB"))
    overlay = Image.fromarray((slide_thumbnail * 0.5 + colored_heatmap * 0.5).astype('uint8'))

    # Plotting the original slide image and the overlay side by side
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(slide_thumbnail)
    axs[0].axis('off')
    axs[0].set_title('Original Slide')

    axs[1].imshow(overlay)
    axs[1].axis('off')
    axs[1].set_title('Overlay with Heatmap')

    plt.tight_layout()
    plt.show()

def main(args, model_path, data_path):
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.load_state_dict(torch.load(model_path)['state_dict'])
    model = model.cuda()

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    dset = MILdataset(data_path, trans)
    loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    dset.setmode(1)
    _, tile_info = inference(loader, model)

    positive_count = sum(prob >= 0.5 for _, _, prob in tile_info)
    negative_count = len(tile_info) - positive_count

    # Calculate percentage of positive predictions
    total_predictions = positive_count + negative_count
    if total_predictions > 0:
        positive_percentage = (positive_count / total_predictions) * 100
    else:
        positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

    # Print total 1's, 0's, and percentage of positive predictions
    print(f"Total 1's (Positive Predictions): {positive_count}")
    print(f"Total 0's (Negative Predictions): {negative_count}")
    print(f"Percentage of Predicted Positive Tiles: {positive_percentage:.2f}%")

    # Plotting the counts
    plt.figure(figsize=(4, 4))
    plt.bar(['Positive Predictions', 'Negative Predictions'], [positive_count, negative_count], color=['blue', 'red'])
    plt.title('Count of Positive and Negative Predictions')
    plt.ylabel('Count')
    plt.show()

    aggregated_info = {}
    for slideIDX, coord, prob in tile_info:
        if slideIDX not in aggregated_info:
            aggregated_info[slideIDX] = []
        aggregated_info[slideIDX].append((coord, prob))

    for slideIDX, info in aggregated_info.items():
        output_path = f'heatmap_slide_{slideIDX}.png'
        overlay_heatmap_on_slide(dset.slides[slideIDX], info, output_path)

args = Namespace(batch_size=256, workers=4)
model_path = '/content/drive/MyDrive/KS_deepLearning/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update with your path
data_path = '/content/mil_testing_dummy_data.pt'  # Update with your path

main(args, model_path, data_path)


In [None]:
# @title Hidden Cell
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def main(args, model_path, data_path, output_csv, output_folder):
    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # Load data
    dset = MILdataset(data_path, trans)
    loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    dset.setmode(1)
    probs, tile_info = inference(loader, model)

    # Initialize counters
    positive_count = 0
    negative_count = 0

    # Collect slide-level information
    slide_info = {}

    for slideIDX, coord, prob in tile_info:
        slide_name = dset.slidenames[slideIDX]
        #print(f"Processing slide: {slide_name}")  # Print the slide name
        if slide_name not in slide_info:
            slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

        if prob >= 0.5:
            slide_info[slide_name]['positive_tiles'] += 1
            positive_count += 1
            # Get the original image from the dataset
            img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
            # Save the image
            img.save(os.path.join(output_folder, f"{slide_name}_{coord[0]}_{coord[1]}.png"))
        else:
            slide_info[slide_name]['negative_tiles'] += 1
            negative_count += 1

    # Write to CSV file
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for slide_name, info in slide_info.items():
            total_tiles = info['positive_tiles'] + info['negative_tiles']
            if total_tiles > 0:
                positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
            else:
                positive_percentage = 0

            # Get ground truth from slide name
            ground_truth = get_ground_truth(slide_name)

            # Get predicted label based on the total number of positive tiles
            predicted_label = get_predicted_label(info['positive_tiles'])

            writer.writerow({'slide_name': slide_name,
                             'positive_tiles': info['positive_tiles'],
                             'negative_tiles': info['negative_tiles'],
                             'positive_percentage': positive_percentage,
                             'ground_truth': ground_truth,
                             'predicted_label': predicted_label})

    # Calculate percentage of positive predictions
    total_predictions = positive_count + negative_count
    if total_predictions > 0:
        positive_percentage = (positive_count / total_predictions) * 100
    else:
        positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

# Example usage
args = Namespace(batch_size=100, workers=4)
model_path = '/content/drive/MyDrive/KS_deepLearning/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update this path
data_path = '/content/mil_testing_dummy_data.pt'  # Update this path
output_csv = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_output/testing_dummy_prediction_results.csv'  # Output CSV file path
output_folder = '/content/drive/MyDrive/KS_deepLearning/testing_dummy/'  # Output folder path for positive tiles

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

main(args, model_path, data_path, output_csv, output_folder)

In [None]:
# @title Hidden Cell
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
import csv
from skimage.color import rgb2hed, hed2rgb
from skimage import img_as_ubyte
import matplotlib.pyplot as plt

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            batch_probs = output.detach()[:,1].clone()
            probs[i*loader.batch_size:i*loader.batch_size+input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i*loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def color_separate(ihc_rgb):
    # Convert the RGB image to HED using the prebuilt skimage method
    ihc_hed = rgb2hed(ihc_rgb)

    # Create an RGB image for each of the separated stains
    # Convert them to ubyte for easy saving to drive as an image
    null = np.zeros_like(ihc_hed[:, :, 0])
    ihc_h = img_as_ubyte(hed2rgb(np.stack((ihc_hed[:, :, 0], null, null), axis=-1)))
    ihc_d = img_as_ubyte(hed2rgb(np.stack((null, null, ihc_hed[:, :, 2]), axis=-1)))
    ihc_e = img_as_ubyte(hed2rgb(np.stack((null, ihc_hed[:, :, 1], null), axis=-1)))  # Eosin channel

    return (ihc_h, ihc_d, ihc_e)

def plot_image_with_distances(slide, coords, mean_distance, min_distance, max_distance):
    # Plot the image
    plt.imshow(slide)

    # Overlay lines representing mean, min, and max distances between tiles
    for i in range(len(coords) - 1):
        plt.plot([coords[i][0], coords[i + 1][0]], [coords[i][1], coords[i + 1][1]], color='blue', alpha=0)

    # Plot line representing mean distance
    plt.plot([coords[0][0], coords[-1][0]], [coords[0][1], coords[-1][1]], color='green', linestyle='dashed', label=f'Mean Dist: {mean_distance:.2f}')

    # Plot line representing min distance
    min_idx = np.argmin([min_distance, max_distance])
    plt.plot([coords[0][0], coords[min_idx][0]], [coords[0][1], coords[min_idx][1]], color='red', linestyle='dashed', label=f'Min Dist: {min_distance:.2f}')

    # Plot line representing max distance
    max_idx = np.argmax([min_distance, max_distance])
    plt.plot([coords[0][0], coords[max_idx][0]], [coords[0][1], coords[max_idx][1]], color='orange', linestyle='dashed', label=f'Max Dist: {max_distance:.2f}')

    # Set labels and show plot
    plt.xlabel('X-coordinate')
    plt.ylabel('Y-coordinate')
    plt.title('Image with Distances')
    plt.legend()
    plt.show()

def main(args, model_path, data_path, output_csv, output_folder):
    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # Load data
    dset = MILdataset(data_path, trans)
    loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    dset.setmode(1)
    probs, tile_info = inference(loader, model)

    # Initialize counters
    positive_count = 0
    negative_count = 0

    # Collect slide-level information
    slide_info = {}

    # Initialize a list to store mean intensities for each slide
    slide_mean_intensities = {}

    # Initialize a list to store probabilities for all tiles where prob >= 0.5
    slide_mean_probs = {}

    # Initialize dictionaries to store mean intensity values for IHC channels
    slide_mean_ihc_h = {}
    slide_mean_ihc_e = {}
    slide_mean_ihc_d = {}

    for slideIDX, coord, prob in tile_info:
        slide_name = dset.slidenames[slideIDX]
        if slide_name not in slide_info:
            slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0}

        if prob >= 0.5:
            slide_info[slide_name]['positive_tiles'] += 1
            positive_count += 1
            img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')

            # Compute mean RGB intensity
            img_array = np.array(img)
            mean_intensity = np.mean(img_array, axis=(0, 1))  # Compute mean intensity for each channel (R, G, B)

            # Extract IHC channel and compute mean intensity
            ihc_h, ihc_d, ihc_e = color_separate(img_array)
            mean_intensity_ihc_h = np.mean(ihc_h)
            mean_intensity_ihc_d = np.mean(ihc_d)
            mean_intensity_ihc_e = np.mean(ihc_e)

            # Save the image
            #img.save(os.path.join(output_folder, f"{slide_name}_{coord[0]}_{coord[1]}.png"))

            # Update slide mean intensity
            if slide_name not in slide_mean_intensities:
                slide_mean_intensities[slide_name] = []
            slide_mean_intensities[slide_name].append(mean_intensity)

            # Update slide mean probability
            if slide_name not in slide_mean_probs:
                slide_mean_probs[slide_name] = []
            slide_mean_probs[slide_name].append(prob)

            # Update slide mean intensity for IHC channels
            if slide_name not in slide_mean_ihc_h:
                slide_mean_ihc_h[slide_name] = []
            slide_mean_ihc_h[slide_name].append(mean_intensity_ihc_h)

            if slide_name not in slide_mean_ihc_d:
                slide_mean_ihc_d[slide_name] = []
            slide_mean_ihc_d[slide_name].append(mean_intensity_ihc_d)

            if slide_name not in slide_mean_ihc_e:
                slide_mean_ihc_e[slide_name] = []
            slide_mean_ihc_e[slide_name].append(mean_intensity_ihc_e)
        else:
            slide_info[slide_name]['negative_tiles'] += 1
            negative_count += 1

    # Calculate the mean intensity for each slide with positive tiles
    for slide_name, intensities in slide_mean_intensities.items():
        mean_intensity = np.mean(intensities, axis=0)  # Compute mean intensity across all tiles for each channel
        slide_info[slide_name]['mean_intensity'] = mean_intensity

    # Calculate the mean probability for each slide with positive tiles
    for slide_name, probabilities in slide_mean_probs.items():
        mean_prob = np.mean(probabilities)
        slide_info[slide_name]['mean_prob'] = mean_prob

    # Calculate the mean intensity for each slide with positive tiles for IHC channels
    for slide_name, intensities in slide_mean_ihc_h.items():
        mean_intensity_ihc_h = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_h'] = mean_intensity_ihc_h

    for slide_name, intensities in slide_mean_ihc_d.items():
        mean_intensity_ihc_d = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_d'] = mean_intensity_ihc_d

    for slide_name, intensities in slide_mean_ihc_e.items():
        mean_intensity_ihc_e = np.mean(intensities)
        slide_info[slide_name]['mean_intensity_ihc_e'] = mean_intensity_ihc_e

    # Calculate distances between tiles with prob >= 0.5
    for slide_name, info in slide_info.items():
        coords = [coord for slideIDX, coord, prob in tile_info if slideIDX == dset.slidenames.index(slide_name) and prob >= 0.5]
        if len(coords) > 1:
            distances = np.linalg.norm(np.diff(coords, axis=0), axis=1)
            mean_distance = np.mean(distances)
            max_distance = np.max(distances)
            min_distance = np.min(distances)
        else:
            mean_distance = 0
            max_distance = 0
            min_distance = 0
        info['mean_distance'] = mean_distance
        info['max_distance'] = max_distance
        info['min_distance'] = min_distance

        # Plot the image with distances
        slide = dset.slides[dset.slidenames.index(slide_name)].read_region((0, 0), 0, dset.slides[dset.slidenames.index(slide_name)].level_dimensions[0])
        plot_image_with_distances(slide, coords, mean_distance, min_distance, max_distance)

    # Write to CSV file
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['slide_name', 'positive_tiles', 'negative_tiles', 'positive_percentage', 'ground_truth', 'predicted_label', 'mean_intensity', 'mean_prob', 'mean_intensity_ihc_h', 'mean_intensity_ihc_d', 'mean_intensity_ihc_e', 'mean_distance', 'max_distance', 'min_distance']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for slide_name, info in slide_info.items():
            total_tiles = info['positive_tiles'] + info['negative_tiles']
            if total_tiles > 0:
                positive_percentage = round((info['positive_tiles'] / total_tiles) * 100, 2)
            else:
                positive_percentage = 0

            ground_truth = get_ground_truth(slide_name)
            predicted_label = get_predicted_label(info['positive_tiles'])

            # Write mean intensity value to CSV
            mean_intensity_str = ','.join(map(str, info.get('mean_intensity', ['N/A'])))

            # Write mean probability value to CSV
            mean_prob = info.get('mean_prob', 'N/A')

            # Write mean intensity of IHC channels to CSV
            mean_intensity_ihc_h = info.get('mean_intensity_ihc_h', 'N/A')
            mean_intensity_ihc_d = info.get('mean_intensity_ihc_d', 'N/A')
            mean_intensity_ihc_e = info.get('mean_intensity_ihc_e', 'N/A')

            writer.writerow({'slide_name': slide_name,
                             'ground_truth': ground_truth,
                             'predicted_label': predicted_label,
                             'positive_tiles': info['positive_tiles'],
                             'negative_tiles': info['negative_tiles'],
                             'positive_percentage': positive_percentage,
                             'mean_prob': mean_prob,
                             'mean_intensity': mean_intensity_str,
                             'mean_intensity_ihc_h': mean_intensity_ihc_h,
                             'mean_intensity_ihc_e': mean_intensity_ihc_e,
                             'mean_intensity_ihc_d': mean_intensity_ihc_d,
                             'mean_distance': info['mean_distance'],
                             'max_distance': info['max_distance'],
                             'min_distance': info['min_distance']})

    # Calculate percentage of positive predictions
    total_predictions = positive_count + negative_count
    if total_predictions > 0:
        positive_percentage = (positive_count / total_predictions) * 100
    else:
        positive_percentage = 0  # Handle case where there are no predictions to avoid division by zero

# Example usage
args = Namespace(batch_size=100, workers=4)
model_path = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/MIL_output/MIL_checkpoint_best_25epochs_k60.pth'  # Update this path
data_path = '/content/mil_testing_dummy_data.pt'  # Update this path
output_csv = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy/testing_dummy_prediction_results.csv'  # Output CSV file path
output_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_dummy/'  # Output folder path for positive tiles

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

main(args, model_path, data_path, output_csv, output_folder)



Tile-level parameters: Feature extraction for Machine-learning classifier

In [1]:
import sys
import os
import numpy as np
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from argparse import Namespace
import csv
from skimage.color import rgb2hed, hed2rgb
from skimage import img_as_ubyte
from tqdm import tqdm
from scipy.spatial.distance import cdist

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            slides.append(openslide.OpenSlide(name))
        print('\nDone loading slides.')

        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)
            slideIDX.extend([i]*len(g))

        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None

    def setmode(self, mode):
        self.mode = mode

    def __getitem__(self, index):
        slideIDX = self.slideIDX[index]
        coord = self.grid[index]
        img = self.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.grid)

def inference(loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset))
    tile_info = []  # To store information about the tiles

    with torch.no_grad():
        for i, input in enumerate(loader):
            input = input.cuda()
            output = F.softmax(model(input), dim=1)  # Apply softmax to get probabilities
            batch_probs = output.detach()[:, 1].clone()  # Probability of positive class
            probs[i * loader.batch_size:i * loader.batch_size + input.size(0)] = batch_probs

            for j in range(input.size(0)):
                idx = i * loader.batch_size + j
                slideIDX = loader.dataset.slideIDX[idx]
                coord = loader.dataset.grid[idx]
                prob = batch_probs[j].item()
                tile_info.append((slideIDX, coord, prob))
    return probs.cpu().numpy(), tile_info

def get_ground_truth(slide_name):
    filename = os.path.basename(slide_name)
    ground_truth = int(filename[-5])
    return ground_truth

def get_predicted_label(positive_tiles):
    if positive_tiles > 0:
        return 1
    else:
        return 0

def color_separate(ihc_rgb):
    # Convert the RGB image to HED using the prebuilt skimage method
    ihc_hed = rgb2hed(ihc_rgb)

    # Create an RGB image for each of the separated stains
    # Convert them to ubyte for easy saving to drive as an image
    null = np.zeros_like(ihc_hed[:, :, 0])
    ihc_h = img_as_ubyte(hed2rgb(np.stack((ihc_hed[:, :, 0], null, null), axis=-1)))
    ihc_d = img_as_ubyte(hed2rgb(np.stack((null, null, ihc_hed[:, :, 2]), axis=-1)))

    return ihc_h, ihc_d

def compute_distances(coords):
    if coords.ndim == 1:
        coords = coords.reshape(-1, 2)  # Reshape to ensure 2 dimensions
    if len(coords) == 0:
        return np.array([]), np.array([])  # Return empty arrays if coords is empty
    dist_matrix = cdist(coords, coords)
    np.fill_diagonal(dist_matrix, np.inf)  # Set diagonal elements to infinity to ignore them in calculations
    min_distances = np.min(dist_matrix, axis=1)
    mean_distances = np.mean(dist_matrix, axis=1)
    return min_distances, mean_distances

def main(args, slide_directory, model_path, data_folder, output_folder):
    os.makedirs(data_folder, exist_ok=True)  # Create data folder if it doesn't exist

    # Function to generate tile coordinates for a single slide
    def generate_grid(slide_path, tile_size=224, overlap=0):
        slide = openslide.OpenSlide(slide_path)
        dimensions = slide.level_dimensions[args.level]

        # Check if width is greater than 20000 pixels
        if dimensions[0] > 18000:
            # Divide the slide and use one part for processing
            dimensions = (18000, dimensions[1])

        grid = []
        for x in range(0, dimensions[0], tile_size - overlap):
            for y in range(0, dimensions[1], tile_size - overlap):
                grid.append((x, y))
        return grid

    # Extract label directly from the file name
    def extract_label_from_filename(filename):
        try:
            # Extract the label (last character before '.svs')
            label_str = filename.split('_')[-1].split('.')[0]
            # Extract numeric part of the label string
            label_numeric = ''.join(filter(str.isdigit, label_str))
            label = int(label_numeric)  # Convert the numeric part to an integer
            return label
        except ValueError:
            print(f"Failed to extract label from filename: {filename}")
            return None

    # Prepare the dataset
    def prepare_mil_input_data(slide_path, tile_size=224, overlap=0, mult=1.0, level=0):
        grid = generate_grid(slide_path, tile_size, overlap)
        slide_name = os.path.basename(slide_path)
        targets = extract_label_from_filename(slide_name)

        data = {
            'slides': [slide_path],
            'grid': [grid],
            'targets': [targets],
            'mult': mult,
            'level': level
        }

        return data

    # Load model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    # Data transformations
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    for filename in os.listdir(slide_directory):
        if filename.endswith(".svs"):
            slide_path = os.path.join(slide_directory, filename)
            data_file = os.path.join(data_folder, f"{os.path.splitext(filename)[0]}_data.pt")

            # Prepare MIL input data
            data = prepare_mil_input_data(slide_path, args.tile_size, args.overlap, args.mult, args.level)
            torch.save(data, data_file)

            # Load data
            dset = MILdataset(data_file, trans)
            loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

            dset.setmode(1)
            probs, tile_info = inference(loader, model)

            # Initialize counters
            positive_count = 0
            negative_count = 0

            # Collect slide-level information
            slide_info = {}

            for slideIDX, coord, prob in tile_info:
                slide_name = dset.slidenames[slideIDX]
                if slide_name not in slide_info:
                    slide_info[slide_name] = {'positive_tiles': 0, 'negative_tiles': 0, 'tile_intensities': [], 'ihc_intensities': [], 'tile_coords': [], 'probabilities': [], 'ground_truth': get_ground_truth(slide_name)}

                if prob >= 0.5:
                    slide_info[slide_name]['positive_tiles'] += 1
                    positive_count += 1
                    img = dset.slides[slideIDX].read_region(coord, 0, (224, 224)).convert('RGB')
                    img_array = np.array(img)
                    mean_intensity_rgb = np.mean(img_array, axis=(0, 1))

                    # Compute IHC intensities
                    ihc_h, ihc_d = color_separate(img_array)
                    mean_intensity_ihc_h = np.mean(ihc_h)
                    mean_intensity_ihc_d = np.mean(ihc_d)

                    slide_info[slide_name]['tile_intensities'].append(mean_intensity_rgb)
                    slide_info[slide_name]['ihc_intensities'].append((mean_intensity_ihc_h, mean_intensity_ihc_d))

                    slide_info[slide_name]['tile_coords'].append(coord)
                    slide_info[slide_name]['probabilities'].append(prob)
                else:
                    slide_info[slide_name]['negative_tiles'] += 1
                    negative_count += 1

            # Compute distances for each slide
            for slide_name, info in slide_info.items():
                tile_coords = np.array(info['tile_coords'])
                probabilities = np.array(info['probabilities'])
                tile_intensities = np.array(info['tile_intensities'])
                ihc_intensities = np.array(info['ihc_intensities'])

                # Filter tiles based on probability >= 0.5
                valid_tiles_mask = probabilities >= 0.5
                valid_tile_coords = tile_coords[valid_tiles_mask]

                min_distances, mean_distances = compute_distances(valid_tile_coords)

                # Update slide info with distances only for valid tiles
                slide_info[slide_name]['min_distances'] = min_distances
                #slide_info[slide_name]['mean_distances'] = mean_distances

                # Sort tiles based on probabilities
                sorted_indices = np.argsort(probabilities)[::-1]  # Sort in descending order
                top_15_indices = sorted_indices[:15]  # Select top 15 tiles
                slide_info[slide_name]['top_15_indices'] = top_15_indices

            # Save intensities, distances, and probabilities to CSV
            for slide_name, info in slide_info.items():
                tile_intensities = info['tile_intensities']
                ihc_intensities = info['ihc_intensities']
                min_distances = info.get('min_distances', [0])  # Get distances, return empty list if not available
                #mean_distances = info.get('mean_distances', [0])  # Get distances, return empty list if not available
                probabilities = info['probabilities']
                top_15_indices = info['top_15_indices']
                ground_truth = info['ground_truth']
                output_csv = os.path.join(output_folder, f"{os.path.splitext(slide_name)[0]}.csv")
                with open(output_csv, 'w', newline='') as csvfile:
                    fieldnames = ['tile_index', 'intensity_R', 'intensity_G', 'intensity_B', 'intensity_H', 'intensity_D', 'stddev_R', 'stddev_G', 'stddev_B', 'stddev_H', 'stddev_D', 'min_distance', 'stddev_min_distance', 'probability', 'ground_truth']
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                    writer.writeheader()
                    for i, idx in enumerate(top_15_indices):
                        if len(min_distances) > 0:
                            min_distance = min_distances[idx]
                            stddev_min_distance = np.std(min_distances)
                        else:
                            min_distance = '0'
                            stddev_min_distance = '0'

                        stddev_intensities = np.std(tile_intensities, axis=0)

                        writer.writerow({
                            'ground_truth': ground_truth,
                            'tile_index': i,
                            'intensity_R': tile_intensities[idx][0],
                            'intensity_G': tile_intensities[idx][1],
                            'intensity_B': tile_intensities[idx][2],
                            'intensity_H': ihc_intensities[idx][0],
                            'intensity_D': ihc_intensities[idx][1],
                            'stddev_R': stddev_intensities[0],
                            'stddev_G': stddev_intensities[1],
                            'stddev_B': stddev_intensities[2],
                            'stddev_H': np.std(ihc_intensities, axis=0)[0],
                            'stddev_D': np.std(ihc_intensities, axis=0)[1],
                            'min_distance': min_distance,
                            'stddev_min_distance': stddev_min_distance,
                            'probability': probabilities[idx]
                        })

if __name__ == "__main__":
    args = Namespace(batch_size=256, workers=4, tile_size=224, overlap=0, mult=1.0, level=0)
    slide_directory = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/testing_WSI/'
    model_path = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/MIL_output/MIL_checkpoint_best_10epochs_k10.pth'
    data_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/data_loader/'  # Folder to save data files
    output_folder = '/content/drive/MyDrive/AI_IHC_LANA_Positivity/02.17.24_MIL_analysis/ML/testing/'

    main(args, slide_directory, model_path, data_folder, output_folder)


ModuleNotFoundError: No module named 'openslide'