In [8]:
import os
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

In [9]:
class FootprintPatchDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = []

        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = self.root_dir / class_name
            if not class_path.is_dir():
                continue

            self.class_to_idx[class_name] = class_idx
            self.idx_to_class.append(class_name)

            for fname in os.listdir(class_path):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(class_path / fname)
                    self.labels.append(class_idx)

        print(f"Loaded {len(self.image_paths)} images "
              f"from {self.root_dir}, {len(self.idx_to_class)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label

In [10]:
class FootprintEncoder(nn.Module):
    def __init__(self, feature_dim=256):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(256, feature_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class BackboneEncoder(nn.Module):
    """
    Wrap torchvision backbones to output a feature vector h.
    Supports: resnet50, vgg16, vit_b_16
    """
    def __init__(self, name="resnet50", pretrained=True):
        super().__init__()
        self.name = name

        if name == "resnet50":
            m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
            self.backbone = nn.Sequential(*list(m.children())[:-1])  # remove fc
            self.out_dim = 2048

        elif name == "vgg16":
            m = models.vgg16(weights=models.VGG16_Weights.DEFAULT if pretrained else None)
            self.backbone = m.features
            self.pool = nn.AdaptiveAvgPool2d((7, 7))
            # VGG classifier input is 512*7*7
            self.out_dim = 512 * 7 * 7

        elif name == "vit_b_16":
            m = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT if pretrained else None)
            # remove classification head
            m.heads = nn.Identity()
            self.backbone = m
            self.out_dim = 768

        else:
            raise ValueError(f"Unknown backbone name: {name}")

    def forward(self, x):
        if self.name == "resnet50":
            h = self.backbone(x)           # (B, 2048, 1, 1)
            h = h.flatten(1)               # (B, 2048)
            return h

        elif self.name == "vgg16":
            h = self.backbone(x)           # (B, 512, H, W)
            h = self.pool(h)               # (B, 512, 7, 7)
            h = h.flatten(1)               # (B, 25088)
            return h

        elif self.name == "vit_b_16":
            h = self.backbone(x)           # (B, 768)
            return h

# Projector & Text encoder

class ImageProjector(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class TextTokenEncoder(nn.Module):
    """
    Simple text token encoder using an embedding layer.
    take class ID as text token input , no species name
    """
    def __init__(self, num_classes, out_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, out_dim)

    def forward(self, labels):
        return self.embedding(labels)

In [11]:
def clip_loss(image_z, text_z, temperature=0.07):
    image_z = F.normalize(image_z, dim=1)
    text_z = F.normalize(text_z, dim=1)

    logits = (image_z @ text_z.t()) / temperature
    targets = torch.arange(image_z.size(0), device=image_z.device)

    loss_i2t = F.cross_entropy(logits, targets)
    loss_t2i = F.cross_entropy(logits.t(), targets)
    return 0.5 * (loss_i2t + loss_t2i)

In [12]:
def load_contrastive_encoder(backbone, path, device):
    enc = BackboneEncoder(name=backbone, pretrained=False).to(device)
    state_dict = torch.load(path, map_location=device)
    enc.load_state_dict(state_dict)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    print(f"Loaded contrastive encoder : {backbone}")
    return enc

In [13]:
def train_rq2_alignment(
    encoder, img_proj, txt_enc,
    loader, device,
    epochs=20, lr=1e-3,
    temperature=0.07,
    freeze_encoder=True
):
    encoder.to(device)
    img_proj.to(device)
    txt_enc.to(device)

    if freeze_encoder:
        encoder.eval()
        for p in encoder.parameters():
            p.requires_grad = False
        params = list(img_proj.parameters()) + list(txt_enc.parameters())
    else:
        encoder.train()
        for p in encoder.parameters():
            p.requires_grad = True
        params = list(encoder.parameters()) + list(img_proj.parameters()) + list(txt_enc.parameters())

    opt = torch.optim.Adam(params, lr=lr, weight_decay=1e-4)

    for ep in range(epochs):
        total = 0.0
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            feats = encoder(imgs)          # (B, output)
            image_z = img_proj(feats)      # (B, 128)
            text_z = txt_enc(labels)       # (B, 128)

            loss = clip_loss(image_z, text_z, temperature)

            opt.zero_grad()
            loss.backward()
            opt.step()

            total += loss.item()

        print(f"[RQ2 Align] epoch {ep+1}/{epochs}, loss={total/len(loader):.4f}")

In [14]:
PATCH_ROOT = Path("/users/PAS2985/tingle9/dataset/footprint_patches")

proj_dim = 128
batch_size = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("\n RQ2 : small CNN encoder ")
image_size = 128

rq2_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
])

