In [1]:
import os
import sys
import gc

In [2]:
XMem_path = os.path.abspath('./XMem')
sys.path.append(XMem_path)
# !wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth

In [3]:
from inspect import getsource
from pathlib import Path
from os import path

import cv2
import numpy as np
from PIL import Image
from skimage import io
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics.classification import Dice

# from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
# from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore
from inference.data.mask_mapper import MaskMapper

from inference.interact.interactive_utils import image_to_torch, index_numpy_to_one_hot_torch, torch_prob_to_numpy_mask

from progressbar import progressbar

torch.set_grad_enabled(False)

# default configuration
config = {
    'top_k': 30,
    'mem_every': 5,
    'deep_update_every': -1,
    'enable_long_term': True,
    'enable_long_term_count_usage': True,
    'num_prototypes': 128,
    'min_mid_term_frames': 5,
    'max_mid_term_frames': 10,
    'max_long_term_elements': 10000,
}

if torch.cuda.is_available():
  print('Using GPU')
  device = 'cuda'
else:
  print('CUDA not available. Please connect to a GPU instance if possible.')
  device = 'cpu'

# network = XMem(config, '../XMem/saves/XMem.pth').eval().to(device)

Using GPU


In [4]:
!nvidia-smi

Thu Nov 16 23:46:28 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L40                     On  | 00000000:E1:00.0 Off |                    0 |
| N/A   37C    P0              78W / 300W |   1556MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
torch.cuda.empty_cache()

COLOR = (3, 192, 60)

# processor = InferenceCore(network, config=config)
# NUM_OBJECTS = 1 # Binary Segmentation
# processor.set_all_labels(range(1, NUM_OBJECTS+1))

main_folder = Path("./data")

VIDEOS_PATH = main_folder/"endo17_binary"/"frames"
MASKS_PATH = main_folder/"endo17_binary"/"masks"


In [6]:
def binary2color(binary_mask, color):
    binary_mask = torch_prob_to_numpy_mask(binary_mask)
    pred_mask = np.tile(binary_mask[..., np.newaxis], (1,1,3)) # Make it 3 Channel
    mask = np.where(pred_mask == (1,)*3, color, 0).astype('uint8') # Convert Prediction with Color
    return mask

In [7]:
def frames2video(frames_dict, folder_save_path, video_name, FPS=5):
    frame = frames_dict[list(frames_dict.keys())[-1]]
    size1,size2,_ = frame.shape
    out = cv2.VideoWriter(f'{folder_save_path}/{video_name}_{FPS}FPS.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (size2, size1), True)
    # Sorting the frames according to frame number eg: frame_007.png
    for _,i in sorted(frames_dict.items(), key=lambda x: x[0]):
        out_img = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        out.write(out_img)
    out.release()

In [8]:
def getIoU(pred_frames, gt_path):
    metric = BinaryJaccardIndex()
    dice_metric = Dice()
    
    IoU = []
    dice = []
    
    for frame_name, mask in pred_frames.items():
        mask = torch_prob_to_numpy_mask(mask)
        try:
            truth_mask = io.imread(gt_path/frame_name)
        except FileNotFoundError:
            continue
        truth_mask = np.where(truth_mask == 255, 1, truth_mask)
        if np.sum(truth_mask) == 0:
            continue
        truth_mask = torch.tensor(truth_mask)
        IoU.append(metric(torch.tensor(mask), truth_mask).item())
        dice.append(dice_metric(torch.tensor(mask), truth_mask).item())
        
    meanIoU = sum(IoU)/len(IoU)
    meanDice = sum(dice)/len(dice)
    
    return meanIoU, IoU, meanDice, dice

In [9]:
im_normalization = transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )

def resize_mask(mask, size):
        mask = mask.unsqueeze(0).unsqueeze(0)
        h, w = mask.shape[-2:]
        min_hw = min(h, w)
        return F.interpolate(mask, (int(h/min_hw*size), int(w/min_hw*size)), 
                    mode='nearest')[0]

