# Test All Text-Guided Segmentation Models

This notebook tests all 14 text-guided segmentation models to verify they:
1. Load correctly
2. Accept the expected inputs (images + text prompts)
3. Produce the expected outputs (logits with correct shape)

## Models to test:
1. CLIPSeg (CVPR 2022)
2. LSeg (ICLR 2022)
3. GroupViT (CVPR 2022)
4. SAN (CVPR 2023)
5. FC-CLIP (NeurIPS 2023)
6. OVSeg (CVPR 2023)
7. CAT-Seg (CVPR 2024)
8. SED (CVPR 2024)
9. MAFT+ (ECCV 2024 Oral)
10. X-Decoder (CVPR 2023)
11. OpenSeeD (ICCV 2023)
12. ODISE (CVPR 2023)
13. TagAlign (arXiv 2023)
14. Semantic-SAM (ECCV 2024)

In [1]:
# Setup and imports
import sys
import os

# Add parent directory to path
WORKSPACE = "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work"
sys.path.insert(0, WORKSPACE)

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List
import time

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA RTX A5000


In [14]:
# Import TextGuidedSegmentation package
# First, clear any cached imports
import importlib
import sys

# Remove cached TextGuidedSegmentation modules
modules_to_remove = [m for m in sys.modules.keys() if 'TextGuidedSegmentation' in m]
for m in modules_to_remove:
    del sys.modules[m]

# Now import fresh
from TextGuidedSegmentation import (
    get_model,
    list_models,
    print_model_summary,
    MODEL_INFO,
    DEFAULT_TEXT_PROMPTS,
)

# Print available models
print_model_summary()


Text-Guided Segmentation Models for Histopathology

#   Model           Venue                Description                             
--------------------------------------------------------------------------------
1   CLIPSeg         CVPR 2022            Uses CLIP features with FiLM condition..
2   LSeg            ICLR 2022            Dense Prediction Transformer with CLIP..
3   GroupViT        CVPR 2022            Hierarchical grouping mechanism with c..
4   SAN             CVPR 2023            Side adapter network preserving CLIP c..
5   FC-CLIP         NeurIPS 2023         Fully convolutional CLIP for dense pre..
6   OVSeg           CVPR 2023            Mask-adapted CLIP with region-level cl..
7   CAT-Seg         CVPR 2024            Cost aggregation with spatial semantic..
8   SED             CVPR 2024            Simple encoder-decoder with category-g..
9   MAFT+           ECCV 2024 Oral       Multi-modal adapters with cross-modal ..
10  X-Decoder       CVPR 2023            Unifi

In [15]:
# Test configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2
IMAGE_SIZE = 256
NUM_CLASSES = 5

# Default text prompts for histopathology
TEXT_PROMPTS = [
    "neoplastic cells",
    "inflammatory cells",
    "connective tissue cells",
    "dead cells",
    "epithelial cells",
]

# Create dummy input
dummy_images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)
print(f"Dummy input shape: {dummy_images.shape}")
print(f"Text prompts: {TEXT_PROMPTS}")

Dummy input shape: torch.Size([2, 3, 256, 256])
Text prompts: ['neoplastic cells', 'inflammatory cells', 'connective tissue cells', 'dead cells', 'epithelial cells']


