In [1]:
import os
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

In [2]:
class FootprintPatchDataset(Dataset):
    """
    Classification-style dataset for cropped footprint patches.
    Assumes directory structure:
      root_dir/
        class_name0/
          *.jpg
        class_name1/
          *.jpg
        ...
    """
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = []

        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = self.root_dir / class_name
            if not class_path.is_dir():
                continue

            self.class_to_idx[class_name] = class_idx
            self.idx_to_class.append(class_name)

            for fname in os.listdir(class_path):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(class_path / fname)
                    self.labels.append(class_idx)

        print(f"Loaded {len(self.image_paths)} images "
              f"from {self.root_dir}, {len(self.idx_to_class)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label


class ContrastiveTransform:
    """
    Wraps a base transform and returns TWO augmented views.
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        xi = self.base_transform(x)
        xj = self.base_transform(x)
        return xi, xj


In [3]:
PATCH_ROOT = Path(
    r"/users/PAS2985/tingle9/dataset/footprint_patches"
)

In [4]:
class BackboneEncoder(nn.Module):
    """
    Wrap torchvision backbones to output a feature vector h.
    Supports: resnet50, vgg16, vit_b_16
    """
    def __init__(self, name="resnet50", pretrained=True):
        super().__init__()
        self.name = name

        if name == "resnet50":
            m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
            self.backbone = nn.Sequential(*list(m.children())[:-1])  # remove fc
            self.out_dim = 2048

        elif name == "vgg16":
            m = models.vgg16(weights=models.VGG16_Weights.DEFAULT if pretrained else None)
            self.backbone = m.features
            self.pool = nn.AdaptiveAvgPool2d((7, 7))
            # VGG classifier input is 512*7*7
            self.out_dim = 512 * 7 * 7

        elif name == "vit_b_16":
            m = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT if pretrained else None)
            # remove classification head
            m.heads = nn.Identity()
            self.backbone = m
            self.out_dim = 768

        else:
            raise ValueError(f"Unknown backbone name: {name}")

    def forward(self, x):
        if self.name == "resnet50":
            h = self.backbone(x)           # (B, 2048, 1, 1)
            h = h.flatten(1)               # (B, 2048)
            return h

        elif self.name == "vgg16":
            h = self.backbone(x)           # (B, 512, H, W)
            h = self.pool(h)               # (B, 512, 7, 7)
            h = h.flatten(1)               # (B, 25088)
            return h

        elif self.name == "vit_b_16":
            h = self.backbone(x)           # (B, 768)
            return h

In [5]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)
    
class ContrastiveModel(nn.Module):
    def __init__(self, backbone, proj_dim=128, pretrained=True):
        super().__init__()
        self.encoder = BackboneEncoder(backbone, pretrained=pretrained)
        self.projection = ProjectionHead(self.encoder.out_dim, proj_dim)

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection(h)
        return h, z

In [6]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    device = z_i.device
    batch_size = z_i.size(0)

    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)

    z = torch.cat([z_i, z_j], dim=0)  # (2N, d)
    sim = torch.matmul(z, z.T) / temperature

    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    sim = sim.masked_fill(mask, float('-inf'))

    pos_indices = (torch.arange(2 * batch_size, device=device) + batch_size) % (2 * batch_size)
    positives = sim[torch.arange(2 * batch_size, device=device), pos_indices]
    log_sum_exp = torch.logsumexp(sim, dim=1)

    loss = - (positives - log_sum_exp).mean()
    return loss

In [7]:
def train_contrastive(model, dataloader, optimizer, device, epochs=50, temperature=0.5):
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = 0

        for batch_idx, (batch_views, _) in enumerate(dataloader):
            xi, xj = batch_views
            xi = xi.to(device)
            xj = xj.to(device)

            optimizer.zero_grad()

            _, zi = model(xi)
            _, zj = model(xj)

            loss = nt_xent_loss(zi, zj, temperature=temperature)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        print(f"Epoch [{epoch+1}/{epochs}] loss={total_loss/num_batches:.4f}")

In [8]:
def main():
    NN = ["resnet50", "vgg16","vit_b_16"]
    
    for backbone in NN:
        proj_dim = 128
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        print(f"Training contrastive model: {backbone}")
        
        if backbone == "vit_b_16":
            image_size = 224
        else:
            image_size = 128
    
        IMAGENET_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_STD  = (0.229, 0.224, 0.225)

        base_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
            transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                saturation=0.4, hue=0.1)],
                                   p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])

        contrastive_transform = ContrastiveTransform(base_transform)

        batch_size = 64

        train_dataset_contrastive = FootprintPatchDataset(
            PATCH_ROOT / "train",
            transform=contrastive_transform
        )

        contrastive_loader = DataLoader(
            train_dataset_contrastive,
            batch_size=batch_size,
            shuffle=True,
            num_workers=8,
            drop_last=True,
            pin_memory=True,
            persistent_workers=True
        )
    
        model = ContrastiveModel(backbone=backbone, proj_dim=proj_dim, pretrained=True).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
        train_contrastive(model, contrastive_loader, optimizer, device, epochs=50, temperature=0.5)
    
        torch.save(model.encoder.state_dict(), f"encoder_{backbone}_contrastive.pth")
    
main()

Using device: cuda
Training contrastive model: resnet50
Loaded 12575 images from /users/PAS2985/tingle9/dataset/footprint_patches/train, 117 classes.
Epoch [1/50] loss=3.3910
Epoch [2/50] loss=3.1803
Epoch [3/50] loss=3.1376
Epoch [4/50] loss=3.1168
Epoch [5/50] loss=3.1132
Epoch [6/50] loss=3.1095
Epoch [7/50] loss=3.0993
Epoch [8/50] loss=3.0975
Epoch [9/50] loss=3.0958
Epoch [10/50] loss=3.0939
Epoch [11/50] loss=3.0886
Epoch [12/50] loss=3.0891
Epoch [13/50] loss=3.0916
Epoch [14/50] loss=3.0880
Epoch [15/50] loss=3.0842
Epoch [16/50] loss=3.0801
Epoch [17/50] loss=3.0806
Epoch [18/50] loss=3.0792
Epoch [19/50] loss=3.0782
Epoch [20/50] loss=3.0745
Epoch [21/50] loss=3.0708
Epoch [22/50] loss=3.0723
Epoch [23/50] loss=3.0712
Epoch [24/50] loss=3.0686
Epoch [25/50] loss=3.0636
Epoch [26/50] loss=3.0648
Epoch [27/50] loss=3.0615
Epoch [28/50] loss=3.0587
Epoch [29/50] loss=3.0588
Epoch [30/50] loss=3.0562
Epoch [31/50] loss=3.0581
Epoch [32/50] loss=3.0546
Epoch [33/50] loss=3.0532
E