train_ds_cnn = FootprintPatchDataset(PATCH_ROOT / "train", transform=rq2_train_transform)
num_classes = len(train_ds_cnn.idx_to_class)

train_loader = DataLoader(
    train_ds_cnn,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    drop_last=True,
    pin_memory=True,
    persistent_workers=True
)

# 1) Load pre-trained encoder
CONTRASTIVE_CKPT = "footprint_encoder_contrastive.pth"
encoder = FootprintEncoder(feature_dim=256)
encoder.load_state_dict(torch.load(CONTRASTIVE_CKPT, map_location="cpu"))

img_proj = ImageProjector(in_dim=256, out_dim=proj_dim)
txt_enc = TextTokenEncoder(num_classes, proj_dim)

# 2) RQ2 alignment (encoder freeze)
train_rq2_alignment(
    encoder, img_proj, txt_enc,
    train_loader, device,
    epochs=50, lr=1e-3,
    freeze_encoder=True
)

torch.save(img_proj.state_dict(), "rq2_CNN_image_projector.pth")
torch.save(txt_enc.state_dict(), "rq2_CNN_text_token_encoder.pth")
print("saved RQ2 CNN checkpoints. ")

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

backbones = ["resnet50", "vgg16", "vit_b_16"]

for backbone in backbones:
    
    print(f"\n RQ2: {backbone} encoder ")
    CONTRASTIVE_CKPT = f"encoder_{backbone}_contrastive.pth"
    
    if backbone == "vit_b_16":
        image_size = 224
    else :
        image_size = 128
    
    rq2_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
    train_ds = FootprintPatchDataset(PATCH_ROOT / "train", transform=rq2_train_transform)
    
    train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    drop_last=True,
    pin_memory=True,
    persistent_workers=True
)
    
    encoder = load_contrastive_encoder(backbone, CONTRASTIVE_CKPT, device)

    img_proj = ImageProjector(in_dim=encoder.out_dim, out_dim=proj_dim)
    txt_enc = TextTokenEncoder(num_classes, proj_dim)

    # RQ2 alignment (encoder freeze)
    train_rq2_alignment(
        encoder, img_proj, txt_enc,
        train_loader, device,
        epochs=50, lr=1e-3,
        freeze_encoder=True
    )

    torch.save(img_proj.state_dict(), f"rq2_{backbone}_image_projector.pth")
    torch.save(txt_enc.state_dict(), f"rq2_{backbone}_text_token_encoder.pth")
    print(f"Saved RQ2 checkpoints for {backbone}.")


 RQ2 : small CNN encoder 
Loaded 12575 images from /users/PAS2985/tingle9/dataset/footprint_patches/train, 117 classes.


[RQ2 Align] epoch 1/50, loss=3.6972
[RQ2 Align] epoch 2/50, loss=3.3970
[RQ2 Align] epoch 3/50, loss=3.2662
[RQ2 Align] epoch 4/50, loss=3.1771
[RQ2 Align] epoch 5/50, loss=3.1082
[RQ2 Align] epoch 6/50, loss=3.0436
[RQ2 Align] epoch 7/50, loss=2.9998
[RQ2 Align] epoch 8/50, loss=2.9606
[RQ2 Align] epoch 9/50, loss=2.9203
[RQ2 Align] epoch 10/50, loss=2.8823
[RQ2 Align] epoch 11/50, loss=2.8459
[RQ2 Align] epoch 12/50, loss=2.8180
[RQ2 Align] epoch 13/50, loss=2.7963
[RQ2 Align] epoch 14/50, loss=2.7623
[RQ2 Align] epoch 15/50, loss=2.7426
[RQ2 Align] epoch 16/50, loss=2.7140
[RQ2 Align] epoch 17/50, loss=2.6981
[RQ2 Align] epoch 18/50, loss=2.6710
[RQ2 Align] epoch 19/50, loss=2.6549
[RQ2 Align] epoch 20/50, loss=2.6441
[RQ2 Align] epoch 21/50, loss=2.6268
[RQ2 Align] epoch 22/50, loss=2.6039
[RQ2 Align] epoch 23/50, loss=2.5867
[RQ2 Align] epoch 24/50, loss=2.5720
[RQ2 Align] epoch 25/50, loss=2.5532
[RQ2 Align] epoch 26/50, loss=2.5501
[RQ2 Align] epoch 27/50, loss=2.5299
[RQ2 Align