# ColonFormer Inference Notebook

Notebook này cho phép bạn sử dụng model ColonFormer đã được huấn luyện để:
- Test trên tập dữ liệu test có sẵn
- Upload và test ảnh riêng
- Xem kết quả phân đoạn và metrics
- So sánh với ground truth

## 1. Import Libraries và Setup

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from glob import glob
from tqdm import tqdm
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Import model
from mmseg.models.segmentors import ColonFormer as UNet

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set matplotlib style
plt.style.use('default')
sns.set_palette("husl")

## 2. Define Helper Functions

In [None]:
# Evaluation metrics
epsilon = 1e-7

def dice_coeff(y_true, y_pred):
    """Tính Dice coefficient"""
    intersection = np.sum(y_true * y_pred)
    return (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + epsilon)

def iou_score(y_true, y_pred):
    """Tính IoU score"""
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + epsilon)

def precision_score(y_true, y_pred):
    """Tính Precision"""
    true_positives = np.sum(y_true * y_pred)
    predicted_positives = np.sum(y_pred)
    return true_positives / (predicted_positives + epsilon)

def recall_score(y_true, y_pred):
    """Tính Recall"""
    true_positives = np.sum(y_true * y_pred)
    possible_positives = np.sum(y_true)
    return true_positives / (possible_positives + epsilon)

def preprocess_image(image_path, target_size=(352, 352)):
    """Tiền xử lý ảnh đầu vào"""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_shape = image.shape[:2]
    
    # Resize
    image_resized = cv2.resize(image, target_size)
    
    # Normalize
    image_normalized = image_resized.astype('float32') / 255.0
    
    # Convert to tensor
    image_tensor = torch.from_numpy(image_normalized.transpose((2, 0, 1))).unsqueeze(0)
    
    return image_tensor, image, original_shape

def postprocess_mask(mask_tensor, original_shape):
    """Hậu xử lý mask đầu ra"""
    mask = mask_tensor.sigmoid().data.cpu().numpy().squeeze()
    mask = cv2.resize(mask, (original_shape[1], original_shape[0]))
    mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
    return mask

## 3. Load Pre-trained Model

In [None]:
# Model configuration
BACKBONE = 'b3'  # Change này nếu sử dụng backbone khác
MODEL_PATH = './snapshots/ColonFormerB3/last.pth'  # Đường dẫn đến model đã train

