# Install libraries and set up environment

In [2]:
IMG_PX_SIZE = 128
PROJECT_LOCATION = 'D:/Niket'
SEG_TASK = 19 # 19 is R femur
TEST_IMAGE_IDX = 25

In [3]:
import math
import itertools
import sys
import torch
import os
import pydicom

import numpy as np
import matplotlib.pyplot as plt
import einops as E
import nibabel as nib
import random

from tqdm.auto import tqdm
from torchvision import transforms
from PIL import Image, ImageOps
from skimage.transform import resize
from collections import defaultdict

# Functions for reading in NIFTY images

In [4]:
sys.path.append('UniverSeg')
from universeg import universeg

if torch.cuda.is_available():
   device = torch.device("cuda")
   n_gpu = torch.cuda.device_count()
   torch.cuda.get_device_name(0)
else:
   device = torch.device("cpu")

device = 'cuda'

model = universeg(pretrained=True)
_ = model.to(device)

In [5]:
def visualize_tensors(tensors, col_wrap=8, col_names=None, title=None):
    M = len(tensors)
    N = len(next(iter(tensors.values())))

    cols = col_wrap
    rows = math.ceil(N/cols) * M

    d = 2.5
    fig, axes = plt.subplots(rows, cols, figsize=(d*cols, d*rows))
    if rows == 1:
      axes = axes.reshape(1, cols)

    for g, (grp, tensors) in enumerate(tensors.items()):
        for k, tensor in enumerate(tensors):
            col = k % cols
            row = g + M*(k//cols)
            x = tensor.detach().cpu().numpy().squeeze()
            ax = axes[row,col]
            if len(x.shape) == 2:
                ax.imshow(x,vmin=0, vmax=1, cmap='gray')
            else:
                ax.imshow(E.rearrange(x,'C H W -> H W C'))
            if col == 0:
                ax.set_ylabel(grp, fontsize=16)
            if col_names is not None and row == 0:
                ax.set_title(col_names[col])

    for i in range(rows):
        for j in range(cols):
            ax = axes[i,j]
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])

    if title:
        plt.suptitle(title, fontsize=20)

    plt.tight_layout()


def visualize_single_tensor(tensor, title=None):
    d = 2.5
    fig, ax = plt.subplots(1, 1, figsize=(d, d))

    x = tensor.detach().cpu().numpy().squeeze()
    if len(x.shape) == 2:
        ax.imshow(x, vmin=0, vmax=1, cmap='gray')
    else:
        ax.imshow(x, vmin=0, vmax=1, cmap='gray')
        # ax.imshow(E.rearrange(x, 'C H W -> H W C'))

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    if title:
        plt.title(title, fontsize=16)

    plt.tight_layout()
    plt.show()


def files_to_tensor(directory_path, file_type):
    if file_type == 'dicom':
        files = [f for f in os.listdir(directory_path)]
        read_function = read_dicom
    elif file_type == 'jpg':
        files = [f for f in os.listdir(directory_path) if f.endswith('.jpg') or f.endswith('.jpeg')]
        read_function = read_jpg
    elif file_type == 'nifti':
        files = [f for f in os.listdir(directory_path) if f.endswith('.nii.gz')]
        read_function = read_nifti
    else:
        raise ValueError("Unsupported file type")

    tensor_list = []
    for file_name in files:
        file_path = os.path.join(directory_path, file_name)
        image_tensor = read_function(file_path)
        tensor_list.append(image_tensor)

    tensor_stack = torch.stack(tensor_list)
    return tensor_stack


def read_dicom(file_path):
    ds = pydicom.dcmread(file_path)
    pixel_array = ds.pixel_array.astype(np.float32)
    pixel_array = resize(pixel_array, (IMG_PX_SIZE, IMG_PX_SIZE))
    pixel_array /= pixel_array.max()
    image_tensor = torch.from_numpy(pixel_array)
    return apply_transforms(image_tensor)


def read_jpg(file_path):
    img = Image.open(file_path).convert('L')  # 'L' mode for gray-scale images
    img = ImageOps.invert(img)
    img = img.resize((IMG_PX_SIZE, IMG_PX_SIZE))
    image_tensor = transforms.ToTensor()(img)
    return apply_transforms(image_tensor)


def read_nifti(file_path):
    img = nib.load(file_path)
    pixel_array = img.get_fdata()
    pixel_array = pixel_array.astype(np.float32)

    min_val = np.min(pixel_array)
    max_val = np.max(pixel_array)
    
    # Avoid division by zero
    if max_val - min_val != 0:
        pixel_array = (pixel_array - min_val) / (max_val - min_val)
    else:
        pixel_array = np.zeros_like(pixel_array)
        
    # pixel_array = (pixel_array - np.min(pixel_array)) / (np.max(pixel_array) - np.min(pixel_array))
    pixel_array = resize(pixel_array, (IMG_PX_SIZE, IMG_PX_SIZE))
    image_tensor = transforms.ToTensor()(pixel_array)
    return apply_transforms(image_tensor)


def apply_transforms(image_tensor):
    transform = transforms.Compose([
        # Add your transformations here if needed
    ])
    return transform(image_tensor)

# Reading Images Examples

In [36]:
nni_adrenal = files_to_tensor('D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0000/label', "nifti")
nni_adrenal = nni_adrenal.squeeze()

nni_images_training = files_to_tensor('D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0000/original', "nifti")
nni_images_training = nni_images_training.squeeze()

nni_labels_training = files_to_tensor('D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0000/segmentations', "nifti")
nni_labels_training = nni_labels_training.squeeze()

nni_images_test = files_to_tensor('D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0001/original', "nifti")
nni_images_test = nni_images_test.squeeze()

nni_labels_training_single_task = nni_labels_training[SEG_TASK]

random_indices = random.sample(range(len(nni_images_training)), 15)
random_images = [nni_images_training[i] for i in random_indices]
random_labels = [nni_labels_training_single_task[i] for i in random_indices]

