Input File(s)

In [28]:
img1Path = "./pipelineInput/0X7AFF5B2DEEE839B9_35.png"
img2Path = "./pipelineInput/0X7AFF5B2DEEE839B9_47.png"

Center Model

In [29]:
import cv2
import numpy as np
import torch

#prepare input
def loadImage(imgPath):
    image = cv2.imread(imgPath, cv2.IMREAD_GRAYSCALE)  # ensure 1-channel
    if image is None:
        raise FileNotFoundError(f"Image not found: {imgPath}")

    image = image.astype('float32') / 255.0  # normalize manually if not using transforms
    image = np.expand_dims(image, axis=0)    # [1, 112, 112]
    image = np.expand_dims(image, axis=0)    # [1, 1, 112, 112], simulated dataloader
    image = torch.tensor(image, dtype=torch.float32)

    return image

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionNet(nn.Module):
    def __init__(self, image_size=112, patch_size=16, embed_dim=64, num_heads=4, depth=2):
        super().__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size"

        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.embed_dim = embed_dim

        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)

        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.head = nn.Sequential(
        nn.LayerNorm(embed_dim),
        nn.Linear(embed_dim, 2),
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed

        x = self.transformer(x)
        x = x.mean(dim=1)

        out = self.head(x)
        return out

In [None]:
import torch
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionNet().to(device)

weightsPath = "./center_model.pth"
model.load_state_dict(torch.load(weightsPath, map_location=device))

img1 = loadImage(img1Path)
coord1 = model(img1.to(device))*112

img2 = loadImage(img2Path)
coord2 = model(img2.to(device))*112

MedSam