# SegResNet Inference and Visualization

This notebook loads a pre-trained SegResNet model and performs inference on brain tumor MRI data.

## Contents
1. Imports and Setup
2. SegResNet Model Architecture
3. Load Pre-trained Model
4. Inference Functions
5. Visualization
6. Generate Submission


---
## 1. Imports and Setup


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

---
## 2. SegResNet Model Architecture

The same architecture used during training must be defined here.


In [None]:
class ResidualBlock(nn.Module):
    """Residual block that learns the difference instead of raw mapping."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        x = self.relu(x)
        return x


class SegResNet(nn.Module):
    """SegResNet for volumetric segmentation."""
    def __init__(self, in_channels=4, out_channels=4, init_filters=32):
        super().__init__()
        
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv3d(in_channels, init_filters, kernel_size=3, padding=1),
            nn.BatchNorm3d(init_filters),
            nn.ReLU(inplace=True)
        )
        self.res_block1 = ResidualBlock(init_filters, init_filters)
        self.res_block2 = ResidualBlock(init_filters, init_filters * 2, stride=2)
        self.res_block3 = ResidualBlock(init_filters * 2, init_filters * 4, stride=2)
        self.res_block4 = ResidualBlock(init_filters * 4, init_filters * 8, stride=2)
        
        # Bottleneck
        self.bottleneck = ResidualBlock(init_filters * 8, init_filters * 16, stride=2)
        
        # Decoder
        self.up4 = nn.ConvTranspose3d(init_filters * 16, init_filters * 8, kernel_size=2, stride=2)
        self.dec4 = ResidualBlock(init_filters * 16, init_filters * 8)
        
        self.up3 = nn.ConvTranspose3d(init_filters * 8, init_filters * 4, kernel_size=2, stride=2)
        self.dec3 = ResidualBlock(init_filters * 8, init_filters * 4)
        
        self.up2 = nn.ConvTranspose3d(init_filters * 4, init_filters * 2, kernel_size=2, stride=2)
        self.dec2 = ResidualBlock(init_filters * 4, init_filters * 2)
        
        self.up1 = nn.ConvTranspose3d(init_filters * 2, init_filters, kernel_size=2, stride=2)
        self.dec1 = ResidualBlock(init_filters * 2, init_filters)
        
        self.final = nn.Conv3d(init_filters, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        x1 = self.enc1(x)
        x1 = self.res_block1(x1)
        x2 = self.res_block2(x1)
        x3 = self.res_block3(x2)
        x4 = self.res_block4(x3)
        
        # Bottleneck
        b = self.bottleneck(x4)
        
        # Decoder with skip connections
        u4 = self.up4(b)
        if u4.shape != x4.shape:
            u4 = F.interpolate(u4, size=x4.shape[2:])
        d4 = torch.cat((u4, x4), dim=1)
        d4 = self.dec4(d4)
        
        u3 = self.up3(d4)
        if u3.shape != x3.shape:
            u3 = F.interpolate(u3, size=x3.shape[2:])
        d3 = torch.cat((u3, x3), dim=1)
        d3 = self.dec3(d3)
        
        u2 = self.up2(d3)
        if u2.shape != x2.shape:
            u2 = F.interpolate(u2, size=x2.shape[2:])
        d2 = torch.cat((u2, x2), dim=1)
        d2 = self.dec2(d2)
        
        u1 = self.up1(d2)
        if u1.shape != x1.shape:
            u1 = F.interpolate(u1, size=x1.shape[2:])
        d1 = torch.cat((u1, x1), dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

---
## 3. Load Pre-trained Model


In [None]:
CHECKPOINT_PATH = "/kaggle/input/segresnet-checkpoint/segresnet_best.pth"

# Initialize model
model = SegResNet(in_channels=4, out_channels=4, init_filters=32).to(DEVICE)

# Load weights
try:
    state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
    print("Model weights loaded successfully!")
except Exception as e:
    print(f"Error loading weights: {e}")

model.eval()

---
## 4. Inference Functions


In [None]:
def normalize_volume(volume):
    """Robust Z-Score Normalization."""
    mask = volume > 0
    if np.sum(mask) == 0:
        return volume

    pixels = volume[mask]
    p_low, p_high = np.percentile(pixels, 0.5), np.percentile(pixels, 99.5)
    volume = np.clip(volume, p_low, p_high)
    
    pixels = volume[mask]
    mean, std = pixels.mean(), pixels.std()
    volume = (volume - mean) / (std + 1e-8)
    volume[~mask] = 0
    
    return volume


def sliding_window_inference(model, image, patch_size=(96, 96, 96), overlap=0.5):
    """Sliding window inference for large 3D volumes."""
    model.eval()
    
    batch_size = image.shape[0]
    image_size = image.shape[2:]
    
    output_probs = torch.zeros((batch_size, 4, *image_size), device=image.device, dtype=torch.float32)
    count_map = torch.zeros((batch_size, 4, *image_size), device=image.device, dtype=torch.float32)
    
    step = [int(p * (1 - overlap)) for p in patch_size]
    
    for z in range(0, image_size[0] - patch_size[0] + 1, step[0]):
        for y in range(0, image_size[1] - patch_size[1] + 1, step[1]):
            for x in range(0, image_size[2] - patch_size[2] + 1, step[2]):
                patch = image[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]]
                
                with torch.no_grad():
                    with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
                        pred = model(patch)
                        pred = F.softmax(pred, dim=1)
                
                output_probs[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]] += pred
                count_map[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]] += 1
    
    output_probs /= count_map.clamp(min=1)
    return output_probs


def simple_resize_inference(model, image, target_size=(96, 96, 96)):
    """Simple resize-based inference for faster processing."""
    model.eval()
    original_shape = image.shape[2:]
    
    with torch.no_grad():
        small_image = F.interpolate(image, size=target_size, mode='trilinear', align_corners=False)
        
        with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
            pred = model(small_image)
        
        pred = F.interpolate(pred, size=original_shape, mode='trilinear', align_corners=False)
        pred = F.softmax(pred, dim=1)
    
    return pred

---
## 5. Visualization


In [None]:
def visualize_test_prediction(model, test_path, patient_id=None):
    """Visualize the model's prediction for a test patient."""
    model.eval()
    
    if patient_id is None:
        patients = sorted([p for p in os.listdir(test_path) if os.path.isdir(os.path.join(test_path, p))])
        if len(patients) == 0:
            print("No patients found!")
            return
        patient_id = patients[0]
    
    patient_path = os.path.join(test_path, patient_id)
    print(f"Processing patient: {patient_id}")
    
    # Load modalities
    modalities = {}
    for f in os.listdir(patient_path):
        if f.endswith('.nii.gz'):
            for mod in ['t1ce', 't1', 't2', 'flair']:
                if mod in f.lower():
                    if mod == 't1' and 't1ce' in f.lower():
                        continue
                    modalities[mod] = nib.load(os.path.join(patient_path, f)).get_fdata()
                    break
    
    # Normalize and stack
    image = np.stack([
        normalize_volume(modalities['t1'].astype(np.float32)),
        normalize_volume(modalities['t1ce'].astype(np.float32)),
        normalize_volume(modalities['t2'].astype(np.float32)),
        normalize_volume(modalities['flair'].astype(np.float32))
    ], axis=0)
    
    # Inference
    image_tensor = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
    probs = simple_resize_inference(model, image_tensor)
    prediction = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
    
    # Find slice with most tumor
    tumor_per_slice = [(prediction[:, :, z] > 0).sum() for z in range(prediction.shape[2])]
    best_slice = np.argmax(tumor_per_slice)
    
    # Plot
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    axes[0].imshow(np.rot90(modalities['t1'][:, :, best_slice]), cmap='gray')
    axes[0].set_title('T1')
    axes[0].axis('off')
    
    axes[1].imshow(np.rot90(modalities['t1ce'][:, :, best_slice]), cmap='gray')
    axes[1].set_title('T1ce')
    axes[1].axis('off')
    
    axes[2].imshow(np.rot90(modalities['t2'][:, :, best_slice]), cmap='gray')
    axes[2].set_title('T2')
    axes[2].axis('off')
    
    axes[3].imshow(np.rot90(modalities['flair'][:, :, best_slice]), cmap='gray')
    axes[3].set_title('FLAIR')
    axes[3].axis('off')
    
    cmap = mcolors.ListedColormap(['black', 'red', 'green', 'yellow'])
    axes[4].imshow(np.rot90(prediction[:, :, best_slice]), cmap=cmap, vmin=0, vmax=3)
    axes[4].set_title('Prediction')
    axes[4].axis('off')
    
    plt.suptitle(f"Patient: {patient_id} | Slice: {best_slice}", fontsize=14)
    plt.tight_layout()
    plt.show()