visualize_tensors({
    'NNI Training Image': random_images,
    'NNI Training Label': random_labels
}, col_wrap=5, title='Support Set Examples')

test_images = nni_images_test
image = test_images[TEST_IMAGE_IDX]
image = torch.reshape(image, (1, image.shape[0], image.shape[1]))

visualize_single_tensor(image)

torch.Size([179, 128, 128])


# Functions for Searching

In [6]:
### For the Notebook
import sys
import datetime
import warnings
import torch
import pathlib
import os
import einops
import math
import imagehash
import random
import cv2

from tqdm.auto import tqdm
from collections import defaultdict
from typing import Tuple
from PIL import Image
from torchvision import transforms
from pylab import *
from sentence_transformers import SentenceTransformer, util
from medsimilarity import utils
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torch.nn.utils import prune

import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import albumentations as A

sys.path.append('src')
warnings.filterwarnings('ignore')
torch.manual_seed(156)
np.random.seed(156)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
ssim = SSIM().eval().to(device)
lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze',normalize=True).eval().to(device)
transform = transforms.Compose([transforms.ToTensor()])

def process_img(path: pathlib.Path, size: Tuple[int, int]):
    img = Image.open(path)
    img = img.resize(size, resample=Image.BILINEAR)
    img = img.convert("L")
    img = np.array(img)
    img = img.astype(np.float32)
    return img

def process_seg(path: pathlib.Path, size: Tuple[int, int]):
    seg = Image.open(path)
    seg = seg.resize(size, resample=Image.NEAREST)
    seg = np.array(seg)
    seg = np.stack([seg == 0, seg == 128, seg == 255])
    seg = seg.astype(np.float32)
    return seg

def load_folder(path: pathlib.Path, size: Tuple[int, int] = (128, 128)):
    data = []
    for file in sorted(path.glob("*.bmp")):
        img = process_img(file, size=size)
        seg_file = file.with_suffix(".png")
        seg = process_seg(seg_file, size=size)
        data.append((img / 255.0, seg))
    return data

def dice_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred
    y_true = y_true
    score = 2*(y_pred*y_true).sum() / (y_pred.sum() + y_true.sum())
    return score.item()

@torch.no_grad()
def inference(model, image, label, support_images, support_labels,device):
    image, label = image.to(device), label.to(device)
    support_images, support_labels = support_images.to(device), support_labels.to(device)
    # inference
    with torch.inference_mode():
        logits = model(image[None],support_images[None],support_labels[None])[0].detach() # outputs are logits

    soft_pred = torch.sigmoid(logits)
    hard_pred = soft_pred.round().clip(0,1)
    del logits
    torch.cuda.empty_cache()
    #  score
    score = dice_score(hard_pred, label)

    # return a dictionary of all relevant variables
    return {'Image': image,
            'Soft Prediction': soft_pred,
            'Prediction': hard_pred,
            'Ground Truth': label,
            'score': score}

def plot_tensor(arr1):
    num_rows = 4  # You can adjust this based on your preference
    num_cols = 8
    images_tensor = arr1
    # Create a figure and axes for the grid
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(12, 6))

    # Loop through each image in the tensor and plot in the grid
    for i in range(images_tensor.size(0)):
        row_idx = i // num_cols
        col_idx = i % num_cols

    # Convert the tensor to a numpy array
        image_array = images_tensor[i, 0].numpy()

    # Plot the image in the corresponding grid cell
        axs[row_idx, col_idx].imshow(image_array, cmap='gray')  # Assuming it's a grayscale image
        axs[row_idx, col_idx].set_title(f"Image {i+1}")
        axs[row_idx, col_idx].axis('off')

    # Adjust layout to prevent clipping
    plt.tight_layout()
    plt.show()

@torch.no_grad()
def inference_mix_prec(model, image, label, support_images, support_labels,device):
    image, label = image.to(device), label.to(device)
    support_images, support_labels = support_images.to(device), support_labels.to(device)
    # inference
    with torch.inference_mode(),torch.autocast(device,dtype=torch.bfloat16):
        logits = model(image[None],support_images[None],support_labels[None])[0].detach().float() # outputs are logits

    soft_pred = torch.sigmoid(logits)
    hard_pred = soft_pred.round().clip(0,1)
    del logits
    torch.cuda.empty_cache()
    #  score
    score = dice_score(hard_pred, label)

    # return a dictionary of all relevant variables
    return {'Image': image,
            'Soft Prediction': soft_pred,
            'Prediction': hard_pred,
            'Ground Truth': label,
            'score': score}

def structural_similarity(img1, img2):

    # Ensure both images are grayscale
    if img1.mode != 'L' or img2.mode != 'L':
      img1 = img1.convert('L')
      img2 = img2.convert('L')
    # Resize to match dimensions
    img1 = img1.resize((112,112))

    img2 = img2.resize((112,112))
    # Calculate SSIM
    image1 = transform(img1).unsqueeze(0).to(device)
    image2 = transform(img2).unsqueeze(0).to(device)
    with torch.inference_mode():
      score = ssim(image1, image2)
    return score

def lpips_similarity(img1, img2):

    # Ensure both images are grayscale
    if img1.mode != 'RGB' or img2.mode != 'RGB':
      img1 = img1.convert('RGB')
      img2 = img2.convert('RGB')
    # Resize to match dimensions
    img1 = img1.resize((112,112))

    img2 = img2.resize((112,112))
    # Calculate SSIM
    image1 = transform(img1).unsqueeze(0).to(device)
    image2 = transform(img2).unsqueeze(0).to(device)
    with torch.inference_mode():
      score = lpips(image1, image2)
    return 1 - score

def lpips_comparison(img, dataset, top_k = 50):

    img_test = Image.open(img)
    dataset_images = [Image.open(i) for i in dataset]

    matches = []
    for i, img in enumerate(dataset_images):
      score = float(lpips_similarity(img, img_test))
      matches.append([utils.get_filename(dataset[i]), score])

    matches = np.array(matches, dtype=object)
    return matches[np.argsort(matches[:, 1])][::-1][:top_k]

