# Inference with customly trained SegFormer

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Patch
from PIL import Image
from safetensors import safe_open
from torchvision import transforms

sys.path.append("../")
from src.datamodule import SegmentationDataset
from src.pl_module import SegformerConfig, SegformerForSemanticSegmentation
from utils import LABEL_COLORS, label2rgb

%load_ext autoreload
%autoreload 2

In [None]:
# Configure device
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(device)

In [None]:
# Load sample image
img_dir = "../data/025_08.jpg"

pil_img = Image.open(img_dir)
print(pil_img)
img_transforms = SegmentationDataset.segformer_img_tfms

img_in = img_transforms(pil_img)
img_in = img_in.unsqueeze(0)  # Add batch dimension
print(img_in.shape)

In [None]:
# Load model from checkpoint
log_dir = Path("../logs_pl/mit-b0/version_0")
ckpt_file = "best_model.safetensors"

config = SegformerConfig.from_json_file(log_dir / "model_config.json")
model = SegformerForSemanticSegmentation(config)

# torch.load has security issues, use safetensors instead
# ckpt = torch.load(log_dir / ckpt_fp, weights_only=True)
# print([key for key in ckpt["state_dict"].keys() if "decode_head" in key])

# # Remove "model." prefix from checkpoint keys
# state_dict = {key.replace("model.", ""): value for key, value in ckpt["state_dict"].items()}

state_dict = {}
with safe_open(log_dir / ckpt_file, framework="pt", device="cpu") as f:
    for key in f.keys():
        state_dict[key] = f.get_tensor(key)

# print(model.state_dict().keys())
model.load_state_dict(state_dict)

In [None]:
# Run inference
model.eval()
img_in = img_in.to(device)
model.to(device)
outs = model.forward(pixel_values=img_in, labels=None)
logits = outs.logits  # shape (batch_size, num_labels, ~height/4, ~width/4)

print(logits.shape)
print(pil_img.size)

In [None]:
# Adjust logits
# resize output to match input image dimensions
upsampled_logits = torch.nn.functional.interpolate(
    logits,
    size=pil_img.size,
    mode="bilinear",
    align_corners=False,  # H x W
)

# get label masks
labels = upsampled_logits.argmax(dim=1)[0]

# move to CPU to visualize in matplotlib
labels_viz = labels.cpu().numpy()
print(labels_viz.shape)

In [None]:
# Create a legend
legend_patches = [
    Patch(color=np.array(rgb), label=label)  # Normalize RGB to [0, 1] for matplotlib
    for rgb, label in zip(list(LABEL_COLORS.values()), list(LABEL_COLORS.keys()))
]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs: list[plt.Axes]

axs[0].imshow(pil_img)
axs[0].set_title("Image")
axs[1].imshow(label2rgb(labels_viz))
axs[1].set_title("Pred Mask")
axs[1].legend(
    handles=legend_patches, bbox_to_anchor=(1.05, 1.05), loc="upper left", title="Classes"
)
plt.show()