In [None]:
# # Install clip
! pip install torch torchvision -q
! pip install ftfy regex tqdm -q
! pip install git+https://github.com/openai/CLIP.git -q

In [None]:
# Install w&b for visualization
! pip install wandb -q

In [None]:
# Hugging Face's datasets for the stanford dogs ds
! pip install datasets -q

In [36]:
import os
from typing import Tuple, List, Dict
import clip
import torch
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import random
from datasets import load_dataset
import wandb
from datetime import datetime
import gzip
import numpy as np

In [37]:
# Variables
DATA_ROOT = "data/"
VERIVIED_DOG_DATA = os.path.join(DATA_ROOT, "certain_matches")

In [38]:
# Load the stanford dogs ds for negative examples
STANFORD_DOGS = load_dataset("Alanox/stanford-dogs").shuffle()

In [39]:
# MODEL_TYPE = 'ViT-L/14@336px'
# ml.r5.24xlarge

In [40]:
# MODEL_TYPE = 'ViT-L/14@336px'
MODEL_TYPE = "ViT-B/32"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(MODEL_TYPE, device)
model = model.float()

print("Using device=", device)

Using device= cuda


In [41]:
# Choose layers not to train. From `model.named_parameters()`
# params_to_freeze = []
params_to_freeze = ['positional_embedding',
 'text_projection',
 'visual.class_embedding',
 'visual.positional_embedding',
 'visual.proj',
 'visual.conv1.weight',
 'visual.ln_pre.weight',
 'visual.ln_pre.bias']

In [42]:
temperature = None
for name, param in model.named_parameters():
    if name in params_to_freeze:
        print(f"Freezing {name}")
        param.requires_grad = False
    # if name == "logit_scale":
    #     temperature = param

Freezing positional_embedding
Freezing text_projection
Freezing visual.class_embedding
Freezing visual.positional_embedding
Freezing visual.proj
Freezing visual.conv1.weight
Freezing visual.ln_pre.weight
Freezing visual.ln_pre.bias


In [43]:
temperature