TEST_PATH = "/kaggle/input/instant-odc-ai-hackathon/test"
# visualize_test_prediction(model, TEST_PATH)

---
## 6. Generate Submission


In [None]:
def rle_encode(mask):
    """Encode binary mask to Run-Length Encoding."""
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def generate_submission(model, test_path, output_path="submission.csv"):
    """Generate submission CSV with RLE-encoded predictions."""
    model.eval()
    
    submission_rows = []
    patients = sorted([p for p in os.listdir(test_path) if os.path.isdir(os.path.join(test_path, p))])
    
    for patient_id in tqdm(patients, desc="Processing"):
        patient_path = os.path.join(test_path, patient_id)
        
        modalities = {}
        for f in os.listdir(patient_path):
            if f.endswith('.nii.gz'):
                for mod in ['t1ce', 't1', 't2', 'flair']:
                    if mod in f.lower():
                        if mod == 't1' and 't1ce' in f.lower():
                            continue
                        modalities[mod] = nib.load(os.path.join(patient_path, f)).get_fdata()
                        break
        
        if len(modalities) < 4:
            continue
        
        image = np.stack([
            normalize_volume(modalities['t1'].astype(np.float32)),
            normalize_volume(modalities['t1ce'].astype(np.float32)),
            normalize_volume(modalities['t2'].astype(np.float32)),
            normalize_volume(modalities['flair'].astype(np.float32))
        ], axis=0)
        
        image_tensor = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
        probs = simple_resize_inference(model, image_tensor)
        prediction = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
        
        wt_mask = (prediction > 0).astype(np.uint8)
        tc_mask = ((prediction == 1) | (prediction == 3)).astype(np.uint8)
        et_mask = (prediction == 3).astype(np.uint8)
        
        submission_rows.append({'id': f'{patient_id}_WT', 'rle': rle_encode(wt_mask)})
        submission_rows.append({'id': f'{patient_id}_TC', 'rle': rle_encode(tc_mask)})
        submission_rows.append({'id': f'{patient_id}_ET', 'rle': rle_encode(et_mask)})
    
    df = pd.DataFrame(submission_rows)
    df.to_csv(output_path, index=False)
    print(f"Submission saved to {output_path}")
    return df


# Generate submission
# df = generate_submission(model, TEST_PATH)