def __structural_comparison_worker(img1, img2):

    score = structural_similarity(Image.open(img1), Image.open(img2))
    return [utils.get_filename(img1), score]

def structural_comparison1(img, dataset, top_k = 50):

    img_test = Image.open(img)
    dataset_images = [Image.open(i) for i in dataset]

    matches = []
    for i, img in enumerate(dataset_images):
      score = float(structural_similarity(img, img_test))
      matches.append([utils.get_filename(dataset[i]), score])

    matches = np.array(matches, dtype=object)
    return matches[np.argsort(matches[:, 1])][::-1][:top_k]

def image_hash_comparison(img1,img2):

    # Compare the hashes
    hamming_distance = img1 - img2
    similarity = 1.0 - (hamming_distance / (1024))  # Normalizing to a similarity score between 0 and 1
    return similarity

def hash_comparison(img, dataset, top_k = 50):

    img_test = Image.open(img)
    dataset_images = [imagehash.phash(Image.open(i)) for i in dataset]
    hash1 = imagehash.phash(img_test)
    matches = []
    for i, img in enumerate(dataset_images):
        score = image_hash_comparison(img, hash1)
        matches.append([utils.get_filename(dataset[i]), score])
    matches = np.array(matches, dtype=object)
    return matches[np.argsort(matches[:, 1])][::-1][:top_k]

def dense_vector_comparison(
  img,
  dataset,
  top_k = 50,
  use_multiprocessing = True,
  device = None
):

    if device == None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # This method is invariant to transformations
    model = SentenceTransformer('clip-ViT-B-32', device=device)

    # Lazyload images
    sentences = [Image.open(img)] + [Image.open(path) for path in dataset]

    if use_multiprocessing:
        embds = model.encode_multi_process(
          sentences,
          model.start_multi_process_pool()
        )
    else:
        embds = model.encode(sentences)
    scores = util.paraphrase_mining_embeddings(
      embds,
      top_k = top_k
    )
    scores = np.array(scores, dtype=object)
    scores = (scores[np.where(scores[:,1] == 0)[0]])[:,[2,0]]
    matches = []
    for idx, score in scores:
        matches += [[utils.get_filename(dataset[int(idx)-1]), score]]
    del embds,scores,model,sentences
    torch.cuda.empty_cache()
    return np.array(matches, dtype=object)

'''Combine scores from both methods'''
def combined_score(x_ssim, x_dvrs):

    return np.sqrt(x_ssim)*np.power(x_dvrs, 2)

# Function to get a list of files in a folder

# Function to import support images and perform structural comparison
def import_support_image_structural(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]

    file_list = structural_comparison1(final_image_path, file_list1, top_k=test_num)

    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1,arr2,img_ran
# Function to import support images and perform dense vector comparison

def import_support_image_dense_vector(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]
    file_list = dense_vector_comparison(final_image_path, file_list1, top_k=test_num)

    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1.to('cpu'),arr2.to('cpu'),img_ran.to('cpu'),



# Function to import support images and perform LPIPS comparison

def import_support_image_lpips(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]
    file_list = lpips_comparison(final_image_path, file_list1, top_k=test_num)
    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1,arr2,img_ran

# Function to get a list of files in a folder
def get_files_in_folder(folder_path):
    try:
        files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
        return files
    except Exception as e:
        return str(e)

# Function to import support images and perform hash comparison
def import_support_image_hash(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    """
    Imports support images, performs hash comparison, and returns processed tensors.

    Parameters:
    - num (int): Number of support images to consider.
    - img_path (str): Path to the support images folder.
    - mask_path (str): Path to the support masks folder.
    - final_image_path (str): Path to the final image.
    - device: Torch device for processing.
    - size (tuple): Desired size of the images.

    Returns:
    - tuple: Tensors containing processed support images and masks.
    """
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num + 5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]

    file_list = hash_comparison(final_image_path, file_list1, top_k=test_num)
    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1,arr2,img_ran

# Function to import support images and perform structural comparison

def import_support_image_structural(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])
    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]

    file_list = structural_comparison1(final_image_path, file_list1, top_k=test_num)

    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1,arr2,img_ran

# Function to import support images and perform dense vector comparison
def import_support_image_dense_vector(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]
    file_list = dense_vector_comparison(final_image_path, file_list1, top_k=test_num)

    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1.to('cpu'),arr2.to('cpu'),img_ran.to('cpu'),

# Function to import support images, perform hash comparison with augmentation
def import_support_image_hash_aug(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Grayscale(num_output_channels=1)


    aug = A.Compose([

    A.RandomSizedCrop(min_max_height=(size[0] - 15, size[0] - 5),height=size[0],width=size[1], p=0.5),
    A.CLAHE(p=0.8),
    A.RandomBrightnessContrast(p=0.8),
    A.RandomGamma(p=0.8)])
    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)
    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]
    file_list = dense_vector_comparison(final_image_path, file_list1, top_k=5)
    img_ran = transform(Image.open(final_image_path))
    img_ran = transform2(img_ran)
    file_list1 = file_list[1:,0]
    folder_path2 = mask_path
    exp1 = []
    exp2 = []
    image = cv2.imread(folder_path1 +"/"+ file_list1[0],cv2.IMREAD_GRAYSCALE)
    image = cv2.resize(image,size)
    masks = cv2.imread(folder_path2 +"/"+ file_list1[0],cv2.IMREAD_GRAYSCALE)
    masks = cv2.resize(masks,size)
    for i in range(num):
        augmented = aug(image=image, mask=masks)
        image_medium = augmented['image']
        mask_medium = augmented['mask']
        img = transform(image_medium)
        img = transform2(img[:3,:,:])
        exp1.append(img)
        mask = transform(mask_medium)
        mask = transform2(mask[:3,:,:])
        exp2.append(mask)
    arr2 = torch.stack(exp2, axis=0)
    arr1 = torch.stack(exp1, axis=0)
    return arr1,arr2,img_ran