def singleVideoInference(images_paths, first_mask, processor, size = -1):
    predictions = {}
    frames = {}
    with torch.cuda.amp.autocast(enabled=True):

        images_paths = sorted(images_paths)

        # First Frame
        frame = io.imread(images_paths[0])
        shape = frame.shape[:2]
        if size < 0:
            im_transform = transforms.Compose([
                transforms.ToTensor(),
                im_normalization,
            ])
        else:
            im_transform = transforms.Compose([
                transforms.ToTensor(),
                im_normalization,
                transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
            ])
            
        frame_torch = im_transform(frame).to(device)
        first_mask = first_mask.astype(np.uint8)
        if size > 0:
            first_mask = torch.tensor(first_mask).to(device)
            first_mask = resize_mask(first_mask, size)
        else:
            NUM_OBJECTS = 1 # Binary Segmentation
            first_mask = index_numpy_to_one_hot_torch(first_mask, NUM_OBJECTS+1).to(device)
            first_mask = first_mask[1:]
            
        prediction = processor.step(frame_torch, first_mask)
        
        for image_path in tqdm(images_paths[1:]):
            frame = io.imread(image_path)
            # convert numpy array to pytorch tensor format
            frame_torch = im_transform(frame).to(device)
            
            prediction = processor.step(frame_torch)
            # Upsample to original size if needed
            if size > 0:
                prediction = F.interpolate(prediction.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
            predictions[image_path.name] = prediction
            frames[image_path.name] = frame

    return frames, predictions

In [10]:
def firstMaskGT(image_files, mask_folder):
    image_files = sorted(image_files)

    for idx, image_path in enumerate(image_files):        
        
        # Getting the Path to Mask Ground Truth using RGB Image path
        mask_path = mask_folder/image_path.parent.name/image_path.name
        
        mask = io.imread(mask_path)
        # All 255 Values replaced with 1, other values remain as it is.
        mask = np.where(mask == 255, 1, mask)
        
        if np.sum(mask) > 0:
            return mask, idx
            
    return None, -1

In [11]:
def doInference(network_path, config, frames_folder, mask_folder,
                subset = None, pred_mask_folder = None, size = -1):
    overallIoU = []
    overallDice = []
    for video_folder in sorted(frames_folder.iterdir()):

        if subset is not None and video_folder.name not in subset:
            continue
    
        # Clearing GPU Cache
        torch.cuda.empty_cache()
        network = XMem(config, network_path).eval().to(device)
        processor = InferenceCore(network, config=config)
        NUM_OBJECTS = 1 # Binary Segmentation
        processor.set_all_labels(range(1, NUM_OBJECTS+1))
        
        # All Images
        image_files = sorted(list(video_folder.iterdir()))
        if pred_mask_folder:
            mask_path = [i for i in pred_mask_folder.iterdir() if video_folder.name in i.name][0]
            mask = io.imread(mask_path)
            # All 0 pixel is 0, everything else(which is mask) is 1
            mask = np.where(mask == 0, 0, 1)
            # seq_01_0.png -> Two Splits, one on '_', other on '.'
            start_idx = int((mask_path.name.split('_')[-1]).split('.')[0])
        else: # Ground Truth
            mask, start_idx = firstMaskGT(image_files, mask_folder)
    
        print(f"Running Inference on {video_folder.name}...")
        frames, predictions = singleVideoInference(image_files[start_idx:], mask,
                                                  processor, size = size)
        IoU, _, dice, _ = getIoU(predictions, mask_folder/video_folder.name)
        print(f"Video \"{video_folder.name}\", mean IoU is: {IoU}")
        print(f"Video \"{video_folder.name}\", mean dice is: {dice}")
        
        overallIoU.append(IoU)
        overallDice.append(dice)
        print()

        del network, processor
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f"Average IoU over all videos is: {sum(overallIoU)/len(overallIoU)}.")
    print(f"Average Dice over all videos is: {sum(overallDice)/len(overallDice)}.")

    return overallIoU, overallDice

## Check all .pth files and select best

