# Setting up the Notebook

In [6]:
# Imports

import numpy as np
import argparse
import os
import csv
from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
import albumentations as A
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image


# Loading in Scott's Model

In [8]:
# for classes
id2label = {
    0: "unlabeled",
    1: "background",
    2: "disc",
    3: "cup"
}


def color_palette():
    """Color palette that maps each class to RGB values.
    
    This one is actually taken from ADE20k.
    """
    return [[255,0,0], [0,255,0], [0,0,255]]

# for vis
palette = color_palette()

# transforms
ADE_MEAN = np.array([0.709, 0.439, 0.287])
ADE_STD = np.array([0.210, 0.220, 0.199])

test_transform = A.Compose([
    A.Resize(width=512, height=512),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD)
])

# BEGIN TEST

device = 'cuda:0'
# Replace the head of the pre-trained model
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade",
                                                            id2label=id2label,
                                                            ignore_mismatched_sizes=True).to(device)

# Specify the path to the state dictionary
model_savedir = '/sddata/projects/Conformal_Uncertainty_Quantification/models'
model_filename = '/sddata/projects/glaucoma_segmentation_scott/best_maskformer_weights/best_model.pt'
# Load the state dictionary
state_dict = torch.load(model_filename)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

Some weights of MaskFormerForInstanceSegmentation were not initialized from the model checkpoint at facebook/maskformer-swin-base-ade and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([5, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([5]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

# Loading in the Testing Data

In [None]:
def convert_to_single_channel(image_path):
    # Open the image
    image = Image.open(image_path)

    # Convert to RGB mode if not already in that mode
    image = image.convert('RGB')

    # Convert the image to a NumPy array
    img_array = np.array(image)

    # Compute the max value along the channel axis
    max_channel = np.argmax(img_array, axis=2) + 1  # Adding 1 to make red 1, green 2, blue 3

    # Define the color mappings after max channel computation
    color_map = {
        1: 1,  # Red
        2: 2,  # Green
        3: 3,  # Blue
    }

    # Create masks based on the max channel values
    masks = np.zeros((img_array.shape[0], img_array.shape[1], len(color_map)), dtype=np.uint8)
    for index, label in color_map.items():
        masks[:, :, index - 1] = max_channel == label

    # Create the single-channel image using the masks
    single_channel_image = np.argmax(masks, axis=2) + 1  # Adding 1 to adjust labels

    return single_channel_image

def convert_grayscale_disc_image(image_path):
    # Open the image
    image = Image.open(image_path)

    # Convert to grayscale mode if not already in that mode
    image = image.convert('L')

    # Convert the image to a NumPy array
    img_array = np.array(image, dtype=np.uint8)

    # Replace black pixels with 1 (background) and white pixels with 2 (disc)
    img_array[img_array >= 1] = 2  # Assuming white pixels have intensity 255
    img_array[img_array == 0] = 1  # Assuming black pixels have intensity 0

    return img_array


class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, images, masks, transform, data_root_dir, convert_disc=False):
        """
        Args:
            dataset
        """
        self.images = images
        self.masks = masks
        self.transform = transform
        self.data_root_dir = data_root_dir
        self.convert_disc = convert_disc
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        path_to_img = self.data_root_dir + 'images/' + self.images[idx]
        path_to_mask = self.data_root_dir + 'labels/' + self.masks[idx]
        if self.data_root_dir == '/projects/skinder@xsede.org/seg_paper/test_datasets/rimone/':
            path_to_img = self.images[idx].split('/')[-1]
            path_to_img = self.data_root_dir + path_to_img
            path_to_mask = self.masks[idx].split('/')[-1]
            path_to_mask = self.data_root_dir + path_to_mask
             
        original_image = np.array(Image.open(path_to_img))
        original_segmentation_map = None
        if self.convert_disc:
            original_segmentation_map = convert_grayscale_disc_image(path_to_mask)
        else:
            original_segmentation_map = convert_to_single_channel(path_to_mask)
        
        transformed = self.transform(image=original_image, mask=original_segmentation_map)
        image, segmentation_map = transformed['image'], transformed['mask']

        # convert to C, H, W
        image = image.transpose(2,0,1)

        return image, segmentation_map, original_image, original_segmentation_map


# # Load the CSV file
# test_csv_file_path = '/scratch/alpine/skinder@xsede.org/glaucomachris/r1.csv'  # Replace with the actual path
# data_df = pd.read_csv(test_csv_file_path)

# # Note the ::5 if you use the test set
# test_image_paths = data_df['Cropped_Images'].tolist()[::5]  # Replace with the actual column name
# test_mask_paths = data_df['Cropped_Disks'].tolist()[::5]  # Replace with the actual column name

test_dataset = ImageSegmentationDataset(test_image_paths, test_mask_paths, transform=test_transform, data_root_dir='/projects/skinder@xsede.org/seg_paper/test_datasets/rimone/', convert_disc=True)

# Create a preprocessor
preprocessor = MaskFormerImageProcessor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)