# Function to import support images and perform LPIPS comparison
def import_support_image_lpips(num,img_path,mask_path,final_image_path,device,size = (200,200)):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(size)])

    test_num = num+5

    transform2 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])

    img_ran = Image.open(final_image_path)

    folder_path1 = img_path

    file_list1 = get_files_in_folder(img_path)

    for i in range(len(file_list1)):
        file_list1[i] = folder_path1 + '/' + file_list1[i]
    file_list = lpips_comparison(final_image_path, file_list1, top_k=test_num)
    img_ran = transform(Image.open(final_image_path))

    file_list1 = file_list[1:,0]
    exp1 = []
    for i in range(num):
        img = Image.open(folder_path1 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp1.append(img)
    arr1 = torch.stack(exp1, axis=0)
    folder_path2 = mask_path
    file_list2 = get_files_in_folder(folder_path2)
    exp2 = []
    for i in range(num):
        img = Image.open(folder_path2 +"/"+ file_list1[i])
        img = transform(img)
        img = transform2(img[:3,:,:])
        exp2.append(img)
    arr2 = torch.stack(exp2, axis=0)

    img_ran = transform2(img_ran[:3, :, :])
    return arr1,arr2,img_ran


def get_total_parameters_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_pruned_parameters_count(pruned_model):
    params = 0
    for param in pruned_model.parameters():
        if param is not None:
            params += torch.nonzero(param).size(0)
    return params

def prune_model(model,reduce_size = 0.25):
    total_params_count = get_pruned_parameters_count(model)
    pruning_percentage = reduce_size  # Corrected variable name
    parameters_to_prune = (
        (model.enc_blocks[0].target.vmapped.conv, 'weight'),
        (model.enc_blocks[0].support.vmapped.conv, 'weight'),
        (model.enc_blocks[1].cross.cross_conv, 'weight'),
        (model.enc_blocks[1].target.vmapped.conv, 'weight'),
        (model.enc_blocks[1].support.vmapped.conv, 'weight'),
        (model.enc_blocks[2].cross.cross_conv, 'weight'),
        (model.enc_blocks[2].target.vmapped.conv, 'weight'),
        (model.enc_blocks[2].support.vmapped.conv, 'weight'),
        (model.enc_blocks[3].cross.cross_conv, 'weight'),
        (model.enc_blocks[3].target.vmapped.conv, 'weight'),
        (model.enc_blocks[3].support.vmapped.conv, 'weight'),
        (model.dec_blocks[0].cross.cross_conv, 'weight'),
        (model.dec_blocks[0].target.vmapped.conv, 'weight'),
        (model.dec_blocks[0].support.vmapped.conv, 'weight'),
        (model.dec_blocks[1].cross.cross_conv, 'weight'),
        (model.dec_blocks[1].target.vmapped.conv, 'weight'),
        (model.dec_blocks[1].support.vmapped.conv, 'weight'),
        (model.dec_blocks[2].cross.cross_conv, 'weight'),
        (model.dec_blocks[2].target.vmapped.conv, 'weight'),
        (model.dec_blocks[2].support.vmapped.conv, 'weight')
    )

    prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=pruning_percentage)

    prune.remove(model.enc_blocks[0].target.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[0].support.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[1].cross.cross_conv, 'weight'),
    prune.remove(model.enc_blocks[1].target.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[1].support.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[2].cross.cross_conv, 'weight'),
    prune.remove(model.enc_blocks[2].target.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[2].support.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[3].cross.cross_conv, 'weight'),
    prune.remove(model.enc_blocks[3].target.vmapped.conv, 'weight'),
    prune.remove(model.enc_blocks[3].support.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[0].cross.cross_conv, 'weight'),
    prune.remove(model.dec_blocks[0].target.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[0].support.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[1].cross.cross_conv, 'weight'),
    prune.remove(model.dec_blocks[1].target.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[1].support.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[2].cross.cross_conv, 'weight'),
    prune.remove(model.dec_blocks[2].target.vmapped.conv, 'weight'),
    prune.remove(model.dec_blocks[2].support.vmapped.conv, 'weight')
    pruned_model_param_count = get_pruned_parameters_count(model)
    print('Original Model paramete count:', total_params_count)
    print('Pruned Model parameter count:', pruned_model_param_count)
    print(f'Compressed Percentage: {(100 - (pruned_model_param_count / total_params_count) * 100)}%')
    return model

def visualize_tensors(tensors, col_wrap=8, col_names=None, title=None):
    M = len(tensors)
    N = len(next(iter(tensors.values())))

    cols = col_wrap
    rows = math.ceil(N/cols) * M

    d = 2.5
    fig, axes = plt.subplots(rows, cols, figsize=(d*cols, d*rows))
    if rows == 1:
      axes = axes.reshape(1, cols)

    for g, (grp, tensors) in enumerate(tensors.items()):
        for k, tensor in enumerate(tensors):
            col = k % cols
            row = g + M*(k//cols)
            x = tensor.detach().cpu().numpy().squeeze()
            ax = axes[row,col]
            if len(x.shape) == 2:
                ax.imshow(x,vmin=0, vmax=1, cmap='gray')
            else:
                ax.imshow(E.rearrange(x,'C H W -> H W C'))
            if col == 0:
                ax.set_ylabel(grp, fontsize=16)
            if col_names is not None and row == 0:
                ax.set_title(col_names[col])

    for i in range(rows):
        for j in range(cols):
            ax = axes[i,j]
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])

    if title:
        plt.suptitle(title, fontsize=20)

    plt.tight_layout()

# Converting NIFTI to JPG files

In [80]:
import os
import nibabel as nib
import numpy as np
from PIL import Image

