  from .autonotebook import tqdm as notebook_tqdm


{'device': 'cpu', 'torch': '2.2.0'}


In [2]:
# Data loading
BATCH_SIZE = 256

def get_dataloaders(batch_size=BATCH_SIZE):
    tf = transforms.Compose([
        transforms.ToTensor(),
        # Keep in [0,1] since we want direct pixel-coefficient interpretability
    ])
    train_ds = datasets.MNIST(root=os.path.join('~', '.torch', 'datasets'), train=True, download=True, transform=tf)
    test_ds = datasets.MNIST(root=os.path.join('~', '.torch', 'datasets'), train=False, download=True, transform=tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
    return train_loader, test_loader

train_loader, test_loader = get_dataloaders()
len_train, len_test = len(train_loader.dataset), len(test_loader.dataset)
print({'train': len_train, 'test': len_test})


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [01:18<00:00, 127013.52it/s]


Extracting /Users/sun/.torch/datasets/MNIST/raw/train-images-idx3-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 104851.88it/s]


Extracting /Users/sun/.torch/datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:04<00:00, 371152.37it/s]


Extracting /Users/sun/.torch/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13118.29it/s]

Extracting /Users/sun/.torch/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/sun/.torch/datasets/MNIST/raw

{'train': 60000, 'test': 10000}





In [3]:
# Ring (hole) feature computation utilities
import collections

def compute_ring_feature(img: torch.Tensor, threshold: float = 0.5) -> int:
    """
    Given a grayscale MNIST image tensor of shape [1, 28, 28] in [0,1],
    returns 1 if there exists at least one background (0) connected component
    fully enclosed by foreground (1), else 0.
    We threshold > threshold as foreground.
    """
    assert img.ndim == 3 and img.shape[0] == 1, "Expected [1,H,W]"
    h, w = img.shape[1], img.shape[2]
    x = (img[0] > threshold).cpu().numpy().astype(np.uint8)  # 1 for foreground strokes

    # Background mask (0 where foreground, 1 where background)
    bg = (x == 0).astype(np.uint8)

    # Flood-fill background from border to mark non-hole background
    visited = np.zeros_like(bg, dtype=np.uint8)
    dq = collections.deque()

    # Push all border background pixels
    for i in range(h):
        if bg[i, 0] and not visited[i, 0]:
            visited[i, 0] = 1
            dq.append((i, 0))
        if bg[i, w - 1] and not visited[i, w - 1]:
            visited[i, w - 1] = 1
            dq.append((i, w - 1))
    for j in range(w):
        if bg[0, j] and not visited[0, j]:
            visited[0, j] = 1
            dq.append((0, j))
        if bg[h - 1, j] and not visited[h - 1, j]:
            visited[h - 1, j] = 1
            dq.append((h - 1, j))

    # 4-connected BFS
    OFFSETS = [(1,0), (-1,0), (0,1), (0,-1)]
    while dq:
        i, j = dq.popleft()
        for di, dj in OFFSETS:
            ni, nj = i + di, j + dj
            if 0 <= ni < h and 0 <= nj < w and bg[ni, nj] and not visited[ni, nj]:
                visited[ni, nj] = 1
                dq.append((ni, nj))

    # Any background pixel not visited is a hole pixel
    holes = (bg == 1) & (visited == 0)
    return int(holes.any())

class RingMNIST(Dataset):
    def __init__(self, base: datasets.MNIST, threshold: float = 0.5):
        self.base = base
        self.threshold = threshold

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        img, label = self.base[idx]  # img: [1,28,28] float in [0,1]
        ring = compute_ring_feature(img, threshold=self.threshold)
        # Flatten pixels and append ring feature as extra dimension
        pixels = img.view(-1)
        feat = torch.cat([pixels, torch.tensor([float(ring)], dtype=pixels.dtype)])
        return feat, label, torch.tensor(ring, dtype=torch.float32)

INPUT_DIM = 28 * 28 + 1
NUM_CLASSES = 10


In [4]:
# Wrap loaders with RingMNIST and fix dataset path expansion
from pathlib import Path

DATA_ROOT = str(Path.home() / '.torch' / 'datasets')


def get_ring_loaders(batch_size=BATCH_SIZE, threshold: float = 0.5):
    tf = transforms.Compose([
        transforms.ToTensor(),
    ])
    base_train = datasets.MNIST(root=DATA_ROOT, train=True, download=True, transform=tf)
    base_test = datasets.MNIST(root=DATA_ROOT, train=False, download=True, transform=tf)

    ring_train = RingMNIST(base_train, threshold=threshold)
    ring_test = RingMNIST(base_test, threshold=threshold)

    def collate(batch):
        feats, labels, rings = zip(*batch)
        return torch.stack(feats), torch.tensor(labels, dtype=torch.long), torch.stack(rings)

    train_loader = DataLoader(ring_train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available(), collate_fn=collate)
    test_loader = DataLoader(ring_test, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available(), collate_fn=collate)
    return train_loader, test_loader

train_loader, test_loader = get_ring_loaders()
xb, yb, rb = next(iter(train_loader))
print({'batch_feats': tuple(xb.shape), 'batch_labels': tuple(yb.shape), 'batch_ring': tuple(rb.shape)})


AttributeError: Can't pickle local object 'get_ring_loaders.<locals>.collate'

In [None]:
# Sanity-check ring feature distribution
from collections import Counter

