In [None]:
import os
import random
import logging
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import resize
from tqdm import tqdm
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from datasets.dataset_cellseg import CellSeg_dataset
from utils import test_single_volume, calculate_metric_percase, calculate_iou_ap_per_class, convert_to_uint8, infer_large_image_in_patches, overlay_mask_on_image_and_save
from matplotlib.colors import ListedColormap


Setup the Model and Parameters

In [None]:
# Parameters setup
args = {
    'volume_path': './vis/',
    'dataset': 'CellSeg',
    'num_classes': 2,
    'list_dir': 'path_to_list_dir',
    'img_size': 224,
    'vit_name': 'ViT-B_16',
    'vit_patches_size': 16,
    'n_skip': 0,
    'is_pretrain': True,
    'test_save_dir': './vis/',
    'deterministic': True,
    'seed': 1234,
    'batch_size': 1,
    'base_lr': 0.01,
    'max_epochs': 30
}

if args['deterministic']:
    cudnn.benchmark = False
    cudnn.deterministic = True
else:
    cudnn.benchmark = True
    cudnn.deterministic = False

random.seed(args['seed'])
np.random.seed(args['seed'])
torch.manual_seed(args['seed'])
torch.cuda.manual_seed(args['seed'])

dataset_config = {
    'CellSeg': {
        'Dataset': CellSeg_dataset,
        'volume_path': args['volume_path'],
        'list_dir': args['list_dir'],
        'num_classes': args['num_classes'],
        'z_spacing': 1,
    },
}

dataset_name = args['dataset']
args['num_classes'] = dataset_config[dataset_name]['num_classes']
args['volume_path'] = dataset_config[dataset_name]['volume_path']
args['Dataset'] = dataset_config[dataset_name]['Dataset']
args['list_dir'] = dataset_config[dataset_name]['list_dir']
args['z_spacing'] = dataset_config[dataset_name]['z_spacing']
args['is_pretrain'] = True

snapshot_path = "/mnt/parscratch/users/coq20tz/TransUNet/model/TU_CellSeg224/TU_pretrain_R50-ViT-B_16_skip0_epo200_bs128_224_St150_SN100_SEL10000_SF100000_LTpareto_VOS/epoch_199.pth"
net = ViT_seg(CONFIGS_ViT_seg[args['vit_name']], img_size=args['img_size'], num_classes=args['num_classes']).cuda()
net.load_state_dict(torch.load(snapshot_path))


Perform Inference and Visulization

In [None]:
# Perform Inference
image_folder = args['volume_path']
image_files = [f for f in os.listdir(image_folder) if f.endswith('.tif') or f.endswith('.png')]

# Image normalization function
def check_image(image):
    if not isinstance(image, np.ndarray):
        raise TypeError("Image is not a numpy array.")
    if image.max() > 1:
        if image.max() <= 255:
            return image / 255.0
        else:
            raise ValueError("Image pixel values are outside the expected range.")
    return image

for image_file in image_files:
    image_path = os.path.join(image_folder, image_file)
    image = Image.open(image_path).convert("L")
    image = np.array(image)

    prediction = infer_large_image_in_patches(image, net, patch_size=[224, 224], overlap=56, device='cuda')
    img_uint8 = convert_to_uint8(image)
    prd_uint8 = convert_to_uint8(prediction)


    # Convert numpy arrays to SimpleITK images
    img_itk = sitk.GetImageFromArray(img_uint8)
    prd_itk = sitk.GetImageFromArray(prd_uint8)
    img_itk.SetSpacing((1, 1, z_spacing))
    prd_itk.SetSpacing((1, 1, z_spacing))

    # Save output images
    test_save_path = args['test_save_dir']
    case = os.path.basename(image_path).split('.')[0]
    os.makedirs(test_save_path, exist_ok=True)
    sitk.WriteImage(prd_itk, os.path.join(test_save_path, f'{case}_pred.png'))
    sitk.WriteImage(img_itk, os.path.join(test_save_path, f'{case}_img.png'))

    # Overlay and save the prediction mask on the original image
    overlay_mask_on_image_and_save(img_uint8, prd_uint8, save_path=os.path.join(test_save_path, f'{case}_vis.png'), alpha=0.3, mask_color='red', dpi=300, threshold=0.5)

    print(f"Inference completed. Results saved in {test_save_path}")

    # Visualize the images
    plt.figure(figsize=(12, 4))

    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(img_uint8, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')

    # Prediction mask
    plt.subplot(1, 3, 2)
    plt.imshow(prd_uint8, cmap='gray')
    plt.title('Prediction Mask')
    plt.axis('off')

    # Overlayed image
    overlayed_image = overlay_mask_on_image_and_save(img_uint8, prd_uint8, alpha=0.3, mask_color='red', dpi=300, threshold=0.5, save_path=None)

    plt.subplot(1, 3, 3)
    plt.imshow(overlayed_image)
    plt.title('Overlayed Image')
    plt.axis('off')

    plt.tight_layout()
    plt.show()
    
    