In [44]:
class DogDataset(Dataset):
    def __init__(self, root, negative_ds, transform, num_negatives=5):
        self.dataset = ImageFolder(root, transform=transform)
        self.num_negatives = num_negatives
        self.negative_ds = negative_ds
        self.dog_indices = self.get_dog_indices()
        self.transform = transform

    def __getitem__(
        self, index
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Each indexing returns the components that would later create 2 image vectors, of the form: [positive_dog, [negative_dog*num_negatives]]
        where `positive_dog` is a dog from our dataset, for which we have a match, and `negative_dog`s are random
        dogs from the "dogs of stanford" dataset, for which we don't have a match, thus they serve as negative examples only.

        """
        dog_label = list(self.dog_indices.keys())[index % len(self.dog_indices)]
        pos_1_idx, pos_2_idx = random.sample(self.dog_indices[dog_label], 2)
        positive_image_1, _ = self.dataset[pos_1_idx]
        positive_image_2, _ = self.dataset[pos_2_idx]
        negative_images_1, negative_images_2 = self.get_negative_images()
        assert len(negative_images_1) == len(negative_images_2)
        return positive_image_1, positive_image_2, negative_images_1, negative_images_2

    def get_dog_indices(self) -> Dict[int, List[int]]:
        # Logic to enforce a uniform probability for each dog
        dog_indices = {}
        for idx, (_, label) in enumerate(self.dataset.imgs):
            if label not in dog_indices:
                dog_indices[label] = []
            dog_indices[label].append(idx)

        # Remove folders with only one image
        dog_indices = {
            label: indices for label, indices in dog_indices.items() if len(indices) > 1
        }
        return dog_indices

    def get_negative_images(self) -> List[torch.Tensor]:
        # Sample self.num_negatives examples from dataset for each positive dog
        indices = random.sample(
            range(self.negative_ds["full"].num_rows), self.num_negatives * 2
        )
        images = [self.transform(self.negative_ds["full"][idx]["image"]) for idx in indices]
        return images[:self.num_negatives], images[self.num_negatives:]

    def __len__(self):
        return len(self.dog_indices)


In [59]:
NUM_NEGATIVES = 10

In [60]:
dog_dataset = DogDataset(
    root=VERIVIED_DOG_DATA, 
    negative_ds=STANFORD_DOGS, 
    transform=preprocess, 
    num_negatives=NUM_NEGATIVES
)

In [61]:
# Train - test split

BATCH_SIZE = 16

train_proportion = 0.8
test_proportion = 1 - train_proportion

num_samples = len(dog_dataset)
num_train_samples = int(train_proportion * num_samples)
num_test_samples = num_samples - num_train_samples

train_set, test_set = random_split(dog_dataset, [num_train_samples, num_test_samples])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)


In [62]:
# Create new dir for test set & checkpoints

NEW_RUN_NAME = f"RUN_{datetime.now().strftime('%d_%b_%y_%H-%M-%S')}"
NEW_RUN_DIR = os.path.join("artifacts", NEW_RUN_NAME)
PATH_TO_CHECKPOINTS = os.path.join(NEW_RUN_DIR, "checkpoints")

os.mkdir(NEW_RUN_DIR)
os.mkdir(PATH_TO_CHECKPOINTS)

In [63]:
# Save the test set to later evaluate performance on the test set only

with gzip.open(os.path.join(NEW_RUN_DIR, 'test_data.pth.gz'), 'wb') as f:
    torch.save(test_set, f)


In [64]:
# Set training params
# learning_rate = 5e-6
learning_rate = 1e-6

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 5

In [None]:
# Init w&b
wandb.init(
    project="dogfinder",
    config={
        "batch_size": BATCH_SIZE,
        "learning_rate": learning_rate,
        "num_negatives": dog_dataset.num_negatives,
        "temperature": temperature,
        "model_type": MODEL_TYPE,
        "run_name": NEW_RUN_NAME,
        "frozen_params": params_to_freeze
    },
    save_code=True,
    )

In [66]:
# Clear cache
if device == "cuda":
    torch.cuda.empty_cache()

In [67]:
# Helper funcs
def normalize_features(features):
    return features / (torch.norm(features, dim=-1, keepdim=True))

def normalized_features_to_logits(anchor_features, positive_features, negatives_features, temperature=None):  
    positive_similarity = (anchor_features * positive_features).sum(dim=-1)
    negative_similarity = torch.stack(
        [
            (anchor_features * negative_feature).sum(dim=-1)
            for negative_feature in negatives_features
        ],
        dim=-1)
    logits = (
      torch.cat([positive_similarity.unsqueeze(-1), negative_similarity], dim=-1)
    )
    
    if temperature:
        logits /= temperature
    return logits



def run_through_loop(pos_1, pos_2, negs_1, negs_2):

    # Encode the images
    pos_1_features = model.encode_image(pos_1.float())
    pos_2_features = model.encode_image(pos_2.float())
    negs_1_features = [model.encode_image(negative.float()) for negative in negs_1]
    negs_2_features = [model.encode_image(negative.float()) for negative in negs_2]

    pos_1_features = normalize_features(pos_1_features) 
    pos_2_features = normalize_features(pos_2_features)
    negs_1_features = [
        normalize_features(negative_feature) for negative_feature in negs_1_features
        ]
    negs_2_features = [
        normalize_features(negative_feature) for negative_feature in negs_2_features
      ]

    # Calculate the contrastive loss
    # Direction 1: logits per row
    row_logits = normalized_features_to_logits(
        anchor_features=pos_1_features, 
        positive_features=pos_2_features, 
        negatives_features=negs_2_features,
        temperature=temperature
        )
    # Direction 2: logits per col
    col_logits = normalized_features_to_logits(
        anchor_features=pos_2_features, 
        positive_features=pos_1_features, 
        negatives_features=negs_1_features,
        temperature=temperature
        )
    # The positive example is always at the 1st index
    labels = torch.zeros(row_logits.shape[0], dtype=torch.long, device=device)
    row_loss = torch.nn.functional.cross_entropy(row_logits, labels) 
    col_loss = torch.nn.functional.cross_entropy(col_logits, labels) 
    loss = (row_loss + col_loss) / 2.0
    return loss


In [68]:
# Custom calculations to run on the validation set

def similarity(embedding1, embedding2):
    return torch.nn.functional.cosine_similarity(embedding1, embedding2).item()

def get_top_n_stats(row_embeddings, col_embeddings):
    indices_of_correct_image_in_predictions = []

    for ii, row_dog in enumerate(row_embeddings):
        similarities = []
        for col_dog in col_embeddings:
            similarities.append(similarity(row_dog, col_dog))

        sorted_similarities = sorted(
            enumerate(similarities), key=lambda x: x[1], reverse=True
        )
        sorted_indices = [i for i, _ in sorted_similarities]
        correct_image_index = sorted_indices.index(ii)

        indices_of_correct_image_in_predictions.append(correct_image_index)

    as_np = np.asarray(indices_of_correct_image_in_predictions)
    total_dp = len(indices_of_correct_image_in_predictions)
    top_1_count = len(as_np[as_np == 1])
    top_5_count = len(as_np[as_np <= 5])
    top_10_count = len(as_np[as_np <= 10])

    print(f"Total datapoints: {total_dp}")
    print(f"# of answer is in top 1: {top_1_count} = {top_1_count/total_dp:.3f}%")
    print(f"# of answer is in top 5: {top_5_count} = {top_5_count/total_dp:.3f}%")
    print(f"# of answer is in top 10: {top_10_count} = {top_10_count/total_dp:.3f}%")

    print("indices_of_correct_image_in_predictions", indices_of_correct_image_in_predictions)
    return total_dp, top_1_count, top_5_count, top_10_count


In [None]:
# Top n baseline:
with torch.no_grad():        
    row_embeddings = []
    col_embeddings = []
    for pos_1, pos_2, negs_1, negs_2 in tqdm(test_loader):
        pos_1 = pos_1.to(device)
        pos_2 = pos_2.to(device)
            

        pos1_features = model.encode_image(pos_1.float())
        pos2_features = model.encode_image(pos_2.float())
        row_embeddings.extend([data_point.unsqueeze(0) for data_point in pos1_features])
        col_embeddings.extend([data_point.unsqueeze(0) for data_point in pos2_features])

    total_dp, top_1_count, top_5_count, top_10_count = get_top_n_stats(
        row_embeddings, col_embeddings
    )

# Log the metrics to W&B
wandb.log({"total_dp": total_dp, 
           "top_1_count":top_1_count, 
           "top_5_count":top_5_count, 
           "top_10_count": top_10_count, 
           "epoch": -1}
         )

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for pos_1, pos_2, negs_1, negs_2 in tqdm(train_loader):
        pos_1 = pos_1.to(device)
        pos_2 = pos_2.to(device)
        negs_1 = [neg.to(device) for neg in negs_1]
        negs_2 = [neg.to(device) for neg in negs_2]
        optimizer.zero_grad()
        loss = run_through_loop(pos_1, pos_2, negs_1, negs_2)
        
        # Update the model parameters
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_loss /= len(train_loader)

    # Compute validation loss
    model.eval()
    with torch.no_grad():
        validation_loss = 0
        
        row_embeddings = []
        col_embeddings = []
        for pos_1, pos_2, negs_1, negs_2 in tqdm(test_loader):
            pos_1 = pos_1.to(device)
            pos_2 = pos_2.to(device)
            negs_1 = [neg.to(device) for neg in negs_1]
            negs_2 = [neg.to(device) for neg in negs_2]
            
            val_loss = run_through_loop(pos_1, pos_2, negs_1, negs_2)
            validation_loss += val_loss.item()
            
            pos1_features = model.encode_image(pos_1.float())
            pos2_features = model.encode_image(pos_2.float())
            row_embeddings.extend([data_point.unsqueeze(0) for data_point in pos1_features])
            col_embeddings.extend([data_point.unsqueeze(0) for data_point in pos2_features])

        validation_loss /= len(test_loader)
        total_dp, top_1_count, top_5_count, top_10_count = get_top_n_stats(
            row_embeddings, col_embeddings
        )
        
    print(
        f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {validation_loss:.4f}"
    )
    # Log the metrics to W&B
    wandb.log({"train_loss": train_loss, "validation_loss": validation_loss, "epoch": epoch + 1})
    wandb.log({"total_dp": total_dp, 
               "top_1_count":top_1_count, 
               "top_5_count":top_5_count, 
               "top_10_count": top_10_count, 
               "epoch": epoch + 1}
             )

    # Save checkpoint
    checkpoint = {
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": train_loss,
        "validation_loss": validation_loss,
    }

    torch.save(model.state_dict(), os.path.join(PATH_TO_CHECKPOINTS, f"checkpoint_epoch_{epoch + 1}.pth"))