def load_model(model_path, backbone='b3'):
    """Load model đã được train"""
    # Initialize model
    model = UNet(
        backbone=dict(
            type='mit_{}'.format(backbone),
            style='pytorch'
        ),
        decode_head=dict(
            type='UPerHead',
            in_channels=[64, 128, 320, 512],
            in_index=[0, 1, 2, 3],
            channels=128,
            dropout_ratio=0.1,
            num_classes=1,
            norm_cfg=dict(type='BN', requires_grad=True),
            align_corners=False,
            decoder_params=dict(embed_dim=768),
            loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
        ),
        neck=None,
        auxiliary_head=None,
        train_cfg=dict(),
        test_cfg=dict(mode='whole'),
        pretrained=None
    ).to(device)
    
    # Load weights
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
            print(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
        else:
            model.load_state_dict(checkpoint)
            print("✅ Loaded model weights")
    else:
        print(f"❌ Model file not found: {model_path}")
        return None
    
    model.eval()
    return model

# Load model
model = load_model(MODEL_PATH, BACKBONE)
if model is not None:
    total_params = sum(p.numel() for p in model.parameters())
    print(f"📊 Model loaded successfully with {total_params:,} parameters")

## 4. Visualization Functions

In [None]:
def visualize_results(image, pred_mask, gt_mask=None, title="Prediction Results"):
    """Hiển thị kết quả dự đoán"""
    if gt_mask is not None:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Original image
        axes[0, 0].imshow(image)
        axes[0, 0].set_title('Ảnh Gốc')
        axes[0, 0].axis('off')
        
        # Ground truth
        axes[0, 1].imshow(gt_mask, cmap='gray')
        axes[0, 1].set_title('Ground Truth')
        axes[0, 1].axis('off')
        
        # Prediction
        axes[1, 0].imshow(pred_mask, cmap='gray')
        axes[1, 0].set_title('Dự Đoán')
        axes[1, 0].axis('off')
        
        # Overlay
        overlay = image.copy()
        pred_binary = (pred_mask > 0.5).astype(np.uint8)
        overlay[pred_binary == 1] = [255, 0, 0]  # Red for prediction
        axes[1, 1].imshow(overlay)
        axes[1, 1].set_title('Overlay (Đỏ: Dự đoán)')
        axes[1, 1].axis('off')
        
    else:
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(image)
        axes[0].set_title('Ảnh Gốc')
        axes[0].axis('off')
        
        # Prediction
        axes[1].imshow(pred_mask, cmap='gray')
        axes[1].set_title('Dự Đoán')
        axes[1].axis('off')
        
        # Overlay
        overlay = image.copy()
        pred_binary = (pred_mask > 0.5).astype(np.uint8)
        overlay[pred_binary == 1] = [255, 0, 0]  # Red for prediction
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay (Đỏ: Dự đoán)')
        axes[2].axis('off')
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_metrics(dice, iou, precision, recall):
    """Vẽ biểu đồ metrics"""
    metrics = ['Dice', 'IoU', 'Precision', 'Recall']
    values = [dice, iou, precision, recall]
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(metrics, values, color=colors, alpha=0.8)
    plt.ylim(0, 1)
    plt.title('Kết Quả Đánh Giá', fontsize=16, fontweight='bold')
    plt.ylabel('Score')
    
    # Add value labels on bars
    for bar, value in zip(bars, values):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()

## 5. Test Single Image Function

In [None]:
def test_single_image(image_path, mask_path=None, show_results=True):
    """Test một ảnh đơn lẻ"""
    if not os.path.exists(image_path):
        print(f"❌ Image not found: {image_path}")
        return None
    
    # Preprocess
    image_tensor, original_image, original_shape = preprocess_image(image_path)
    image_tensor = image_tensor.to(device)
    
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
        if isinstance(outputs, (list, tuple)):
            pred_mask = outputs[0]  # Take first output
        else:
            pred_mask = outputs
    
    # Postprocess
    pred_mask_np = postprocess_mask(pred_mask, original_shape)
    pred_binary = (pred_mask_np > 0.5).astype(np.float32)
    
    # Load ground truth if available
    gt_mask_np = None
    metrics = None
    
    if mask_path and os.path.exists(mask_path):
        gt_mask = cv2.imread(mask_path, 0)
        gt_mask_np = cv2.resize(gt_mask, (original_shape[1], original_shape[0]))
        gt_mask_np = gt_mask_np.astype(np.float32) / 255.0
        gt_binary = (gt_mask_np > 0.5).astype(np.float32)
        
        # Calculate metrics
        dice = dice_coeff(gt_binary, pred_binary)
        iou = iou_score(gt_binary, pred_binary)
        precision = precision_score(gt_binary, pred_binary)
        recall = recall_score(gt_binary, pred_binary)
        
        metrics = {
            'dice': dice,
            'iou': iou,
            'precision': precision,
            'recall': recall
        }
        
        print(f"📊 Metrics:")
        print(f"   Dice: {dice:.4f}")
        print(f"   IoU: {iou:.4f}")
        print(f"   Precision: {precision:.4f}")
        print(f"   Recall: {recall:.4f}")
    
    # Visualization
    if show_results:
        filename = os.path.basename(image_path)
        visualize_results(original_image, pred_mask_np, gt_mask_np, f"Results: {filename}")
        
        if metrics:
            plot_metrics(metrics['dice'], metrics['iou'], 
                        metrics['precision'], metrics['recall'])
    
    return {
        'prediction': pred_mask_np,
        'binary_prediction': pred_binary,
        'ground_truth': gt_mask_np,
        'metrics': metrics,
        'original_image': original_image
    }