In [1]:
# ────────────── Setup ────────────── #
import os
import random
import torch
from transformers import AutoModel, AutoTokenizer
from data_loader import OASISDataset
from torch.utils.data import DataLoader
from segmentation_models_pytorch import Unet
from IPython.display import Image
from test_with_bert_updated import VisualGuidedCrossAttention, reshape_tensor, dice_coeff_and_loss, save_prediction_vs_mask_gif, save_static_comparison_plot


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from test_with_bert_updated import VisualGuidedCrossAttention, reshape_tensor, dice_coeff_and_loss, save_prediction_vs_mask_gif, save_static_comparison_plot

In [3]:
BRAIN_PARTS = 4

In [4]:
# Parameters
checkpoint_dir = f"checkpoints-{BRAIN_PARTS}"
data_dir = "oasis-redefined"
model_name = "dmis-lab/biobert-base-cased-v1.1"
seq_len = 10
proj_dim = 128
encoded_dim = 512  # UNet ResNet34 encoder last layer dim
text_dim = 768

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ────────────── Load Model and Tokenizer ────────────── #
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to(device).eval()

model = Unet("resnet34", encoder_weights="imagenet", in_channels=1, classes=1).to(device)
cross_attn = VisualGuidedCrossAttention(encoded_dim, text_dim, proj_dim).to(device)

# Load best model checkpoint
checkpoint = torch.load(os.path.join(checkpoint_dir, "best_model.pth"), map_location=device)
model.load_state_dict(checkpoint["model"])
cross_attn.load_state_dict(checkpoint["cross_attn"])
model.eval()
cross_attn.eval()

  checkpoint = torch.load(os.path.join(checkpoint_dir, "best_model.pth"), map_location=device)


VisualGuidedCrossAttention(
  (query_proj): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
  (key_proj): Linear(in_features=768, out_features=128, bias=True)
  (value_proj): Linear(in_features=768, out_features=128, bias=True)
  (out_proj): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)

In [5]:
# ────────────── Choose Random Video ────────────── #
test_ds = OASISDataset(data_dir, f"test-{BRAIN_PARTS}", num_brain_parts=BRAIN_PARTS)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

Found 323 files in 'test-4' split


In [8]:
imgs, masks, ref_txts = next(iter(test_loader))
imgs, masks = imgs.to(device), masks.to(device)
imgs_r, masks_r = reshape_tensor(imgs), reshape_tensor(masks)

In [9]:
# ────────────── Custom Text Prompt ────────────── #
custom_prompts = ["Left-Cerebral-White-Matter", "Right-Cerebral-White-Matter", "Left-Thalamus", "Right-Thalamus"]
custom_prompt = "segment Left-Cerebral-White-Matter"

tok = tokenizer(
    custom_prompt,
    padding="max_length",
    truncation=True,
    max_length=seq_len,
    return_tensors="pt",
)
with torch.no_grad():
    text_emb = bert_model(
        input_ids=tok["input_ids"].to(device),
        attention_mask=tok["attention_mask"].to(device),
    ).last_hidden_state.detach()

# ────────────── Inference ────────────── #
with torch.no_grad():
    feats = model.encoder(imgs_r)
    feats[-1] = cross_attn(feats[-1], text_emb.expand(feats[-1].shape[0], -1, -1))
    logits = model.segmentation_head(model.decoder(feats))

# ────────────── Save Visualizations ────────────── #
os.makedirs("demo_outputs", exist_ok=True)
save_prediction_vs_mask_gif(imgs_r.cpu(), logits.cpu(), masks_r.cpu(), save_path="demo_outputs/result.gif", caption=custom_prompt)
save_static_comparison_plot(imgs_r.cpu(), logits.cpu(), masks_r.cpu(), save_path="demo_outputs/comparison.png", caption=custom_prompt, frame_stride=4)

GIF saved to demo_outputs/result.gif
Static plot saved to demo_outputs/comparison.png