In [16]:
def test_model(model_name: str, images: torch.Tensor, text_prompts: List[str]) -> Dict:
    """
    Test a single model.
    
    Returns:
        Dict with test results
    """
    result = {
        'model_name': model_name,
        'status': 'unknown',
        'error': None,
        'output_shape': None,
        'num_params': None,
        'inference_time': None,
    }
    
    try:
        # Load model
        print(f"\n{'='*60}")
        print(f"Testing: {model_name}")
        print(f"{'='*60}")
        
        start_time = time.time()
        model = get_model(
            model_name,
            num_classes=NUM_CLASSES,
            image_size=IMAGE_SIZE,
            device=DEVICE,
        )
        load_time = time.time() - start_time
        print(f"✓ Model loaded in {load_time:.2f}s")
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        result['num_params'] = total_params
        print(f"✓ Total params: {total_params:,}")
        print(f"✓ Trainable params: {trainable_params:,}")
        
        # Forward pass
        model.eval()
        with torch.no_grad():
            start_time = time.time()
            outputs = model(images, text_prompts)
            inference_time = time.time() - start_time
        
        result['inference_time'] = inference_time
        print(f"✓ Inference time: {inference_time:.3f}s")
        
        # Check outputs
        assert 'logits' in outputs, "Output must contain 'logits'"
        logits = outputs['logits']
        result['output_shape'] = tuple(logits.shape)
        
        expected_shape = (BATCH_SIZE, NUM_CLASSES, IMAGE_SIZE, IMAGE_SIZE)
        print(f"✓ Output shape: {logits.shape}")
        print(f"  Expected shape: {expected_shape}")
        
        assert logits.shape == expected_shape, f"Shape mismatch: {logits.shape} != {expected_shape}"
        print(f"✓ Shape matches expected!")
        
        # Check for NaN/Inf
        assert not torch.isnan(logits).any(), "Output contains NaN values"
        assert not torch.isinf(logits).any(), "Output contains Inf values"
        print(f"✓ No NaN/Inf values")
        
        # Check predicted mask
        if 'pred_mask' in outputs:
            pred_mask = outputs['pred_mask']
            print(f"✓ Predicted mask shape: {pred_mask.shape}")
            print(f"  Unique classes: {torch.unique(pred_mask).tolist()}")
        
        result['status'] = 'PASSED'
        print(f"\n✅ {model_name}: PASSED")
        
        # Clean up
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        result['status'] = 'FAILED'
        result['error'] = str(e)
        print(f"\n❌ {model_name}: FAILED")
        print(f"   Error: {e}")
        import traceback
        traceback.print_exc()
    
    return result

In [17]:
# Get all model names
all_models = list_models()
print(f"Total models to test: {len(all_models)}")
print(f"Models: {all_models}")

Total models to test: 18
Models: ['clipseg', 'clipseg_rd64', 'clipseg_rd128', 'lseg', 'lseg_vit_l', 'groupvit', 'san', 'fc_clip', 'fc_clip_convnext', 'ovseg', 'cat_seg', 'sed', 'maft_plus', 'x_decoder', 'openseed', 'odise', 'tagalign', 'semantic_sam']


In [18]:
# Test all models
results = []

for model_name in all_models:
    result = test_model(model_name, dummy_images, TEXT_PROMPTS)
    results.append(result)
    
    # Clear CUDA cache between models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Testing: clipseg
✓ Model loaded in 1.42s
✓ Total params: 150,078,466
✓ Trainable params: 457,729
✓ Inference time: 0.030s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [0, 1, 2, 3, 4]

✅ clipseg: PASSED

Testing: clipseg_rd64
✓ Model loaded in 1.26s
✓ Total params: 150,078,466
✓ Trainable params: 457,729
✓ Inference time: 0.022s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [0, 1, 2, 3, 4]

✅ clipseg_rd64: PASSED

Testing: clipseg_rd128
✓ Model loaded in 1.39s
✓ Total params: 150,761,474
✓ Trainable params: 1,140,737
✓ Inference time: 0.031s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch

Traceback (most recent call last):
  File "/tmp/ipykernel_3306561/2823249585.py", line 44, in test_model
    outputs = model(images, text_prompts)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/lseg.py", line 356, in forward
    visual_features = self.encode_image(image)  # (B, D, H', W')
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/lseg.py", line 327, in encode_image
    dense_features = self.dpt_head(layer_features, cls_token, patch_h, patch_w)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 177

✓ Model loaded in 1.50s
✓ Total params: 175,964,802
✓ Trainable params: 26,344,065
✓ Inference time: 0.011s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [0, 4]

✅ groupvit: PASSED

Testing: san
✓ Model loaded in 1.47s
✓ Total params: 181,669,121
✓ Trainable params: 32,048,384
✓ Inference time: 0.034s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [0, 1, 2, 3, 4]

✅ san: PASSED

Testing: fc_clip
✓ Model loaded in 1.37s
✓ Total params: 154,019,586
✓ Trainable params: 4,398,849
✓ Inference time: 0.037s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [2]



