In [34]:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.data import Dataset
from PIL import Image
import os

In [35]:
train_dir = '../../data/training'
gallery_dir = '../../data/test/gallery'
query_dir = '../../data/test/query'

In [32]:
class ImageOnlyDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir)
                            if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, img_path  # Return path so you can match later


In [None]:
# import timm
# model = timm.create_model("efficientnet_b0", pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Load the pretrained ResNet-18 model
model = models.resnet18(pretrained=True)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/anthony/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:05<00:00, 7.97MB/s]


In [7]:
# Modify the final fully connected layer for your dataset
num_ftrs = model.fc.in_features  # Get the input features of the last layer
model.fc = nn.Linear(num_ftrs, 2)  # For binary classification (change '2' to the number of classes in your dataset)

In [8]:
# If you want to fine-tune only the final layer, freeze the rest of the model:
for param in model.parameters():
    param.requires_grad = False
# Only the last fully connected layer will be fine-tuned
for param in model.fc.parameters():
    param.requires_grad = True
# Move the model to the GPU (if available)
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = model.to(device)

# Data preparation

In [9]:
# Define the transformations: resize, normalize, and convert to tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet expects 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Pretrained ImageNet normalization
])

In [21]:
# Load the dataset using ImageFolder
train_dataset = datasets.ImageFolder(root = train_dir, transform=transform)
# val_dataset = datasets.ImageFolder(root='/path/to/data/val', transform=transform)

In [36]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # normalize if you did during training
])

gallery_dataset = ImageOnlyDataset(gallery_dir, transform=transform)
query_dataset = ImageOnlyDataset(query_dir, transform=transform)

gallery_loader = DataLoader(gallery_dataset, batch_size=32, shuffle=False)
query_loader = DataLoader(query_dataset, batch_size=32, shuffle=False)


In [11]:
# Create DataLoader for batching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Model

In [16]:

class ResNetEmbedding(nn.Module):
    def __init__(self, resnet_model, embedding_dim):
        super(ResNetEmbedding, self).__init__()
        self.resnet = resnet_model

        # Get the number of input features to the last fully connected layer
        in_features = self.resnet.fc.in_features

        # Replace the final fully connected layer with an identity layer
        self.resnet.fc = nn.Identity()

        # Define a new fully connected layer for embedding
        self.fc = nn.Linear(in_features, embedding_dim)

    def forward(self, x):
        x = self.resnet(x)  # Forward pass through ResNet backbone (up to before the classification layer)
        x = self.fc(x)      # Pass through the embedding layer
        return x

# Define the embedding dimension (e.g., 128)
embedding_dim = 128

# Load ResNet18 with pretrained weights
model = models.resnet18(weights="IMAGENET1K_V1")

# Create the custom model
model = ResNetEmbedding(model, embedding_dim)

# Move to device (GPU or CPU)
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = model.to(device)

In [25]:
import torch
from torch.utils.data import Dataset
import random
from torchvision import datasets, transforms
from PIL import Image

class TripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the class folders.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.transform = transform
        self.class_to_idx = self.dataset.class_to_idx
        self.imgs = self.dataset.imgs  # List of (image_path, class_index)
        self.class_indices = {class_name: [] for class_name in self.class_to_idx.keys()}

        for idx, (img_path, class_idx) in enumerate(self.imgs):
            class_name = list(self.class_to_idx.keys())[list(self.class_to_idx.values()).index(class_idx)]
            self.class_indices[class_name].append(idx)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        anchor_img_path, anchor_label = self.imgs[idx]
        anchor_image = Image.open(anchor_img_path)

        # Apply the transform (e.g., resizing, normalization)
        if self.transform:
            anchor_image = self.transform(anchor_image)

        # Positive: A random image from the same class
        positive_idx = random.choice(self.class_indices[list(self.class_to_idx.keys())[anchor_label]])
        positive_img_path, positive_label = self.imgs[positive_idx]
        positive_image = Image.open(positive_img_path)

        if self.transform:
            positive_image = self.transform(positive_image)

        # Negative: A random image from a different class
        negative_class = random.choice(list(self.class_to_idx.keys()))
        while negative_class == list(self.class_to_idx.keys())[anchor_label]:  # Ensure it's not the same class
            negative_class = random.choice(list(self.class_to_idx.keys()))

        negative_idx = random.choice(self.class_indices[negative_class])
        negative_img_path, negative_label = self.imgs[negative_idx]
        negative_image = Image.open(negative_img_path)

        if self.transform:
            negative_image = self.transform(negative_image)

        # Return the triplet
        return anchor_image, positive_image, negative_image


In [30]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Define the transformation (resize, normalization, etc.)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust to your image size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Pretrained weights normalization
])

# Create the triplet dataset and DataLoader
train_dataset = TripletDataset(root_dir=train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

# Now you can use this `train_loader` in your training loop


In [31]:
import torch.optim as optim
import torch.nn.functional as F

# Define the TripletMarginLoss (you can adjust the margin parameter)
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

# Set up the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for anchor, positive, negative in train_loader:
        # Move the data to the GPU (if available)
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Get embeddings
        anchor_emb = model(anchor)
        positive_emb = model(positive)
        negative_emb = model(negative)

        # Compute the triplet loss
        loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
        
        # Backpropagate and optimize
        loss.backward()
        optimizer.step()

        # Track the loss
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


Epoch 1/10, Loss: 0.22736740112304688
Epoch 2/10, Loss: 0.1052762046456337
Epoch 3/10, Loss: 0.018141746520996094
Epoch 4/10, Loss: 0.03136635571718216
Epoch 5/10, Loss: 0.0
Epoch 6/10, Loss: 0.7395499497652054
Epoch 7/10, Loss: 0.12428782880306244
Epoch 8/10, Loss: 0.14294615387916565
Epoch 9/10, Loss: 0.007796823978424072
Epoch 10/10, Loss: 0.14474797248840332


In [37]:
model.eval()
with torch.no_grad():
    gallery_embeddings = []
    query_embeddings = []
    gallery_paths = []
    query_paths = []

    # Extract gallery embeddings
    for images, paths in gallery_loader:
        images = images.to(device)
        emb = model(images)
        gallery_embeddings.append(emb.cpu().numpy())
        gallery_paths.extend(paths)

    # Extract query embeddings
    for images, paths in query_loader:
        images = images.to(device)
        emb = model(images)
        query_embeddings.append(emb.cpu().numpy())
        query_paths.extend(paths)

In [38]:
# Convert to numpy arrays
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

gallery_embeddings = np.vstack(gallery_embeddings)
query_embeddings = np.vstack(query_embeddings)

# Compute cosine similarity
similarity_matrix = cosine_similarity(query_embeddings, gallery_embeddings)

# For each query image, get the most similar gallery image
retrieved_indices = np.argmax(similarity_matrix, axis=1)

# Print results
for i, idx in enumerate(retrieved_indices):
    print(f"Query image: {query_paths[i]}")
    print(f"Retrieved gallery image: {gallery_paths[idx]}")
    print()

Query image: ../../data/test/query/4597118805213184.jpg
Retrieved gallery image: ../../data/test/gallery/painting_085_000118.jpg

Query image: ../../data/test/query/n01855672_10973.jpg
Retrieved gallery image: ../../data/test/gallery/n01855672_4393.jpg

