In [None]:
import torch
import os
import numpy as np
import random
from openslide import open_slide
from PIL import Image
from torch.utils.data import Dataset

In [None]:
Image.MAX_IMAGE_PIXELS = None

### Data functions

In [None]:
class Patch:
    '''
    Store properties of each patch
    '''

    def __init__(self, image, position, size=256):
        self.image = image
        self.position = position
        self.size = size
        self.probability = None
        self.prediction = None
        self.is_background = False
    
    def set_probability(self, probability):
        self.probability = probability

    def get_prediction(self):
        self.prediction = 1 if self.probability >= 0.5 else 0

### Helper functions

In [None]:
def scale_tensor(tensor: torch.Tensor):
    '''
    Scale a tensor to the range [0, 1]
    '''
    minn = tensor.min()
    maxx = tensor.max()
    tensor = (tensor - minn)/(maxx - minn)
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

def image_to_patches_with_positions(image, patch_size: int, stride: int):
    '''
    Function for splitting an input image into patches.

    Parameters:
    image: input image to split
    patch_size (int): dimension, patches will be square
    stride (int): controls overlap between patches

    Returns:
    Tensor of patches with shape (num_patches, im_dim (if applicable), patch_size, patch_size) with their positions in the original image
    '''
    # Convert image to PyTorch tensor
    im = torch.from_numpy(image)
    # Scale image to [0, 1]
    im = scale_tensor(im)

    # Is image colour or binary?
    image_dimension = 3 if len(image.shape) == 3 else 1
    # Working with a colour image
    if image_dimension == 3:
        # Extract patches
        patches = im.unfold(0, patch_size, stride).unfold(1, patch_size, stride)
        # Reshape tensor into tensor of shape (num_patches, 3, patch_size, patch_size)
        patches = patches.contiguous().view(-1, image_dimension, patch_size, patch_size) ###.contiguous() ensure tensor is stored in contiguous block of memory which is required for .view()
    # Working with greyscale image
    else:
        # Extract patches
        patches = im.unfold(0, patch_size, stride).unfold(1, patch_size, stride)
        # Reshape tensor into tensor of shape (num_patches, patch_size, patch_size)
        patches = patches.contiguous().view(-1, patch_size, patch_size)

    # Calculate the number of patches in each dimension
    height, width = image.shape[:2]
    num_patches_h = (height - patch_size) // stride + 1
    num_patches_w = (width - patch_size) // stride + 1

    # Generate positions of the patches
    positions = []
    for h in range(num_patches_h):
        for w in range(num_patches_w):
            # Calculate the top-left position of the current patch
            top = h * stride
            left = w * stride
            positions.append((top, left))

    return patches, positions

def get_patch_objects(patches, positions):
    patch_objects = []
    for patch, position in zip(patches, positions):
        patch_object = Patch(image=patch, position=position)
        patch_objects.append(patch_object)
    return patch_objects

## CHECK THIS FUNCTION PROPERLY for my data
def check_if_background(patch):
    '''
    Given a patch, return whether it should be classified as a background patch or not.
    '''
    # working with actual patch now and NOT the mask?????????????????
    im = np.array(patch.convert(mode='RGB'))
    pixels = np.ravel(im)
    mean = np.mean(pixels)
    is_background = mean >= 220
    return is_background

def choose_random_image(directory, seed):
    '''
    Choose a SVS file randomly to perform inference and produce heatmap.

    Return level 1 image from SVS file as well as the images case code.
    '''
    random.seed(seed)
    files = os.listdir(directory)
    svs_files = [file for file in files if file.endswith(('.svs'))]
    # Randomly choose file
    random_svs = random.choice(svs_files)
    print(random_svs)
    name = random_svs.replace('.svs', '')
    if name.startswith('._'):
        name = name.replace('._', '')
    case_code = name.split('.')[0].replace('TCGA-', '').replace('-01Z-00-DX1', '')
    slide_path = os.path.join(directory, random_svs)
    sld = open_slide(slide_path)
    slide_props = sld.properties
    slide_width = int(slide_props['openslide.level[1].width']); slide_height = int(slide_props['openslide.level[1].height']) # dimensions at 10X magnification
    slide = np.array(sld.get_thumbnail(size=(slide_width, slide_height)))

    return case_code, slide

def check_seg_accuracy(label_directory, case_code):
    # get image labels for 
    # going to run into problem: what if there arent the same number of background patches as this algorithm returns - wont be able to calculate accuracy
    return acc

def get_prediction(patch, output):
    # preprocess

    # get predictions for patch
    probabilities = torch.softmax(output, dim=1) # Post-process the predictions
    patch.probability = probabilities[1]
    predicted_class = torch.argmax(probabilities, dim=1)
    patch.prediction = predicted_class

def create_heatmap(image_size, patches):
    '''
    Generate the heatmap array based on patch predicted probabilities.
    '''
    heat_map = np.zeros(image_size)
    for patch in patches:
        i, j = patch.position
        h, w = patch.size
        heat_map[i:i+h, j+j+w] = patch.probability

def inference(patches, model):
    '''
    Takes in Patch objects and makes predictions for each patch if it is not classified as a background patch.
    Then uses those predictions to create a heatmap
    '''
    model.eval()
    for patch in patches:
        is_background = check_if_background(patch.patch)
        if not is_background:
            output = model.predict(patch.image)
            get_prediction(patch, output)
        else:
            patch.probability = 0
            patch.prediction = 0
    
    return heatmap

def visualise_heatmap():

In [None]:
PATCH_SIZE=256
STRIDE=PATCH_SIZE
num_classes=2

In [None]:
SEED=42
images_directory = '/Volumes/AlexS/MastersData/SVS files/'
labels_directory = '/Volumes/AlexS/MastersData/processed/labels/'

In [None]:
# load image
image = choose_random_image(images_directory, SEED)
image_size = image.shape
# extract patches
patches, positions = image_to_patches_with_positions(image, PATCH_SIZE, STRIDE)
# create Patch objects
patch_objects = get_patch_objects(patches, positions)
heatmap = inference(patch_objects)



- Want to be able to overlay this produced heatmap over the original slide
- Also need to create heatmap of the original slide with its classes to visualise what is the ground truth segmentation using patches