# Function to convert a NIfTI file to JPG slices
def nifti_to_jpg_slices(nifti_path, output_folder):
    img = nib.load(nifti_path)
    data = img.get_fdata()
    
    # Create the output directory if it does not exist
    os.makedirs(output_folder, exist_ok=True)
    
    num_slices = data.shape[2]
    for i in range(num_slices):
        slice_data = data[:, :, i]
        
        # Normalize slice data to 0-1
        slice_min, slice_max = slice_data.min(), slice_data.max()
        if slice_max - slice_min != 0:
            slice_data = (slice_data - slice_min) / (slice_max - slice_min)
        else:
            slice_data = np.zeros_like(slice_data)
        
        # Convert to 0-255 for JPG
        slice_data = (slice_data * 255).astype(np.uint8)
        
        # Save slice as JPG
        slice_img = Image.fromarray(slice_data)
        slice_img.save(os.path.join(output_folder, f'slice_{i:04d}.jpg'))

# Base directory
base_dir = 'D:/Niket/Datasets/BTCV_subset/Training-Training'

# Directories for original and JPG images/labels
original_dirs = {'img': 'img', 'label': 'label'}
jpg_dirs = {'img': 'img_jpg', 'label': 'label_jpg'}

# Iterate over img and label directories
for key in original_dirs:
    original_base_dir = os.path.join(base_dir, original_dirs[key])
    jpg_base_dir = os.path.join(base_dir, jpg_dirs[key])
    
    # Create the base JPG directory if it does not exist
    os.makedirs(jpg_base_dir, exist_ok=True)
    
    # Iterate over numbered folders
    for num_folder in os.listdir(original_base_dir):
        num_folder_path = os.path.join(original_base_dir, num_folder)
        if os.path.isdir(num_folder_path):
            jpg_num_folder_path = os.path.join(jpg_base_dir, num_folder)
            os.makedirs(jpg_num_folder_path, exist_ok=True)
            
            # Iterate over each nii.gz file in the numbered folder
            for nii_file in os.listdir(num_folder_path):
                if nii_file.endswith('.nii.gz'):
                    nii_file_path = os.path.join(num_folder_path, nii_file)
                    nii_file_name = os.path.splitext(os.path.splitext(nii_file)[0])[0]
                    nii_jpg_folder_path = os.path.join(jpg_num_folder_path, nii_file_name)
                    
                    # Convert NIfTI to JPG slices
                    nifti_to_jpg_slices(nii_file_path, nii_jpg_folder_path)


# Experiment 1 - Using entirely the external dataset (Total-Segmentator - 5, 50, and 500 patients fully annotated, segment the others out)

In [8]:
# This code creates reference databases. Just change the db_size variable. 
# The goal is to find the best searching algorithm, best support size, and best database reference size for each organ modality
import os
import shutil

# Number of patient folders to include
db_size = 500  # You can change this number as needed

# Base dataset directory
dataset_base_dir = 'D:/Niket/Datasets/Totalsegmentator_dataset_v201'

# Dynamically get the first 'db_size' folders
source_dirs = sorted(
    [os.path.join(dataset_base_dir, d) for d in os.listdir(dataset_base_dir) if os.path.isdir(os.path.join(dataset_base_dir, d))]
)[:db_size]

# Dynamically set target directory name based on db_size
target_base_dir = f'D:/Niket/Experiments/Experiment1/{db_size}_patient_db/reference_db'

def create_task_structure(base_dir, tasks):
    for task in tasks:
        task_dir = os.path.join(base_dir, task)
        image_dir = os.path.join(task_dir, 'image')
        label_dir = os.path.join(task_dir, 'label')
        os.makedirs(image_dir, exist_ok=True)
        os.makedirs(label_dir, exist_ok=True)

def move_files(source_dirs, target_base_dir, tasks):
    for source_dir in source_dirs:
        for task in tasks:
            task_image_dir = os.path.join(target_base_dir, task, 'image')
            task_label_dir = os.path.join(target_base_dir, task, 'label')
            
            # Move image files
            original_jpg_dir = os.path.join(source_dir, 'original_jpg')
            for jpg_file in os.listdir(original_jpg_dir):
                if jpg_file.endswith('.jpg'):
                    source_file = os.path.join(original_jpg_dir, jpg_file)
                    target_file = os.path.join(task_image_dir, f"{os.path.basename(source_dir)}_{jpg_file}")
                    shutil.copyfile(source_file, target_file)
            
            # Move label files
            segmentations_jpg_dir = os.path.join(source_dir, 'segmentations_jpg', task)
            for jpg_file in os.listdir(segmentations_jpg_dir):
                if jpg_file.endswith('.jpg'):
                    source_file = os.path.join(segmentations_jpg_dir, jpg_file)
                    target_file = os.path.join(task_label_dir, f"{os.path.basename(source_dir)}_{jpg_file}")
                    shutil.copyfile(source_file, target_file)

# Get the list of tasks based on folders within s0000/segmentations_jpg
segmentations_jpg_dir = os.path.join(source_dirs[0], 'segmentations_jpg')
tasks = [task for task in os.listdir(segmentations_jpg_dir) if os.path.isdir(os.path.join(segmentations_jpg_dir, task))]

# Create the task structure in the target base directory
create_task_structure(target_base_dir, tasks)

# Move the files to the new structure
move_files(source_dirs, target_base_dir, tasks)

print(f"Files have been successfully moved to {target_base_dir} and the structure has been created.")

Files have been successfully moved to D:/Niket/Experiments/Experiment1/500_patient_db/reference_db and the structure has been created.


In [None]:
import os
import torch
import datetime
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as ssim_sk

# Constants
base_output_dir = "D:/Niket/Experiments/Experiment1/5_patient_db"
dataset_dir = "D:/Niket/Datasets/Totalsegmentator_dataset_v201"
size = (256, 256)
support_size = 8

# Define task filter
tasks_to_include = ["liver"]  # Specify tasks to include (e.g., only "liver")

methods_dic = [
    ('import_support_image_structural', import_support_image_structural),
    ('import_support_image_dense_vector', import_support_image_dense_vector),
    ('import_support_image_hash', import_support_image_hash),
    ('import_support_image_hash_aug', import_support_image_hash_aug),
    ('import_support_image_lpips', import_support_image_lpips)
]

