In [1]:
from sequence_detection.models_mae_finetune import MaskedAutoencoderViTClassify
from functools import partial
import torch
import torch.nn as nn
detect_sequence = ['T1_T1flair', 'T2', 'T2flair_flair', 'PD', 'T2star_hemo', 'T2star_swi', 'DTI_DWI']
model = MaskedAutoencoderViTClassify(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=len(detect_sequence) + 1, mode='cls')

pretrain_path = './finetune_models/sequence_detection/mae_encoder_cls_token_e1000.pt'

missing, unexpected = model.load_state_dict(torch.load(pretrain_path)['model_state_dict'], strict=False)  # strict=False ignores unmatched keys
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
print(sum(p.numel() for p in model.parameters()))
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
def get_model_size_in_mb(model):
    """
    Get the size (in MB) of a model's parameters and buffers on GPU memory.
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_in_bytes = param_size + buffer_size
    size_in_megabytes = size_in_bytes / (1024 ** 2)  # bytes to MB
    return size_in_megabytes

def get_trainable_model_size_in_mb(model):
    """
    Returns the total GPU memory (in MB) used by trainable parameters of a model.
    """
    total_bytes = sum(
        p.nelement() * p.element_size()
        for p in model.parameters()
        if p.requires_grad
    )
    return total_bytes / (1024 ** 2)



model = model.to('cuda')
def freeze_mae_encoder_and_decoder(model):
    for name, param in model.named_parameters():
        if name.startswith("patch_embed") or name.startswith("blocks") or \
           name.startswith("norm") or name in ["cls_token", "pos_embed"] or \
           name.startswith("decoder") or name == "mask_token":
            param.requires_grad = False
freeze_mae_encoder_and_decoder(model)
model_size_mb = get_trainable_model_size_in_mb(model)
print(f"Model size on GPU: {model_size_mb:.2f} MB")

Missing keys: []
Unexpected keys: []
111913992
111661832
Model size on GPU: 0.02 MB


In [2]:
def predict_sequence(model, image_path, detect_sequence, device='cuda', cls_strategy='cls_token'):
    """
    Predict MRI sequence label from a single slice image.

    Args:
        model: Trained model with .forward_cls() method
        image_path: Path to the input PNG image
        device: 'cuda' or 'cpu'
        cls_strategy: Classification strategy used in the model

    Returns:
        (predicted_index, predicted_label)
    """
    import torchvision.transforms as T
    from PIL import Image
    import numpy as np
    import torch


    # === Load grayscale image
    img = Image.open(image_path).convert('L')

    # === Convert to numpy and min-max normalize
    img_np = np.array(img).astype(np.float32)
    if img_np.max() > img_np.min():
        img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    else:
        img_np[:] = 0.0  # Avoid division by zero

    # === Convert to 3-channel
    img_3ch = np.stack([img_np] * 3, axis=-1)  # shape: (H, W, 3)

    # === Convert to PIL and apply transforms
    img_pil = Image.fromarray((img_3ch * 255).astype(np.uint8))
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),  # (3, 224, 224), float32 in [0, 1]
    ])
    img_tensor = transform(img_pil).unsqueeze(0).to(device)

    # === Run inference
    model.eval()
    model.to(device)
    with torch.no_grad():
        logits = model.forward_cls(img_tensor, cls_strategy=cls_strategy)
        pred_index = torch.argmax(logits, dim=1).item()

    pred_label = detect_sequence[pred_index]
    return pred_index, pred_label


In [3]:
image_path = './demo_data/sequence_detection/T2star_swi/2.png'

# Predict
pred_idx, pred_label = predict_sequence(model, image_path, detect_sequence=detect_sequence, device='cuda')
print(f"Predicted index: {pred_idx}")
print(f"Predicted label: {pred_label}")


Predicted index: 5
Predicted label: T2star_swi
