In [None]:
import os
from pathlib import Path

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import matplotlib.pyplot as plt


In [None]:
# (update these!)
PATCH_ROOT = Path("/path/to/footprint_patches")

# Encoder checkpoint 
ENCODER_CKPT = "footprint_encoder_contrastive.pth"

image_size = 128
feature_dim = 256
batch_size = 64
num_epochs = 20
learning_rate = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
class FootprintPatchDataset(Dataset):
    """
    Classification-style dataset for cropped footprint patches.
    Assumes directory structure:
      root_dir/
        class_name0/
          *.jpg
        class_name1/
          *.jpg
        ...
    """
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = []

        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = self.root_dir / class_name
            if not class_path.is_dir():
                continue

            self.class_to_idx[class_name] = class_idx
            self.idx_to_class.append(class_name)

            for fname in os.listdir(class_path):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(class_path / fname)
                    self.labels.append(class_idx)

        print(f"Loaded {len(self.image_paths)} images "
              f"from {self.root_dir}, {len(self.idx_to_class)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label


In [None]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
])

eval_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

train_dataset = FootprintPatchDataset(
    PATCH_ROOT / "train",
    transform=train_transform
)

val_dataset = FootprintPatchDataset(
    PATCH_ROOT / "val",
    transform=eval_transform
)

num_classes = len(train_dataset.idx_to_class)
print("Number of classes:", num_classes)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)


In [None]:
class FootprintEncoder(nn.Module):
    """
    Simple CNN encoder for footprint images.
    Input: (B, 3, 128, 128)
    Output: (B, feature_dim)
    """
    def __init__(self, feature_dim=256):
        super().__init__()
        self.feature_dim = feature_dim

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 128 -> 64
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 64 -> 32
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 32 -> 16
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(256, feature_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [None]:
encoder = FootprintEncoder(feature_dim=feature_dim)

state_dict = torch.load(ENCODER_CKPT, map_location=device)
encoder.load_state_dict(state_dict)

# Freeze encoder parameters
for p in encoder.parameters():
    p.requires_grad = False

encoder = encoder.to(device)
encoder.eval()

print("Encoder loaded and frozen.")


In [None]:
class LinearProbe(nn.Module):
    def __init__(self, encoder, feature_dim=256, num_classes=10):
        super().__init__()
        self.encoder = encoder
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)   # (B, feature_dim)
        logits = self.fc(features)       # (B, num_classes)
        return logits

model = LinearProbe(
    encoder=encoder,
    feature_dim=feature_dim,
    num_classes=num_classes
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.fc.parameters(),  # ONLY train classifier head
    lr=learning_rate,
    weight_decay=1e-4
)

model


In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total if total > 0 else 0.0
    return acc


def train_linear_probe(model, train_loader, val_loader, optimizer,
                       criterion, device, num_epochs=10):
    train_acc_history = []
    val_acc_history = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total if total > 0 else 0.0
        train_acc = correct / total if total > 0 else 0.0

        val_acc = evaluate(model, val_loader, device)

        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f}  "
              f"Train Acc: {train_acc*100:.2f}%  "
              f"Val Acc: {val_acc*100:.2f}%")

    return train_acc_history, val_acc_history


In [None]:
train_acc_hist, val_acc_hist = train_linear_probe(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    num_epochs=num_epochs
)

epochs = range(1, num_epochs + 1)

plt.figure(figsize=(6, 4))
plt.plot(epochs, train_acc_hist, label="Train Acc")
plt.plot(epochs, val_acc_hist, label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Linear Probe Accuracy")
plt.legend()
plt.grid(True)
plt.show()