sample_loader, _ = get_ring_loaders(batch_size=1024)
ring_counts = Counter()
num_batches = 10
for i, (_, _, rb) in enumerate(sample_loader):
    ring_counts.update(rb.int().tolist())
    if i >= num_batches - 1:
        break
print({'ring_feature_counts_over_first_batches': dict(ring_counts)})


In [None]:
# Models: linear softmax and 1-hidden-layer MLP
class LinearSoftmax(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.W = nn.Linear(in_dim, num_classes, bias=True)

    def forward(self, x):
        return self.W(x)

class OneHiddenMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)

# Regularization: encourage reliance on ring feature
# For linear model, directly L2-boost the last input weight per class
# For MLP, encourage first-layer weights on ring feature and possibly path strength

def ring_weight_penalty_linear(model: LinearSoftmax, alpha: float = 1.0):
    # last input is ring feature
    W = model.W.weight  # [C, D]
    ring_w = W[:, -1]   # [C]
    return -alpha * torch.mean(ring_w.abs())  # negative to increase magnitude


def ring_weight_penalty_mlp(model: OneHiddenMLP, beta: float = 1.0):
    # Encourage large absolute weights from ring input into hidden units
    W1 = model.fc1.weight  # [H, D]
    ring_w1 = W1[:, -1]    # [H]
    return -beta * torch.mean(ring_w1.abs())


def train_model(model, train_loader, test_loader, epochs=3, lr=1e-2, ring_reg=0.0, model_type='linear'):
    model = model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        total_correct = 0
        total = 0
        for xb, yb, rb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            logits = model(xb)
            loss = criterion(logits, yb)
            if ring_reg > 0:
                if model_type == 'linear':
                    loss = loss + ring_reg * ring_weight_penalty_linear(model)
                else:
                    loss = loss + ring_reg * ring_weight_penalty_mlp(model)
            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += float(loss) * xb.size(0)
            total_correct += (logits.argmax(dim=1) == yb).sum().item()
            total += xb.size(0)
        train_loss = total_loss / total
        train_acc = total_correct / total

        # Eval
        model.eval()
        correct = 0
        count = 0
        with torch.no_grad():
            for xb, yb, rb in test_loader:
                xb = xb.to(DEVICE)
                yb = yb.to(DEVICE)
                logits = model(xb)
                correct += (logits.argmax(dim=1) == yb).sum().item()
                count += xb.size(0)
        test_acc = correct / count
        print({'epoch': epoch, 'train_loss': round(train_loss,4), 'train_acc': round(train_acc,4), 'test_acc': round(test_acc,4)})
    return model


In [None]:
# Train linear model with ring-boost regularization
linear = LinearSoftmax(INPUT_DIM, NUM_CLASSES)
linear = train_model(linear, train_loader, test_loader, epochs=5, lr=1e-3, ring_reg=1e-3, model_type='linear')

# Extract explicit formula: logits = W x + b
W_lin = linear.W.weight.detach().cpu().numpy()   # [10, 785]
b_lin = linear.W.bias.detach().cpu().numpy()     # [10]
ring_weights_lin = W_lin[:, -1]
print({'ring_weight_per_class_linear': ring_weights_lin.round(4).tolist()})


In [None]:
# Train 1-hidden-layer MLP with ring-boost regularization
mlp = OneHiddenMLP(INPUT_DIM, hidden_dim=64, num_classes=NUM_CLASSES)
mlp = train_model(mlp, train_loader, test_loader, epochs=5, lr=1e-3, ring_reg=1e-3, model_type='mlp')

# Extract path contribution of ring feature (approx): average |w1[:, -1]| and top contributors
with torch.no_grad():
    w1 = mlp.fc1.weight.detach().cpu()
    w2 = mlp.fc2.weight.detach().cpu()
ring_in_w = w1[:, -1].abs()
print({'mean_abs_ring_to_hidden': float(ring_in_w.mean())})


In [None]:
# Reporting: formulas and visualizations
import seaborn as sns

# Linear model: class c logit = dot(W_c, pixels_with_ring) + b_c
print('Linear softmax explicit form:')
for c in range(NUM_CLASSES):
    print(f"class {c}: logit = b[{c}] + sum_i W[{c},i]*x[i] + W[{c},ring]*ring")

# Show ring weights
plt.figure(figsize=(6,3))
plt.bar(np.arange(NUM_CLASSES), ring_weights_lin)
plt.title('Linear: ring weight per class')
plt.xlabel('class')
plt.ylabel('weight on ring feature')
plt.show()

# Visualize pixel weights per class for linear model
fig, axes = plt.subplots(2, 5, figsize=(12,5))
for c, ax in enumerate(axes.flat):
    ax.imshow(W_lin[c, :-1].reshape(28,28), cmap='coolwarm')
    ax.set_title(f'class {c}')
    ax.axis('off')
plt.suptitle('Linear: pixel weights per class')
plt.tight_layout()
plt.show()

# For MLP, logits = W2 * ReLU(W1 * x) + b2
print('MLP (1 hidden layer, ReLU) form: logits = W2 * ReLU(W1 * x) + b2')
# Ring path strengths approximation: |w1[:, -1]| * ||w2|| per class
ring_to_hidden = w1[:, -1].numpy()
ring_path_strength = np.abs(ring_to_hidden)  # simple proxy
plt.figure(figsize=(6,3))
plt.hist(ring_path_strength, bins=20)
plt.title('MLP: |ring -> hidden| weight distribution')
plt.show()
