In [None]:
!pip install timm

In [None]:
!pip install chromadb==0.3.29

In [None]:
## HELPER FUNCTIONS AND IMPORTS
import timm
import chromadb
import torchvision
from torch.utils.data import DataLoader
import pandas as pd
import torch
import gc
from collections import Counter

def make_embeddings(dataloader, model, classes, device):
    data_source = "" # E.g. "torch_cifar10": string added as metadata to DB. Currently not further used, but might be helpful to determine data origin.
    apped_str_to_ids = "" # string concatinated to each id string in DB. Helpful if multiple datasets should be added to the same DB
    d = []
    for batch, (X, y) in enumerate(dataloader):

        images = X
        labels = y
        images = images.to(device)

        output = model.forward_features(images)
        output = model.forward_head(output, pre_logits=True)

        feature_maps = output


        for id, feature_map in enumerate(feature_maps):
            d.append(
                {
                    'document': str("X[" + str(id) + "]"), # document is currently irrelevant. Might be used to add raw input data to DB
                    'embedding': feature_map.flatten().tolist(), # this is important
                    'metadata': {"label": classes[y[id]], "data_source": data_source}, # can add any metadata. "label" was used for ACC calculation on test data 
                    'id': str("id_b" + str(batch) + "_id" + str(id) + apped_str_to_ids),
                    'class_id': y[id].item()
                }
            )

        ### ADD SOME GARBAGE COLLECTION
        del X
        del y
        del images
        gc.collect()
        torch.cuda.empty_cache()
    
    return pd.DataFrame(d)

def add_train_test_df_to_chroma_collection(chroma_collection, df):
    chroma_collection.upsert(
        documents=list(df['document']),
        embeddings=list(df['embedding']),
        metadatas=list(df['metadata']),
        ids=list(df['id'])
    )
    print("Elements in DB after add:", chroma_collection.count())  # returns the number of items in the collection
    #print(collection.peek()) # returns a list of the first 10 items in the collection

# PHASE 3 (inference)
def eval_collection_against_df(collection_loaded, df_to_test, classes, k = 10, distance_weighted = False):
        total_correct = 0
        total_instances = 0

        for sample in df_to_test.itertuples():
            print("Testing " + str(total_instances) + " out of " + str(len(df_to_test.index)), end="\r")
            result = collection_loaded.query(
                query_embeddings=list(sample.embedding),
                n_results = k,
                include=["metadatas", "documents", "distances"]
            )

            if distance_weighted == False:
                classification = knn_majority_vote(result, k)
            else:
                classification = knn_distance_weighted(result, k)

            if (classification == classes[sample.class_id]):
                total_correct+=1
                #print("Correct classification")
            #else:
                #print("Missclassification: Real: " + classes[sample.class_id] + "; Classification:" + classification)
                #print(display_distances_with_labels(result))
            total_instances+=1

        acc = round(total_correct/total_instances, 3)
        #print("Total test instances: " + str(total_instances))
        #print("Total correct of test instances: " + str(total_correct))
        #print("Accuracy: " +  str(acc))

        return acc

def knn_majority_vote(data, k):
    distances = data['distances'][0]
    labels = [metadata['label'] for metadata in data['metadatas'][0]]

    # Get the k nearest neighbors based on distances
    nearest_neighbors = sorted(range(len(distances)), key=lambda x: distances[x])[:k]
    
    # Extract the corresponding labels for the nearest neighbors
    nearest_labels = [labels[i] for i in nearest_neighbors]

    # Perform majority vote on the labelss
    vote_counts = Counter(nearest_labels)
    
    # Get the label with the maximum vote count
    majority_label = vote_counts.most_common(1)[0][0]
    
    return majority_label

def knn_distance_weighted(data, k):
    distances = data['distances'][0]
    labels = [metadata['label'] for metadata in data['metadatas'][0]]

    # Get the k nearest neighbors based on distances
    nearest_neighbors = sorted(range(len(distances)), key=lambda x: distances[x])[:k]

    # Create a dictionary to store label weights
    label_weights = {}

    # Calculate distance-weighted votes for each label
    for i in nearest_neighbors:
        if (distances[i] != 0):
            weight = 1.0 / distances[i]
        else:
            weight = 1.0
        label = labels[i]
        if label in label_weights:
            label_weights[label] += weight
        else:
            label_weights[label] = weight

    # Find the label with the maximum weighted vote
    majority_label = max(label_weights, key=label_weights.get)

    return majority_label


In [None]:
# SETUP PARAMETERS
device = "cuda:0" # TODO choose your ML computing device
db_name = "privacy-aware-image-classification-with-kNN" # TODO name that allows you to identify db in the future
db_collection_name = db_name + "_" + "ex1_STL10" # TODO name that allows you to identify your collection in the future
db_persistent_directory = "chroma/databases/" + db_name # TODO adjust to your own chroma storage path
# DATABASE CHROMA
chroma_client = chromadb.Client(chromadb.config.Settings(
            chroma_db_impl = "duckdb+parquet",
            persist_directory = db_persistent_directory
        ))
chroma_collection = chroma_client.get_or_create_collection(name=db_collection_name)
print("Elements in loaded collection: " + str(chroma_collection.count())) # should typically be empty

# MODEL PARAMETERS
# any backbone from https://huggingface.co/timm
backbone = "vit_small_patch14_dinov2.lvd142m" # TODO choose your timm backbone (ours: "vit_small_patch14_dinov2.lvd142m","vit_large_patch14_dinov2.lvd142m","vit_base_patch16_clip_224.openai","vit_large_patch14_clip_336.openai")

# MODEL
model = timm.create_model(backbone, pretrained=True, num_classes=0).to(device)
model = model.eval()

data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False) # get the required transform for the given backbone

# DATASET # TODO choose your support dataset (ours: "STL10", "CIFAR-10", "CIFAR-100")
train_dataset = torchvision.datasets.STL10(root='./data', split="train", download=True, transform=transform)
classes = train_dataset.classes
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=False, num_workers=2)

# GET TRAIN EMBEDDINGS (support set)
print('Calculate feature maps for support set embeddings...')
df_train = make_embeddings(train_loader, model, classes, device)
add_train_test_df_to_chroma_collection(chroma_collection,df_train)
chroma_client.persist() # data is only stored in Chroma DB when persist() is called

assert chroma_collection.count() == (len(df_train))

## SUPPORT SET HAS BEEN ADDED TO DB

In [None]:
# EVALUATE ACCURACY
# DATASET # TODO choose your test dataset (ours: "STL10", "CIFAR-10", "CIFAR-100")
test_dataset = torchvision.datasets.STL10(root='./data', split="test", download=True, transform=transform)
classes = test_dataset.classes
test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=2)
df_test = make_embeddings(test_loader, model, classes, device)

acc = eval_collection_against_df(chroma_collection, df_test, classes, k = 10)

print(backbone + ":Classes " + str(df_test.class_id.unique()) + ": acc= " + str(acc))