In [2]:
import os
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [3]:
class FootprintPatchDataset(Dataset):
    """
    Classification-style dataset for cropped footprint patches.
    Assumes directory structure:
      root_dir/
        class_name0/
          *.jpg
        class_name1/
          *.jpg
        ...
    """
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = []

        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = self.root_dir / class_name
            if not class_path.is_dir():
                continue

            self.class_to_idx[class_name] = class_idx
            self.idx_to_class.append(class_name)

            for fname in os.listdir(class_path):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(class_path / fname)
                    self.labels.append(class_idx)

        print(f"Loaded {len(self.image_paths)} images "
              f"from {self.root_dir}, {len(self.idx_to_class)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label


class ContrastiveTransform:
    """
    Wraps a base transform and returns TWO augmented views.
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        xi = self.base_transform(x)
        xj = self.base_transform(x)
        return xi, xj


In [4]:
PATCH_ROOT = Path(
    r"C:\Users\Kdbro\OneDrive\Desktop\OSU\Fall 2025\Neural Networks\Contrastive-and-Attribute-Aligned-Representation-Learning-for-Animal-Footprint\dataset\footprint_patches"
)
image_size = 128

base_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.4, contrast=0.4,
                               saturation=0.4, hue=0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
])

contrastive_transform = ContrastiveTransform(base_transform)

batch_size = 64

train_dataset_contrastive = FootprintPatchDataset(
    PATCH_ROOT / "train",
    transform=contrastive_transform
)

contrastive_loader = DataLoader(
    train_dataset_contrastive,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

batch_views, labels = next(iter(contrastive_loader))
xi, xj = batch_views
print(xi.shape, xj.shape)  # should be [batch_size, 3, 128, 128]
print(labels.shape)        # [batch_size]

Loaded 12575 images from C:\Users\Kdbro\OneDrive\Desktop\OSU\Fall 2025\Neural Networks\Contrastive-and-Attribute-Aligned-Representation-Learning-for-Animal-Footprint\dataset\footprint_patches\train, 117 classes.
torch.Size([64, 3, 128, 128]) torch.Size([64, 3, 128, 128])
torch.Size([64])


In [5]:
class FootprintEncoder(nn.Module):
    """
    Simple CNN encoder for footprint images.
    Input: (B, 3, 128, 128)
    Output: (B, feature_dim)
    """
    def __init__(self, feature_dim=256):
        super().__init__()
        self.feature_dim = feature_dim

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 128 -> 64
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 64 -> 32
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 32 -> 16
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(256, feature_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [6]:
class FootprintEncoder(nn.Module):
    """
    Simple CNN encoder for footprint images.
    Input: (B, 3, 128, 128)
    Output: (B, feature_dim)
    """
    def __init__(self, feature_dim=256):
        super().__init__()
        self.feature_dim = feature_dim

        # Conv block 1: 3 -> 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 128 -> 64
        )
        # Conv block 2: 32 -> 64
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 64 -> 32
        )
        # Conv block 3: 64 -> 128
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)  # 32 -> 16
        )
        # Conv block 4: 128 -> 256
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # Global average pooling to (B, 256, 1, 1)
            nn.AdaptiveAvgPool2d((1, 1))
        )

        # Final linear layer to get a compact feature vector
        self.fc = nn.Linear(256, feature_dim)

    def forward(self, x):
        x = self.conv1(x)   # (B, 32, 64, 64)
        x = self.conv2(x)   # (B, 64, 32, 32)
        x = self.conv3(x)   # (B, 128, 16, 16)
        x = self.conv4(x)   # (B, 256, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 256)
        x = self.fc(x)      # (B, feature_dim)
        return x


In [7]:
class ProjectionHead(nn.Module):
    def __init__(self, feature_dim=256, projection_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, projection_dim)
        )

    def forward(self, x):
        return self.net(x)


class ContrastiveModel(nn.Module):
    def __init__(self, feature_dim=256, projection_dim=128):
        super().__init__()
        self.encoder = FootprintEncoder(feature_dim=feature_dim)
        self.projection_head = ProjectionHead(
            feature_dim=feature_dim,
            projection_dim=projection_dim
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return h, z


In [8]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    """
    SimCLR-style NT-Xent loss.
    """
    device = z_i.device
    batch_size = z_i.size(0)

    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)

    z = torch.cat([z_i, z_j], dim=0)  # (2N, d)

    sim = torch.matmul(z, z.T) / temperature

    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    sim = sim.masked_fill(mask, float('-inf'))

    pos_indices = (torch.arange(2 * batch_size, device=device)
                   + batch_size) % (2 * batch_size)

    positives = sim[torch.arange(2 * batch_size, device=device), pos_indices]
    log_sum_exp = torch.logsumexp(sim, dim=1)

    loss = - (positives - log_sum_exp).mean()
    return loss


In [13]:
def train_contrastive(model,
                      dataloader,
                      optimizer,
                      device,
                      epochs=50,
                      temperature=0.5):
    model.train()
    print(">>> Entered train_contrastive")  # DEBUG

    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = 0
        print(f"Starting epoch {epoch+1}/{epochs}")  # DEBUG

        for batch_idx, (batch_views, _) in enumerate(dataloader):
            xi, xj = batch_views
            xi = xi.to(device)
            xj = xj.to(device)

            optimizer.zero_grad()

            _, zi = model(xi)
            _, zj = model(xj)

            loss = nt_xent_loss(zi, zj, temperature=temperature)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            # Print every 20 batches so you see progress
            if (batch_idx + 1) % 20 == 0:
                print(f"  Epoch {epoch+1}, batch {batch_idx+1}, "
                      f"loss={loss.item():.4f}")

        avg_loss = total_loss / max(1, num_batches)
        print(f"Epoch [{epoch+1}/{epochs}] - Contrastive loss: {avg_loss:.4f}")


In [15]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_dim = 256
    projection_dim = 128
    temperature = 0.5
    epochs = 50

    model = ContrastiveModel(
        feature_dim=feature_dim,
        projection_dim=projection_dim
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3,
        weight_decay=1e-4
    )

    train_contrastive(
        model,
        contrastive_loader,
        optimizer,
        device,
        epochs=epochs,
        temperature=temperature
    )

    torch.save(model.encoder.state_dict(), "footprint_encoder_contrastive.pth")


main()


>>> Entered train_contrastive
Starting epoch 1/50
  Epoch 1, batch 20, loss=4.0820
  Epoch 1, batch 40, loss=4.0583
  Epoch 1, batch 60, loss=4.1995
  Epoch 1, batch 80, loss=4.1837
  Epoch 1, batch 100, loss=4.0254
  Epoch 1, batch 120, loss=3.9178
  Epoch 1, batch 140, loss=3.9302
  Epoch 1, batch 160, loss=3.7761
  Epoch 1, batch 180, loss=3.8408
Epoch [1/50] - Contrastive loss: 4.0086
Starting epoch 2/50
  Epoch 2, batch 20, loss=3.8157
  Epoch 2, batch 40, loss=3.7793
  Epoch 2, batch 60, loss=3.7698
  Epoch 2, batch 80, loss=3.7840
  Epoch 2, batch 100, loss=3.7207
  Epoch 2, batch 120, loss=3.7616
  Epoch 2, batch 140, loss=3.6072
  Epoch 2, batch 160, loss=3.6127
  Epoch 2, batch 180, loss=3.6929
Epoch [2/50] - Contrastive loss: 3.7394
Starting epoch 3/50
  Epoch 3, batch 20, loss=3.5766
  Epoch 3, batch 40, loss=3.6986
  Epoch 3, batch 60, loss=3.6591
  Epoch 3, batch 80, loss=3.6491
  Epoch 3, batch 100, loss=3.5934
  Epoch 3, batch 120, loss=3.5656
  Epoch 3, batch 140, loss