def infer_and_save(original_image, similar_images, similar_labels, model, device):
    original_image = original_image[None].to(device)
    similar_images = similar_images[None].to(device)
    similar_labels = similar_labels[None].to(device)

    # Run inference
    logits = model(original_image, similar_images, similar_labels)[0]
    pred = torch.sigmoid(logits)

    original_image = original_image.detach().cpu()
    pred = pred.detach().cpu()
    del logits, similar_images, similar_labels
    torch.cuda.empty_cache()

    return original_image, pred

def save_image(tensor_or_array, path):
    # Check if input is a PyTorch tensor
    if isinstance(tensor_or_array, torch.Tensor):
        image = tensor_or_array.detach().cpu().numpy()  # Convert tensor to NumPy array
    else:
        image = tensor_or_array  # Assume it's already a NumPy array

    # Scale values to 0-255 and convert to uint8
    image = (image * 255).astype('uint8') if image.max() <= 1 else image.astype('uint8')
    
    # Save as image
    image = Image.fromarray(image)
    image.save(path)

def calculate_similarity(mask_image, pred_soft, pred_hard):
    # Ensure all inputs are NumPy arrays
    mask_image = np.array(mask_image).astype(np.float32)
    pred_soft = np.array(pred_soft).astype(np.float32)
    pred_hard = np.array(pred_hard).astype(np.float32)

    # Calculate SSIM
    soft_similarity = ssim_sk(mask_image, pred_soft, data_range=pred_soft.max() - pred_soft.min())
    hard_similarity = ssim_sk(mask_image, pred_hard, data_range=pred_hard.max() - pred_hard.min())

    return soft_similarity, hard_similarity

def create_output_structure(base_output_dir, s_folder, slice_name, task, method_name):
    # Create main directories
    slice_task_dir = os.path.join(base_output_dir, s_folder, slice_name, task, method_name)
    similar_images_dir = os.path.join(slice_task_dir, "similar_images")
    similar_labels_dir = os.path.join(slice_task_dir, "similar_labels")
    os.makedirs(similar_images_dir, exist_ok=True)
    os.makedirs(similar_labels_dir, exist_ok=True)
    return slice_task_dir, similar_images_dir, similar_labels_dir

# Load tasks from tasks.txt and filter based on tasks_to_include
tasks_file = os.path.join(base_output_dir, 'tasks.txt')
with open(tasks_file, 'r') as f:
    tasks = [line.strip() for line in f.readlines() if line.strip() in tasks_to_include]

# Create a timestamped output directory
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_output_dir = os.path.join(base_output_dir, timestamp)
os.makedirs(run_output_dir, exist_ok=True)

# Assuming 'model' and 'device' are already defined

# Get folders starting from the 501st
s_folders = sorted([d for d in os.listdir(dataset_dir) if d.startswith('s')])[500:]

# Loop over each S folder
for s_folder in s_folders:
    print(f"Processing folder: {s_folder}")
    s_folder_path = os.path.join(dataset_dir, s_folder)
    if not os.path.isdir(s_folder_path):
        continue

    original_jpg_dir = os.path.join(s_folder_path, 'original_jpg')
    segmentations_jpg_dir = os.path.join(s_folder_path, 'segmentations_jpg')

    for task in tasks:
        task_segmentations_dir = os.path.join(segmentations_jpg_dir, task)
        support_images_folder = f"{base_output_dir}/reference_db/{task}/image/"
        support_masks_folder = f"{base_output_dir}/reference_db/{task}/label/"

        # Process each slice in the original_jpg directory
        for slice_file in os.listdir(original_jpg_dir):
            if not slice_file.endswith('.jpg'):
                continue

            slice_name = os.path.splitext(slice_file)[0]
            test_image_path = os.path.join(original_jpg_dir, slice_file)
            mask_image_path = os.path.join(task_segmentations_dir, slice_file)

            # Ensure the mask image exists
            if not os.path.exists(mask_image_path):
                print(f"Mask image not found for slice: {slice_file}")
                continue

            # Load mask image
            mask_image = Image.open(mask_image_path).convert('L')
            mask_image_np = np.array(mask_image)

            for method_name, method_function in methods_dic:
                # Get support images and labels
                similar_images, similar_labels, original_image = method_function(
                    support_size, support_images_folder, support_masks_folder, test_image_path, device, size
                )
                original_image = original_image.to(torch.float32)
                similar_images = similar_images.to(torch.float32)
                similar_labels = similar_labels.to(torch.float32)

                # Run inference
                original_image, pred = infer_and_save(original_image, similar_images, similar_labels, model, device)

                # Generate predictions
                pred_soft = (pred * 255).numpy().astype('uint8')
                pred_hard = ((pred > 0.5) * 255).numpy().astype('uint8')
                original_image_np = (original_image.numpy() * 255).astype('uint8')

                # Flatten dimensions for compatibility
                original_image_np = original_image_np.squeeze()
                pred_soft = pred_soft.squeeze()
                pred_hard = pred_hard.squeeze()

                # Calculate similarity using the mask image
                soft_similarity, hard_similarity = calculate_similarity(
                    mask_image_np, pred_soft, pred_hard
                )

                # Create output directories
                slice_task_dir, similar_images_dir, similar_labels_dir = create_output_structure(
                    run_output_dir, s_folder, slice_name, task, method_name
                )

                # Save images and scores
                save_image(original_image.squeeze(), os.path.join(slice_task_dir, "original_image.jpg"))
                save_image(pred_soft, os.path.join(slice_task_dir, "pred_soft.jpg"))
                save_image(pred_hard, os.path.join(slice_task_dir, "pred_hard.jpg"))
                save_image(mask_image_np, os.path.join(slice_task_dir, "mask_image.jpg"))

                # Save similarity scores
                similarity_file_path = os.path.join(slice_task_dir, "similarity_scores.txt")
                with open(similarity_file_path, 'w') as f:
                    f.write(f"Soft Prediction Similarity: {soft_similarity}\n")
                    f.write(f"Hard Prediction Similarity: {hard_similarity}\n")

                # Save similar images and labels
                for i, img in enumerate(similar_images):
                    save_image(img.squeeze(), os.path.join(similar_images_dir, f"similar_image_{i:04d}.jpg"))
                for i, label in enumerate(similar_labels):
                    save_image(label.squeeze(), os.path.join(similar_labels_dir, f"similar_label_{i:04d}.jpg"))

                # Free memory
                del similar_images, similar_labels, original_image, pred, pred_soft, pred_hard
                torch.cuda.empty_cache()

