In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
import glob

In [None]:
                                                   # Load images and labeling

images = []

for file in glob.glob("data/*.png"):     # change to .jpg or .tif if needed
    img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)

    # ensure image is 20x20
    img = cv2.resize(img, (20,20))
    img = img.astype(np.float32)
    img = img / np.max(img)              # normalize

    images.append(img)

images = np.array(images)   # shape: (N, 20,20)
print("Loaded images:", images.shape)


centroids = np.loadtxt("centroids.txt", dtype=int)

labels = []
for cx, cy in centroids:
    lab = cy * 20 + cx    # flatten to 0-399
    labels.append(lab)

labels = np.array(labels)
print("Labels shape:", labels.shape)


# Convert to tensors
X = torch.tensor([img.reshape(-1, 400) for img in images], dtype=torch.float32)
Y = torch.tensor(labels, dtype=torch.long)


In [None]:
                                          # SHNN-50 Model

class SHNN50(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(400, 50)
        self.fc2 = nn.Linear(50, 400)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)   # no softmax here (CrossEntropyLoss handles it)
        return x

model = SHNN50()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
                                                 # Train model

epochs = 30
batch_size = 128

for epoch in range(epochs):
    perm = torch.randperm(X.size(0))

    for i in range(0, X.size(0), batch_size):
        idx = perm[i:i+batch_size]
        x_batch = X[idx]
        y_batch = Y[idx]

        optimizer.zero_grad()
        out = model(x_batch)
        loss = criterion(out, y_batch)
        loss.backward()
        optimizer.step()

    print("Epoch:", epoch+1, "Loss:", loss.item())


In [None]:
                                                 # Predict Center Pixel


def predict_center(model, img20):
    img = torch.tensor(img20.reshape(400), dtype=torch.float32)
    out = model(img)
    k = torch.argmax(out).item()

    cx = k % 20
    cy = k // 20
    return cx, cy


In [None]:
                                             # Subpixel Refinement

def refine_centroid(img, cx, cy):
    pts = []
    for dy in [-1,0,1]:
        for dx in [-1,0,1]:
            x = cx + dx
            y = cy + dy
            if 0 <= x < 20 and 0 <= y < 20:
                pts.append((img[y,x], x, y))

    pts.sort(reverse=True)
    top5 = pts[:5]

    I = np.array([p[0] for p in top5])
    Xc = np.array([p[1] for p in top5])
    Yc = np.array([p[2] for p in top5])

    x_sub = np.sum(I * Xc) / np.sum(I)
    y_sub = np.sum(I * Yc) / np.sum(I)
    return x_sub, y_sub


In [None]:
                                             # Final Centroid Function

def SH_centroid(model, img20):
    img20 = img20 / np.max(img20)  # normalize
    cx, cy = predict_center(model, img20)
    x_sub, y_sub = refine_centroid(img20, cx, cy)
    return x_sub, y_sub

img = images[0]   # first image
x,y = SH_centroid(model, img)
print("Centroid:", x, y)