Traceback (most recent call last):
  File "/tmp/ipykernel_3306561/2823249585.py", line 44, in test_model
    outputs = model(images, text_prompts)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/maft_plus.py", line 355, in forward
    visual_enhanced, text_enhanced = self.cross_modal(visual_flat, text_features)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args,

✓ Model loaded in 1.25s
✓ Total params: 162,475,521
✓ Trainable params: 12,854,784

❌ x_decoder: FAILED
   Error: The size of tensor a (256) must match the size of tensor b (65536) at non-singleton dimension 1

Testing: openseed


Traceback (most recent call last):
  File "/tmp/ipykernel_3306561/2823249585.py", line 44, in test_model
    outputs = model(images, text_prompts)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/x_decoder.py", line 356, in forward
    queries = self.decoder(decoder_feats, text_feats, pos_flat)  # (B, Q, d_model)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args

✓ Model loaded in 1.26s
✓ Total params: 158,622,721
✓ Trainable params: 9,001,984
✓ Inference time: 0.025s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [4]

✅ openseed: PASSED

Testing: odise
✓ Model loaded in 1.32s
✓ Total params: 162,532,417
✓ Trainable params: 12,911,680
✓ Inference time: 0.046s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [0, 2, 3, 4]

✅ odise: PASSED

Testing: tagalign
✓ Model loaded in 1.24s
✓ Total params: 151,934,210
✓ Trainable params: 2,313,473

❌ tagalign: FAILED
   Error: Sizes of tensors must match except in dimension 1. Expected size 15 but got size 16 for tensor number 1 in the list.

Testing: semantic_sam


Traceback (most recent call last):
  File "/tmp/ipykernel_3306561/2823249585.py", line 44, in test_model
    outputs = model(images, text_prompts)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/tagalign.py", line 305, in forward
    visual_features = self.encode_image(image, text_features)
  File "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/TextGuidedSegmentation/models/tagalign.py", line 277, in encode_image
    multi_scale = self.multi_granular(aligned)
  File "/home/miglab/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
   

✓ Model loaded in 1.25s
✓ Total params: 155,403,522
✓ Trainable params: 5,782,785
✓ Inference time: 0.026s
✓ Output shape: torch.Size([2, 5, 256, 256])
  Expected shape: (2, 5, 256, 256)
✓ Shape matches expected!
✓ No NaN/Inf values
✓ Predicted mask shape: torch.Size([2, 256, 256])
  Unique classes: [3]

✅ semantic_sam: PASSED


In [19]:
# Summary
print("\n" + "="*80)
print("TEST SUMMARY")
print("="*80)

passed = [r for r in results if r['status'] == 'PASSED']
failed = [r for r in results if r['status'] == 'FAILED']

print(f"\n✅ Passed: {len(passed)}/{len(results)}")
print(f"❌ Failed: {len(failed)}/{len(results)}")

print("\n" + "-"*80)
print(f"{'Model':<20} {'Status':<10} {'Params (M)':<12} {'Time (s)':<10}")
print("-"*80)

for r in results:
    params_m = r['num_params'] / 1e6 if r['num_params'] else 0
    time_s = r['inference_time'] if r['inference_time'] else 0
    status = "✅ PASS" if r['status'] == 'PASSED' else "❌ FAIL"
    print(f"{r['model_name']:<20} {status:<10} {params_m:>10.2f}M {time_s:>10.3f}s")

print("-"*80)

if failed:
    print("\nFailed models:")
    for r in failed:
        print(f"  - {r['model_name']}: {r['error']}")


TEST SUMMARY

✅ Passed: 14/18
❌ Failed: 4/18

--------------------------------------------------------------------------------
Model                Status     Params (M)   Time (s)  
--------------------------------------------------------------------------------
clipseg              ✅ PASS         150.08M      0.030s
clipseg_rd64         ✅ PASS         150.08M      0.022s
clipseg_rd128        ✅ PASS         150.76M      0.031s
lseg                 ✅ PASS         178.93M      0.032s
lseg_vit_l           ❌ FAIL         456.99M      0.000s
groupvit             ✅ PASS         175.96M      0.011s
san                  ✅ PASS         181.67M      0.034s
fc_clip              ✅ PASS         154.02M      0.037s
fc_clip_convnext     ✅ PASS         155.20M      0.037s
ovseg                ✅ PASS         155.28M      0.024s
cat_seg              ✅ PASS         151.50M      0.144s
sed                  ✅ PASS         156.64M      0.022s
maft_plus            ❌ FAIL         158.62M      0.000s
x_decod