print("Processing and saving completed.")


Processing folder: s0593


# Experiment 2 - Vary Support Size (4, 8, 16, 32) with TotalSegRef

In [23]:
# This code creates the sub-experiments 
import os
import shutil

# CHANGE THIS CODE
source_dirs = [
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0000',
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0001',
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0002',
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0003',
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0004',    
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0006',    
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0009',    
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0010',    
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0011',    
    'D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0012',    
]
target_base_dir = 'D:/Niket/Experiments/Experiment2/support_size_4/reference_db'

def create_task_structure(base_dir, tasks):
    for task in tasks:
        task_dir = os.path.join(base_dir, task)
        image_dir = os.path.join(task_dir, 'image')
        label_dir = os.path.join(task_dir, 'label')
        os.makedirs(image_dir, exist_ok=True)
        os.makedirs(label_dir, exist_ok=True)

def move_files(source_dirs, target_base_dir, tasks):
    for source_dir in source_dirs:
        for task in tasks:
            task_image_dir = os.path.join(target_base_dir, task, 'image')
            task_label_dir = os.path.join(target_base_dir, task, 'label')
            
            # Move image files
            original_jpg_dir = os.path.join(source_dir, 'original_jpg')
            for jpg_file in os.listdir(original_jpg_dir):
                if jpg_file.endswith('.jpg'):
                    source_file = os.path.join(original_jpg_dir, jpg_file)
                    target_file = os.path.join(task_image_dir, f"{os.path.basename(source_dir)}_{jpg_file}")
                    shutil.copyfile(source_file, target_file)
            
            # Move label files
            segmentations_jpg_dir = os.path.join(source_dir, 'segmentations_jpg', task)
            for jpg_file in os.listdir(segmentations_jpg_dir):
                if jpg_file.endswith('.jpg'):
                    source_file = os.path.join(segmentations_jpg_dir, jpg_file)
                    target_file = os.path.join(task_label_dir, f"{os.path.basename(source_dir)}_{jpg_file}")
                    shutil.copyfile(source_file, target_file)

# Get the list of tasks based on folders within s0000/segmentations_jpg
segmentations_jpg_dir = os.path.join(source_dirs[0], 'segmentations_jpg')
tasks = [task for task in os.listdir(segmentations_jpg_dir) if os.path.isdir(os.path.join(segmentations_jpg_dir, task))]

# Create the task structure in the target base directory
create_task_structure(target_base_dir, tasks)

# Move the files to the new structure
move_files(source_dirs, target_base_dir, tasks)

print("Files have been successfully moved and the structure has been created.")

Files have been successfully moved and the structure has been created.


In [25]:
import os
import torch
import datetime
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as ssim_sk

# CHANGE THIS CODE
size = (256, 256)
support_size = 4
base_output_dir = f"D:/Niket/Experiments/Experiment2/support_size_{support_size}/"
dataset_dir = "D:/Niket/Datasets/Totalsegmentator_dataset_v201"

def infer_and_save(original_image, similar_images, similar_labels, model, device, output_folder):
    # Run inference
    logits = model(original_image[None].to(device), similar_images[None].to(device), similar_labels[None].to(device))[0]
    pred = torch.sigmoid(logits)

    # Save original image, pred (soft), and pred (hard)
    original_image_path = os.path.join(output_folder, 'original_image.jpg')
    pred_soft_path = os.path.join(output_folder, 'pred_soft.jpg')
    pred_hard_path = os.path.join(output_folder, 'pred_hard.jpg')
    similarity_file_path = os.path.join(output_folder, 'similarity_scores.txt')

    save_image(original_image.squeeze(), original_image_path)
    save_image(pred.squeeze(), pred_soft_path)
    save_image((pred > 0.5).squeeze(), pred_hard_path)

    # Calculate and save similarity scores
    calculate_and_save_similarity(original_image_path, pred_soft_path, pred_hard_path, similarity_file_path)

def save_image(tensor, path):
    # Detach the tensor from the computation graph and convert to NumPy
    image = tensor.detach().cpu().numpy()
    image = (image * 255).astype('uint8')
    image = Image.fromarray(image)
    image.save(path)

def calculate_and_save_similarity(original_image_path, pred_soft_path, pred_hard_path, output_file):
    original_image = np.array(Image.open(original_image_path).convert('L'))
    pred_soft_image = np.array(Image.open(pred_soft_path).convert('L'))
    pred_hard_image = np.array(Image.open(pred_hard_path).convert('L'))

    # Ensure images are in float32 format for SSIM calculation
    original_image = original_image.astype(np.float32)
    pred_soft_image = pred_soft_image.astype(np.float32)
    pred_hard_image = pred_hard_image.astype(np.float32)

    soft_similarity = ssim_sk(original_image, pred_soft_image, data_range=pred_soft_image.max() - pred_soft_image.min())
    hard_similarity = ssim_sk(original_image, pred_hard_image, data_range=pred_hard_image.max() - pred_hard_image.min())

    with open(output_file, 'w') as f:
        f.write(f"Soft Prediction Similarity: {soft_similarity}\n")
        f.write(f"Hard Prediction Similarity: {hard_similarity}\n")

def create_output_structure(base_output_dir, s_folder, slice_name, task, methods):
    # Create slice/task directory
    slice_task_dir = os.path.join(base_output_dir, s_folder, slice_name, task)
    os.makedirs(slice_task_dir, exist_ok=True)

    method_dirs = {}
    for method in methods:
        method_dir = os.path.join(slice_task_dir, method)
        os.makedirs(os.path.join(method_dir, 'similar_images'), exist_ok=True)
        os.makedirs(os.path.join(method_dir, 'similar_labels'), exist_ok=True)
        method_dirs[method] = method_dir

    return method_dirs

