In [56]:
import torch
from torch import nn
from sklearn.datasets import fetch_lfw_pairs
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import torch.optim as optim
import clip
from PIL import Image
import numpy as np
import os

# Preprocess and Encode Images
def compute_clip_features(dataset, clip_model, preprocess, device):
    clip_features = []
    for image_pair in dataset:
        image1, image2 = image_pair[0], image_pair[1]

        # Preprocess images
        image1 = Image.fromarray(image1.astype('uint8'), 'RGB')
        image2 = Image.fromarray(image2.astype('uint8'), 'RGB')
        preprocessed_image1 = preprocess(image1).unsqueeze(0).to(device)
        preprocessed_image2 = preprocess(image2).unsqueeze(0).to(device)

        # Encode images
        with torch.no_grad():
            features1 = clip_model.encode_image(preprocessed_image1)
            features2 = clip_model.encode_image(preprocessed_image2)

        # Flatten and concatenate features for each image pair
        features1 = features1.view(-1)
        features2 = features2.view(-1)
        concatenated_features = torch.cat((features1, features2))

        clip_features.append(concatenated_features.cpu())

    return clip_features


class LFWDatasetFeatures(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        features = self.features[idx]
        label = self.labels[idx]
        return features, label

class CustomClassifierConcat(nn.Module):
    def __init__(self):
        super(CustomClassifierConcat, self).__init__()

        # Assuming each CLIP feature vector is of size 512, concatenated size is 1024
        self.fc1 = nn.Linear(1024, 2048)  # Input size for concatenated features
        self.bn1 = nn.BatchNorm1d(2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.bn2 = nn.BatchNorm1d(1024)
        self.fc3 = nn.Linear(1024, 1)
        self.relu = nn.ReLU()

    def forward(self, features):
        out = self.fc1(features)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

# Training and evaluation functions
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0

    for features, labels in train_loader:
        features = features.to(device).float()  # Ensure features are float32
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs.squeeze(), labels.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.sigmoid(outputs).round()
        correct += (preds.squeeze() == labels).sum().item()
        total_samples += labels.size(0)

    average_loss = total_loss / len(train_loader)
    accuracy = correct / total_samples
    return average_loss, accuracy


def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for features, labels in test_loader:
            features, labels = features.to(device).float(), labels.to(device)
            outputs = model(features)
            loss = criterion(outputs.squeeze(), labels.float())
            total_loss += loss.item()
            preds = torch.sigmoid(outputs).round()
            correct += (preds.squeeze() == labels).sum().item()
    accuracy = correct / len(test_loader.dataset)
    return total_loss / len(test_loader), accuracy


In [57]:
from sklearn.model_selection import train_test_split

def create_small_dataset(features, labels, subset_ratio=0.1):
    """
    Create a smaller subset of the dataset for quick prototyping.
    :param features: The original features.
    :param labels: The corresponding labels.
    :param subset_ratio: The fraction of the dataset to use.
    :return: A tuple of (small_features, small_labels).
    """
    small_features, _, small_labels, _ = train_test_split(
        features, labels, test_size=subset_ratio, random_state=42
    )
    return small_features, small_labels


In [58]:
# !rm test_features.pt train_features.pt

In [59]:
# Main execution starts here
device = "mps" if torch.backends.mps.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

lfw_pairs_train, lfw_pairs_test = fetch_lfw_pairs(subset='train', color=True), fetch_lfw_pairs(subset='test', color=True)
X_train_full, y_train_full, X_test_full, y_test_full = lfw_pairs_train.pairs, lfw_pairs_train.target, lfw_pairs_test.pairs, lfw_pairs_test.target
small_train, small_train_y = create_small_dataset(X_train_full, y_train_full)
small_test, small_test_y = create_small_dataset(X_test_full, y_test_full)

# Optionally create a smaller dataset for quick prototyping
use_small_dataset = False  # Set to False to use the full dataset

if use_small_dataset:
    X_train_full, y_train_full, X_test_full, y_test_full = small_train, small_train_y, small_test, small_test_y
    
# Check if precomputed features are already saved
def load_features(filename):
    if os.path.exists(filename):
        return torch.load(filename)
    return None

train_features = load_features('train_features.pt')
test_features = load_features('test_features.pt')

if train_features is None or test_features is None:
    print("Computing features...")
    train_features = compute_clip_features(X_train_full, model, preprocess, device)
    test_features = compute_clip_features(X_test_full, model, preprocess, device)
    torch.save(train_features, 'train_features.pt')
    torch.save(test_features, 'test_features.pt')
else:
    print("Loaded features from disk.")

use_small_dataset = False

if use_small_dataset:
    small_train_features, small_train_labels = create_small_dataset(train_features, y_train_full)
    small_test_features, small_test_labels = create_small_dataset(test_features, y_test_full)

    train_dataset = LFWDatasetFeatures(small_train_features, small_train_labels)
    test_dataset = LFWDatasetFeatures(small_test_features, small_test_labels)
else:
    train_dataset = LFWDatasetFeatures(train_features, y_train_full)
    test_dataset = LFWDatasetFeatures(test_features, y_test_full)

# Proceed with DataLoader, Model, Training, and Evaluation
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Loaded features from disk.


In [61]:
device = "mps"

custom_model_concat = CustomClassifierConcat().to(device).float()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(custom_model_concat.parameters(), lr=0.0005)

# Main training loop
num_epochs = 1000
for epoch in range(num_epochs):
    train_loss, train_accuracy = train(custom_model_concat, train_loader, criterion, optimizer, device)
    test_loss, test_accuracy = evaluate(custom_model_concat, test_loader, criterion, device)
    print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}')

Epoch 1: Train Loss: 0.7352, Train Accuracy: 0.50, Test Loss: 0.7113, Test Accuracy: 0.49
Epoch 2: Train Loss: 0.7069, Train Accuracy: 0.49, Test Loss: 0.7391, Test Accuracy: 0.49
Epoch 3: Train Loss: 0.7079, Train Accuracy: 0.51, Test Loss: 0.7303, Test Accuracy: 0.49
Epoch 4: Train Loss: 0.7022, Train Accuracy: 0.52, Test Loss: 0.7047, Test Accuracy: 0.48
Epoch 5: Train Loss: 0.7028, Train Accuracy: 0.51, Test Loss: 0.7016, Test Accuracy: 0.49
Epoch 6: Train Loss: 0.7030, Train Accuracy: 0.51, Test Loss: 0.7326, Test Accuracy: 0.48
Epoch 7: Train Loss: 0.7082, Train Accuracy: 0.50, Test Loss: 0.7235, Test Accuracy: 0.50
Epoch 8: Train Loss: 0.6954, Train Accuracy: 0.52, Test Loss: 0.7194, Test Accuracy: 0.49
Epoch 9: Train Loss: 0.6951, Train Accuracy: 0.49, Test Loss: 0.7471, Test Accuracy: 0.48
Epoch 10: Train Loss: 0.7018, Train Accuracy: 0.51, Test Loss: 0.7055, Test Accuracy: 0.51
Epoch 11: Train Loss: 0.6967, Train Accuracy: 0.50, Test Loss: 0.7021, Test Accuracy: 0.49
Epoch 12

KeyboardInterrupt: 