In [None]:
# Test a single model in detail (for debugging)
# Change MODEL_TO_TEST to test a specific model

MODEL_TO_TEST = "clipseg"  # Change this to test different models

print(f"\nDetailed test for: {MODEL_TO_TEST}")
print("="*60)

model = get_model(
    MODEL_TO_TEST,
    num_classes=NUM_CLASSES,
    image_size=IMAGE_SIZE,
    device=DEVICE,
)

# Print model architecture
print("\nModel architecture:")
print(model)

# Forward pass
model.eval()
with torch.no_grad():
    outputs = model(dummy_images, TEXT_PROMPTS)

print("\nOutput keys:", outputs.keys())
for key, value in outputs.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"  {key}: {type(value)}")

In [None]:
# Visualize predictions
import matplotlib.pyplot as plt

model = get_model("clipseg", num_classes=NUM_CLASSES, image_size=IMAGE_SIZE, device=DEVICE)
model.eval()

with torch.no_grad():
    outputs = model(dummy_images, TEXT_PROMPTS)

logits = outputs['logits']
pred_mask = logits.argmax(dim=1)

# Plot
fig, axes = plt.subplots(BATCH_SIZE, 3, figsize=(12, 4*BATCH_SIZE))

for i in range(BATCH_SIZE):
    # Input image (denormalized for visualization)
    img = dummy_images[i].cpu().permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]
    
    # Prediction
    pred = pred_mask[i].cpu().numpy()
    
    # Logits heatmap for first class
    heatmap = logits[i, 0].cpu().numpy()
    
    if BATCH_SIZE > 1:
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f'Input {i}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(pred, cmap='tab10', vmin=0, vmax=NUM_CLASSES-1)
        axes[i, 1].set_title(f'Prediction {i}')
        axes[i, 1].axis('off')
        
        im = axes[i, 2].imshow(heatmap, cmap='hot')
        axes[i, 2].set_title(f'Neoplastic logits {i}')
        axes[i, 2].axis('off')
        plt.colorbar(im, ax=axes[i, 2])
    else:
        axes[0].imshow(img)
        axes[0].set_title('Input')
        axes[0].axis('off')
        
        axes[1].imshow(pred, cmap='tab10', vmin=0, vmax=NUM_CLASSES-1)
        axes[1].set_title('Prediction')
        axes[1].axis('off')
        
        im = axes[2].imshow(heatmap, cmap='hot')
        axes[2].set_title('Neoplastic logits')
        axes[2].axis('off')
        plt.colorbar(im, ax=axes[2])

plt.tight_layout()
plt.show()

In [None]:
# Memory usage comparison
if torch.cuda.is_available():
    print("\nGPU Memory Usage Comparison")
    print("="*60)
    
    memory_results = []
    
    for model_name in all_models:
        try:
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            model = get_model(model_name, num_classes=NUM_CLASSES, image_size=IMAGE_SIZE, device=DEVICE)
            
            model.eval()
            with torch.no_grad():
                _ = model(dummy_images, TEXT_PROMPTS)
            
            peak_memory = torch.cuda.max_memory_allocated() / 1e9  # GB
            memory_results.append((model_name, peak_memory))
            
            del model
            torch.cuda.empty_cache()
            
        except Exception as e:
            memory_results.append((model_name, -1))
    
    # Sort by memory
    memory_results.sort(key=lambda x: x[1])
    
    print(f"\n{'Model':<20} {'Peak Memory (GB)':<15}")
    print("-"*35)
    for name, mem in memory_results:
        if mem >= 0:
            print(f"{name:<20} {mem:>10.2f} GB")
        else:
            print(f"{name:<20} {'FAILED':>10}")

In [None]:
print("\n" + "="*60)
print("All tests completed!")
print("="*60)
print(f"\nTo use these models in training, see: train_text_guided_unified.ipynb")