# Setup
methods = ['import_support_image_structural', 'import_support_image_dense_vector', 
           'import_support_image_hash', 'import_support_image_hash_aug', 'import_support_image_lpips']

# Load tasks from tasks.txt
tasks_file = os.path.join(base_output_dir, 'tasks.txt')
with open(tasks_file, 'r') as f:
    tasks = [line.strip() for line in f.readlines()]

# Create a timestamped output directory
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_output_dir = os.path.join(base_output_dir, timestamp)
os.makedirs(run_output_dir, exist_ok=True)

# Assuming 'model' and 'device' are already defined

# Loop over each S folder (e.g., s0000, s0001, etc.)
for s_folder in os.listdir(dataset_dir):
    s_folder_path = os.path.join(dataset_dir, s_folder)
    if not os.path.isdir(s_folder_path):
        continue

    original_jpg_dir = os.path.join(s_folder_path, 'original_jpg')

    # Loop over each task
    for task in tasks:
        support_images_folder = f"{base_output_dir}/reference_db/{task}/image/"
        support_masks_folder = f"{base_output_dir}/reference_db/{task}/label/"

        # Process each slice in the original_jpg directory of the current S folder
        for slice_file in os.listdir(original_jpg_dir):
            if slice_file.endswith('.jpg'):
                slice_name = os.path.splitext(slice_file)[0]
                test_image_path = os.path.join(original_jpg_dir, slice_file)

                # Create output structure for this slice and task within the run's directory
                method_dirs = create_output_structure(run_output_dir, s_folder, slice_name, task, methods)

                # Handling each method
                for method_name, method_function in [
                    ('import_support_image_structural', import_support_image_structural),
                    ('import_support_image_dense_vector', import_support_image_dense_vector),
                    ('import_support_image_hash', import_support_image_hash),
                    ('import_support_image_hash_aug', import_support_image_hash_aug),
                    ('import_support_image_lpips', import_support_image_lpips)
                ]:
                    similar_images, similar_labels, original_image = method_function(support_size, support_images_folder, support_masks_folder, test_image_path, device, size)
                    original_image = original_image.to(torch.float32)
                    similar_images = similar_images.to(torch.float32)
                    similar_labels = similar_labels.to(torch.float32)

                    # Save similar images
                    sim_images_dir = os.path.join(method_dirs[method_name], 'similar_images')
                    sim_labels_dir = os.path.join(method_dirs[method_name], 'similar_labels')
                    for i, img in enumerate(similar_images):
                        save_image(img.squeeze(), os.path.join(sim_images_dir, f'similar_image_{i:04d}.jpg'))
                    for i, label in enumerate(similar_labels):
                        save_image(label.squeeze(), os.path.join(sim_labels_dir, f'similar_label_{i:04d}.jpg'))

                    # Run inference and save outputs
                    infer_and_save(original_image, similar_images, similar_labels, model, device, method_dirs[method_name])

print("Processing and saving completed.")

KeyboardInterrupt: 

# Experiment 2.5 - Vary Support Size (4, 8, 16, 32) with BCTV

In [None]:
1 Modulating the DB size with totalsegdb 5 50 (only liver)
2 Second experiment is BCTV as reference (still segment 501 and onwards)
3 ...
4 ...

# Leftover Code

In [None]:
def infer_and_vis(original_image, similar_images, similar_labels, model, device):
  # run inference
  logits = model(original_image[None], similar_images[None], similar_labels[None])[0]
  pred = torch.sigmoid(logits)

  # visualize
  res = {'data': [original_image, pred, pred > 0.5]}
  titles = col_names=['image', 'pred (soft)', 'pred (hard)'] # Visualize ground truth as well
  visualize_tensors(res, col_wrap=3, col_names=titles)

size = (256,256)
support_size = 8
task = "adrenal_gland_left"
support_images_folder = f"D:/Niket/Experiments/Experiments1/reference_db/{task}/image/"
support_masks_folder = f"D:/Niket/Experiments/Experiments1/reference_db/{task}/label/"
test_image_path = "D:/Niket/Datasets/Totalsegmentator_dataset_v201/s0002/original_jpg/slice_0000.jpg"

similar_images,similar_labels,original_image = import_support_image_structural(support_size,support_images_folder,support_masks_folder,test_image_path,device,size)
original_image = original_image.to(torch.float32)
similar_images = similar_images.to(torch.float32)
similar_labels = similar_labels.to(torch.float32)
infer_and_vis(original_image, similar_images, similar_labels, model, device)

similar_images,similar_labels,original_image = import_support_image_dense_vector(support_size,support_images_folder,support_masks_folder,test_image_path,device,size)
original_image = original_image.to(torch.float32)
similar_images = similar_images.to(torch.float32)
similar_labels = similar_labels.to(torch.float32)
infer_and_vis(original_image, similar_images, similar_labels, model, device)

similar_images,similar_labels,original_image = import_support_image_hash(support_size,support_images_folder,support_masks_folder,test_image_path,device,size)
original_image = original_image.to(torch.float32)
similar_images = similar_images.to(torch.float32)
similar_labels = similar_labels.to(torch.float32)
infer_and_vis(original_image, similar_images, similar_labels, model, device)

similar_images,similar_labels,original_image = import_support_image_hash_aug(support_size,support_images_folder,support_masks_folder,test_image_path,device,size)
original_image = original_image.to(torch.float32)
similar_images = similar_images.to(torch.float32)
similar_labels = similar_labels.to(torch.float32)
infer_and_vis(original_image, similar_images, similar_labels, model, device)

similar_images,similar_labels,original_image = import_support_image_lpips(support_size,support_images_folder,support_masks_folder,test_image_path,device,size)
original_image = original_image.to(torch.float32)
similar_images = similar_images.to(torch.float32)
similar_labels = similar_labels.to(torch.float32)
infer_and_vis(original_image, similar_images, similar_labels, model, device)