In [None]:
paths = []
for network_path in Path("./saves/Nov15_20.28.13_EndoVis17_Binary").iterdir():
    if 'checkpoint' in network_path.name or '.pth' not in network_path.name:
        continue
    paths.append(network_path)

IoUs = {}
for network_path in sorted(paths, key = lambda x: int(x.name.split('_')[-1].split('.')[0])):
    print(network_path.name)
    test_subset = {i.name for i in VIDEOS_PATH.iterdir() if 'test' in i.name}
    overallIoU, overallDice = doInference(network_path, config, VIDEOS_PATH, MASKS_PATH,
                                          subset = test_subset, size = 384)
    IoUs[network_path.name] = sum(overallIoU)/len(overallIoU)
    print('*'*100)

In [13]:
print("Inference Completed")

Inference Completed


In [18]:
sorted(IoUs.items(), key=lambda x: x[1], reverse=True)[0]

('Nov15_20.28.13_EndoVis17_Binary_1850.pth', 0.9230308244198298)

In [22]:
network_path = "./saves/Nov15_20.28.13_EndoVis17_Binary/XMem_Binary_Endo17.pth"
test_subset = {i.name for i in VIDEOS_PATH.iterdir() if 'test' in i.name}
overallIoU, overallDice = doInference(network_path, config, VIDEOS_PATH, MASKS_PATH,
                                          subset = test_subset, size = 384)

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_01_test...


100%|██████████| 74/74 [00:24<00:00,  3.00it/s]


Video "instrument_dataset_01_test", mean IoU is: 0.8832564353942871
Video "instrument_dataset_01_test", mean dice is: 0.9917536160430392

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_02_test...


100%|██████████| 74/74 [00:21<00:00,  3.46it/s]


Video "instrument_dataset_02_test", mean IoU is: 0.8809122201558706
Video "instrument_dataset_02_test", mean dice is: 0.9911449205231022

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_03_test...


100%|██████████| 74/74 [00:22<00:00,  3.27it/s]


Video "instrument_dataset_03_test", mean IoU is: 0.9524074527057441
Video "instrument_dataset_03_test", mean dice is: 0.9929651432746166

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_04_test...


100%|██████████| 74/74 [00:24<00:00,  3.07it/s]


Video "instrument_dataset_04_test", mean IoU is: 0.9484399033559335
Video "instrument_dataset_04_test", mean dice is: 0.9902801425070376

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_05_test...


100%|██████████| 74/74 [00:21<00:00,  3.40it/s]


Video "instrument_dataset_05_test", mean IoU is: 0.9179814574686257
Video "instrument_dataset_05_test", mean dice is: 0.9921518354802519

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_06_test...


100%|██████████| 74/74 [00:20<00:00,  3.58it/s]


Video "instrument_dataset_06_test", mean IoU is: 0.9220134760882404
Video "instrument_dataset_06_test", mean dice is: 0.9865542835480458

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_07_test...


100%|██████████| 74/74 [00:20<00:00,  3.56it/s]


Video "instrument_dataset_07_test", mean IoU is: 0.929388814278551
Video "instrument_dataset_07_test", mean dice is: 0.9836911430230012

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_08_test...


100%|██████████| 74/74 [00:20<00:00,  3.59it/s]


Video "instrument_dataset_08_test", mean IoU is: 0.9555812856635532
Video "instrument_dataset_08_test", mean dice is: 0.9922536684049142

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_09_test...


100%|██████████| 299/299 [01:37<00:00,  3.06it/s]


Video "instrument_dataset_09_test", mean IoU is: 0.909089560692127
Video "instrument_dataset_09_test", mean dice is: 0.9907494801342687

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Running Inference on instrument_dataset_10_test...


100%|██████████| 299/299 [01:29<00:00,  3.33it/s]


Video "instrument_dataset_10_test", mean IoU is: 0.9312376383953669
Video "instrument_dataset_10_test", mean dice is: 0.9918697667759796

Average IoU over all videos is: 0.9230308244198298.
Average Dice over all videos is: 0.9903413999714257.
