In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from PIL import Image
import os
import numpy as np
from tqdm import tqdm

In [None]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-3


In [None]:

class MNISTWithCircleMask(Dataset):
    def __init__(self, mnist_dataset, mask_dir):
        self.mnist = mnist_dataset
        self.mask_dir = mask_dir
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        img, label = self.mnist[idx]
        img_tensor = self.transform(img)

        mask_path = os.path.join(self.mask_dir, f"{idx:05d}_circle.png")
        mask = Image.open(mask_path).convert("L")
        mask_tensor = self.transform(mask)

        return img_tensor, label, mask_tensor

In [None]:

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.fc(self.conv(x))

In [None]:

train_mnist = MNIST(root="./data", train=True, download=False)
test_mnist = MNIST(root="./data", train=False, download=False)

train_dataset = MNISTWithCircleMask(train_mnist, "output/circular_localization")
test_dataset = MNISTWithCircleMask(test_mnist, "output/circular_localization")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
model = SimpleCNN().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        preds = model(imgs)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}")

Epoch 1/5: 100%|██████████| 938/938 [04:59<00:00,  3.13it/s]


Epoch 1 Loss: 0.2237


Epoch 2/5: 100%|██████████| 938/938 [02:48<00:00,  5.58it/s]


Epoch 2 Loss: 0.0611


Epoch 3/5: 100%|██████████| 938/938 [01:21<00:00, 11.49it/s]


Epoch 3 Loss: 0.0413


Epoch 4/5: 100%|██████████| 938/938 [01:23<00:00, 11.20it/s]


Epoch 4 Loss: 0.0321


Epoch 5/5: 100%|██████████| 938/938 [01:21<00:00, 11.58it/s]

Epoch 5 Loss: 0.0263





In [None]:

def estimate_center_radius(img_tensor):
    img_np = img_tensor.squeeze().cpu().numpy()
    coords = np.argwhere(img_np > 0.1)
    if coords.size == 0:
        return (14, 14), 10
    y_center, x_center = coords.mean(axis=0)
    y_center, x_center = int(y_center), int(x_center)
    distances = np.sqrt(((coords - [y_center, x_center])**2).sum(axis=1))
    radius = int(np.percentile(distances, 90))
    return (x_center, y_center), radius

def generate_circular_mask(shape, center, radius):
    H, W = shape
    Y, X = np.ogrid[:H, :W]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
    mask = (dist_from_center <= radius).astype(np.uint8)
    return torch.tensor(mask, dtype=torch.float32).unsqueeze(0).to(DEVICE)

In [None]:

def compute_iou(pred_mask, gt_mask):
    pred_bin = (pred_mask > 0.5).float()
    gt_bin = (gt_mask > 0.5).float()
    intersection = (pred_bin * gt_bin).sum()
    union = ((pred_bin + gt_bin) > 0).float().sum()
    return (intersection / union).item() if union != 0 else 0.0


In [None]:
model.eval()
total_iou = 0.0
correct_cls = 0

with torch.no_grad():
    for img, label, gt_mask in tqdm(test_loader, desc="Evaluating"):
        img, label = img.to(DEVICE), label.to(DEVICE)
        gt_mask = gt_mask.to(DEVICE)

        output = model(img)
        pred_label = output.argmax(dim=1)

        if pred_label.item() == label.item():
            correct_cls += 1

            center, radius = estimate_center_radius(img)
            pred_mask = generate_circular_mask((28, 28), center, radius)
            iou = compute_iou(pred_mask, gt_mask)
        else:
            iou = 0.0

        total_iou += iou

avg_iou = total_iou / len(test_loader)
accuracy = correct_cls / len(test_loader)

print(f"\nClassification Accuracy: {accuracy * 100:.2f}%")
print(f"Average IoU (with 0 for misclassification): {avg_iou:.4f}")

Evaluating: 100%|██████████| 10000/10000 [00:47<00:00, 210.96it/s]


Classification Accuracy: 98.69%
Average IoU (with 0 for misclassification): 0.7180



