In [None]:
!pip install timm

In [None]:
!pip install chromadb==0.3.29

In [None]:
## HELPER FUNCTIONS AND IMPORTS
import timm
import chromadb
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import numpy as np
import torch
import gc
from collections import Counter

def make_embeddings(dataloader, model, classes, device):
    data_source = "" # E.g. "torch_cifar10": string added as metadata to DB. Currently not further used, but might be helpful to determine data origin.
    apped_str_to_ids = "" # string concatinated to each id string in DB. Helpful if multiple datasets should be added to the same DB
    d = []
    for batch, (X, y) in enumerate(dataloader):

        images = X
        labels = y
        images = images.to(device)

        output = model.forward_features(images)
        output = model.forward_head(output, pre_logits=True)

        feature_maps = output


        for id, feature_map in enumerate(feature_maps):
            d.append(
                {
                    'document': f"X[{id}]", # str("X[" + str(id) + "]"), # document is currently irrelevant. Might be used to add raw input data to DB
                    'embedding': feature_map.view(-1).cpu().tolist(), # feature_map.flatten().tolist(), # this is important not efficient on GPU
                    'metadata': {"label": classes[y[id]], "data_source": data_source}, # can add any metadata. "label" was used for ACC calculation on test data
                    'id': f"id_b{batch}_id{id}{apped_str_to_ids}",  #str("id_b" + str(batch) + "_id" + str(id) + apped_str_to_ids),
                    'class_id': y[id].item()
                }
            )

        ### ADD SOME GARBAGE COLLECTION
        del X
        del y
        del images
        gc.collect()
        torch.cuda.empty_cache()

    return pd.DataFrame(d)

def add_train_test_df_to_chroma_collection(chroma_collection, df):
    chroma_collection.upsert(
        documents=list(df['document']),
        embeddings=list(df['embedding']),
        metadatas=list(df['metadata']),
        ids=list(df['id'])
    )
    print("Elements in DB after add:", chroma_collection.count())  # returns the number of items in the collection
    #print(collection.peek()) # returns a list of the first 10 items in the collection

# PHASE 3 (inference)
def eval_collection_against_df(collection_loaded, df_to_test, classes, k = 10, distance_weighted = False):
        total_correct = 0
        total_instances = 0

        for sample in df_to_test.itertuples():
            print("Testing " + str(total_instances) + " out of " + str(len(df_to_test.index)), end="\r")
            result = collection_loaded.query(
                query_embeddings=list(sample.embedding),
                n_results = k,
                include=["metadatas", "documents", "distances"]
            )

            if distance_weighted == False:
                classification = knn_majority_vote(result, k)
            else:
                classification = knn_distance_weighted(result, k)

            if (classification == classes[sample.class_id]):
                total_correct+=1
                #print("Correct classification")
            #else:
                #print("Missclassification: Real: " + classes[sample.class_id] + "; Classification:" + classification)
                #print(display_distances_with_labels(result))
            total_instances+=1

        acc = round(total_correct/total_instances, 3)
        #print("Total test instances: " + str(total_instances))
        #print("Total correct of test instances: " + str(total_correct))
        #print("Accuracy: " +  str(acc))

        return acc

def knn_majority_vote(data, k):
    distances = data['distances'][0]
    labels = [metadata['label'] for metadata in data['metadatas'][0]]

    # Get the k nearest neighbors based on distances
    nearest_neighbors = sorted(range(len(distances)), key=lambda x: distances[x])[:k]

    # Extract the corresponding labels for the nearest neighbors
    nearest_labels = [labels[i] for i in nearest_neighbors]

    # Perform majority vote on the labelss
    vote_counts = Counter(nearest_labels)

    # Get the label with the maximum vote count
    majority_label = vote_counts.most_common(1)[0][0]

    return majority_label

def knn_distance_weighted(data, k):
    distances = data['distances'][0]
    labels = [metadata['label'] for metadata in data['metadatas'][0]]

    # Get the k nearest neighbors based on distances
    nearest_neighbors = sorted(range(len(distances)), key=lambda x: distances[x])[:k]

    # Create a dictionary to store label weights
    label_weights = {}

    # Calculate distance-weighted votes for each label
    for i in nearest_neighbors:
        if (distances[i] != 0):
            weight = 1.0 / distances[i]
        else:
            weight = 1.0
        label = labels[i]
        if label in label_weights:
            label_weights[label] += weight
        else:
            label_weights[label] = weight

    # Find the label with the maximum weighted vote
    majority_label = max(label_weights, key=label_weights.get)

    return majority_label

# Let's define a custom Dataset class for our data
class datasetISIC2018DiseaseClassification(Dataset):
    def __init__(self, csv_file, class_list, transform=None, subset = "train"):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.class_list = class_list
        self.subset = subset
        self.classes = ["MEL","NV","BCC","AKIEC","BKL","DF","VASC"]

    def __len__(self):
        return self.df.shape[0]

    # def __getitem__(self, index):
    #     image = Image.open("assets/datasets/ISIC2018/Task3-DiseaseClassification/"+ self.subset +"/" + self.df.image[index] + ".jpg")
    #     #print(self.df.label[index])
    #     #print(self.class_list[self.df.label[index]])
    #     #label = self.class_list[self.df.label[index]]
    #     label = self.df.label[index]

    #     if self.transform:
    #         image = self.transform(image)
    #     return image, label
    def __getitem__(self, index):
        image = Image.open("assets/datasets/ISIC2018/Task3-DiseaseClassification/" + self.subset + "/" + self.df.image[index] + ".jpg")
        
        label_row = self.df.iloc[index, 1:]
        label = label_row.idxmax()
        label_index = self.classes.index(label)

        if self.transform:
            image = self.transform(image)
            
        # print(f"Loaded image: {self.df.image[index]}, Label: {label_index}")
        
        return image, label_index

# EXPERIMENT SPECIFIC METHODS
def copy_collection(collection,chroma_client, name_collection):
    #collection.peek() # returns a list of the first 10 items in the collection
    print("Elements in DB1:", collection.count())  # returns the number of items in the collection

    collection_copied = chroma_client.get_or_create_collection(name=name_collection+"_copy",
                                                metadata={"hnsw:space": "cosine"} # set cosine similarity as distance function
                                                )

    db1_data=collection.get(include=['documents','metadatas','embeddings'])
    collection_copied.upsert(
        embeddings=db1_data['embeddings'],
        metadatas=db1_data['metadatas'],
        documents=db1_data['documents'],
        ids=db1_data['ids']
    )
    #collection_copied.peek() # returns a list of the first 10 items in the collection
    print("Elements in DB2 after copy:", collection_copied.count())  # returns the number of items in the collection
    return collection_copied

def copy_collection_delete_most_significant_n_and_eval(id_frequencies, n, collection, chroma_client, name_collection, df_test, classes):
    ids_key_elements_to_be_deleted = list()
    for label, elements in id_frequencies.items():
        #print("Label:", label)
        #print("Most Common "+str(n)+" Elements:", id_frequencies[label].most_common(n))
        for item in id_frequencies[label].most_common(n):
            ids_key_elements_to_be_deleted.append(item[0])

    collection_copied = copy_collection(collection, chroma_client, name_collection)
    collection_copied.delete(ids_key_elements_to_be_deleted)
    chroma_client.persist()
    print("Elements in DB2 after delete:", collection_copied.count())
    return eval_collection_against_df(collection_copied, df_test, classes)

def copy_collection_delete_random_n_and_eval(label_arrays, n, collection, chroma_client, name_collection, df_test, classes, ids_random_elements_already_deleted):
    np.random.seed(5)  # always set the same seed

    ids_random_elements_to_be_deleted = []
    
    # Flatten the list of all available IDs across all classes
    all_ids = [id for ids in label_arrays.values() for id in ids if id not in ids_random_elements_already_deleted]
    
    if len(all_ids) < n:
        raise ValueError(f"Not enough elements to delete. Needed: {n}, Available: {len(all_ids)}")
    
    # Select exactly n unique IDs to delete
    selected_elements = np.random.choice(all_ids, n, replace=False)
    ids_random_elements_to_be_deleted.extend(selected_elements)

    print("Random IDs to be deleted: n = " + str(len(ids_random_elements_to_be_deleted)))
    print(ids_random_elements_to_be_deleted)

    # Copy the collection
    collection_copied = copy_collection(collection, chroma_client, name_collection)

    # Delete the selected elements from the copied collection
    collection_copied.delete(ids_random_elements_to_be_deleted)
    chroma_client.persist()
    
    print("Elements in DB2 after delete:", collection_copied.count())
    
    return ids_random_elements_to_be_deleted, eval_collection_against_df(collection_copied, df_test, classes)
    
def eval_collection_against_df_count_most_significant(collection_loaded, df_to_test, classes, k = 10, distance_weighted = False):
    ## Initialize a Counter for ID frequencies
    ## Initialize a dictionary of Counters for each label
    id_frequencies = {label: Counter() for label in classes}

    total_correct = 0
    total_instances = 0

    for sample in df_to_test.itertuples():
        print("Testing " + str(total_instances) + " out of " + str(len(df_to_test.index)), end="\r")
        result = collection_loaded.query(
            query_embeddings=list(sample.embedding),
            n_results = k,
            include=["metadatas", "documents", "distances"]
        )

        #print(display_distances_with_labels(result))
        if distance_weighted == False:
            classification = knn_majority_vote(result, k)
        else:
            classification = knn_distance_weighted(result, k)

        if (classification == classes[sample.class_id]):
            total_correct+=1
            #print("Correct classification")
            ## COUNTER
            metadatas = result['metadatas'][0]
            ids = result['ids'][0]

            ## Loop through the elements and only pass those with label "horse"
            for metadata, id in zip(metadatas, ids):
                label = metadata['label']
                id_frequencies[label][id] += 1
        #else:
            #print("Missclassification: Real: " + classes[sample.class_id] + "; Classification:" + classification)
            #print(display_distances_with_labels(result))
        total_instances+=1

    acc = round(total_correct/total_instances, 3)
    print("Total test instances: " + str(total_instances))
    print("Total correct of test instances: " + str(total_correct))
    print("Accuracy: " +  str(acc))
    return acc, id_frequencies

In [None]:
from os import remove
# SETUP PARAMETERS Melanoma diminishing support set (MVF delete)
import ast


device = "cuda:0" if torch.cuda.is_available() else "mps" # TODO choose your cpu if no gpu is available
print("Using device: " + device)
dataset = "melanoma" # TODO choose your dataset (ours: "melanoma", "pneumonia")
pretrain = False
db_name = "privacy-aware-image-classification-with-kNN" # TODO name that allows you to identify db in the future
db_collection_name = f"{db_name}_ex3_2" # TODO name that allows you to identify your collection in the future
db_persistent_directory = "chroma/databases/" + db_name # TODO adjust to your own chroma storage path

# DATABASE CHROMA
chroma_client = chromadb.Client(chromadb.config.Settings(
            chroma_db_impl = "duckdb+parquet",
            persist_directory = db_persistent_directory
        ))
chroma_collection = chroma_client.get_or_create_collection(name=db_collection_name)
print("Elements in loaded collection: " + str(chroma_collection.count())) # should typically be empty

# MODEL PARAMETERS
# any backbone from https://huggingface.co/timm
backbone = "vit_large_patch14_dinov2.lvd142m" # TODO choose your timm backbone (ours: "vit_small_patch14_dinov2.lvd142m","vit_large_patch14_dinov2.lvd142m","vit_base_patch16_clip_224.openai","vit_large_patch14_clip_336.openai")

# MODEL
model = timm.create_model(backbone, pretrained=True, num_classes=0).to(device)
model = model.eval()

data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False) # get the required transform for the given backbone
if dataset == "melanoma":
        print("Melanoma dataset")
        removed_features_per_class = [1,5,10,20,30,40,50,60,70,80,90,100]
        # load the csv file
        csv_file_path = "assets/datasets/ISIC2018/Task3-DiseaseClassification/isic_labels_unified_TRAIN.csv"
        df = pd.read_csv(csv_file_path)
        print(df.head())

        # Define and map the class label
        class_labels = ["MEL","NV","BCC","AKIEC","BKL","DF","VASC"]
        #{'MEL': 0, 'NV': 1, 'BCC': 2, 'AKIEC': 3, 'BKL': 4, 'DF': 5, 'VASC': 6}
        # melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis / Bowen’s disease, benign keratosis, dermatofibroma, and vascular lesion
        class_labels_map = {}
        for indx, label in enumerate(class_labels):
            class_labels_map[indx] = label

        # Lets create an object from our custom dataset class
        train_dataset = datasetISIC2018DiseaseClassification(csv_file_path, class_labels, transform, "train")

        classes = train_dataset.classes
        print(f"classes: {classes} type:{type(classes)}")
        train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=False, num_workers=2)
        for batch_idx, (data, target) in enumerate(train_loader):
            print(f"Batch {batch_idx}: data shape = {data.shape}, target = {target}")
            if batch_idx == 0:  # Print only the first batch for verification
                break

        # GET TRAIN EMBEDDINGS (support set)
        print('Calculate feature maps for support set embeddings...')
        df_train = make_embeddings(train_loader, model, classes, device)
        df_train.to_csv("isic_labels_unified_TRAIN_embeddings.csv", index=False)
        print(df_train.head())
        add_train_test_df_to_chroma_collection(chroma_collection,df_train)
        chroma_client.persist() # data is only stored in Chroma DB when persist() is called

        assert chroma_collection.count() == (len(df_train))

        ## SUPPORT SET HAS BEEN ADDED TO DB ##
        # EVALUATE ACCURACY MELANOMA
        # load the csv file
        csv_file_path = "assets/datasets/ISIC2018/Task3-DiseaseClassification/isic_labels_unified_TEST.csv"
        df = pd.read_csv(csv_file_path)

        # Define and map the class label
        class_labels = ["MEL","NV","BCC","AKIEC","BKL","DF","VASC"]
        #{'MEL': 0, 'NV': 1, 'BCC': 2, 'AKIEC': 3, 'BKL': 4, 'DF': 5, 'VASC': 6}
        # melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis / Bowen’s disease, benign keratosis, dermatofibroma, and vascular lesion
        class_labels_map = {}
        for indx, label in enumerate(class_labels):
            class_labels_map[indx] = label

        # Lets create an object from our custom dataset class
        test_dataset = datasetISIC2018DiseaseClassification(csv_file_path, class_labels, transform, "test")

        classes = test_dataset.classes
        test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=2)
        df_test = make_embeddings(test_loader, model, classes, device)
        df_test.to_csv("isic_labels_unified_TEST_embeddings.csv", index=False)
        print(df_test.head())

elif dataset == "pneumonia":
        print("Pneumonia dataset")
        removed_features_per_class = [1,5,10,20,50,100,200,300,500,700,900]
        # DATASET # TODO choose your support dataset (ours: "melanoma", "pneumonia")
        train_dataset = torchvision.datasets.ImageFolder('assets/datasets/Pneumonia/train', transform=transform)
        classes = ['NORMAL', 'PNEUMONIA'] if pretrain else train_dataset.classes
        print(f"classes: {classes} type:{type(classes)}")
        train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=False, num_workers=2)

        # GET TRAIN EMBEDDINGS (support set)
        print('Calculate feature maps for support set embeddings...')
        if pretrain:
            df_train = pd.read_csv("pneumonia_train_embeddings.csv")
            df_train['embedding'] = df_train['embedding'].apply(ast.literal_eval)
            df_train['metadata'] = df_train['metadata'].apply(ast.literal_eval)

        else:
            df_train = make_embeddings(train_loader, model, classes, device)
            df_train.to_csv("pneumonia_train_embeddings.csv", index=False)
        print(df_train.head())
        add_train_test_df_to_chroma_collection(chroma_collection,df_train)
        chroma_client.persist() # data is only stored in Chroma DB when persist() is called

        assert chroma_collection.count() == (len(df_train))

        ## SUPPORT SET HAS BEEN ADDED TO DB ##
        # EVALUATE ACCURACY PNEUMONIA
        # DATASET # TODO choose your test dataset (ours: "melanoma", "pneumonia")
        test_dataset = torchvision.datasets.ImageFolder('assets/datasets/Pneumonia/test', transform=transform)

        classes = ['NORMAL', 'PNEUMONIA'] if pretrain else test_dataset.classes
        test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=2)
        if pretrain:
          df_test = pd.read_csv("pneumonia_test_embeddings.csv")
          df_test['embedding'] = df_test['embedding'].apply(ast.literal_eval)
          df_test['metadata'] = df_test['metadata'].apply(ast.literal_eval)
        else:
          df_test = make_embeddings(test_loader, model, classes, device)
          df_test.to_csv("pneumonia_test_embeddings.csv", index=False)
        
        print(df_test.head())
else:
        print("Unknown dataset")
        exit()

# Evaluate baseline acc with most significant elements
acc_baseline, id_frequencies = eval_collection_against_df_count_most_significant(chroma_collection, df_test,classes)
print(backbone + ":Classes " + str(df_test.class_id.unique()) + ": acc_baseline= " + str(acc_baseline))

results = {}
# Stepwise remove i most significant elements
for i in removed_features_per_class:
    last_acc = copy_collection_delete_most_significant_n_and_eval(id_frequencies, i, chroma_collection, chroma_client, db_collection_name, df_test, classes)
    print(backbone + ":Delete "+str(i)+" most significant elements from each class. Acc= " + str(last_acc))
    results[str(i)] = last_acc

print(results)

In [None]:
# SETUP PARAMETERS Melanoma diminishing support set (random delete)
import ast
device = "cuda:0" # TODO choose your ML computing device
dataset = "melanoma" # TODO choose your dataset (ours: "melanoma", "pneumonia")
pretrain = True
db_name = "privacy-aware-image-classification-with-kNN" # TODO name that allows you to identify db in the future
db_collection_name = db_name + "_" + "ex3_2" # TODO name that allows you to identify your collection in the future
db_persistent_directory = "chroma/databases/" + db_name # TODO adjust to your own chroma storage path
# DATABASE CHROMA
chroma_client = chromadb.Client(chromadb.config.Settings(
            chroma_db_impl = "duckdb+parquet",
            persist_directory = db_persistent_directory
        ))
chroma_collection = chroma_client.get_or_create_collection(name=db_collection_name)
print("Elements in loaded collection: " + str(chroma_collection.count())) # should typically be empty

# MODEL PARAMETERS
# any backbone from https://huggingface.co/timm
backbone = "vit_large_patch14_dinov2.lvd142m" # TODO choose your timm backbone (ours: "vit_small_patch14_dinov2.lvd142m","vit_large_patch14_dinov2.lvd142m","vit_base_patch16_clip_224.openai","vit_large_patch14_clip_336.openai")

# MODEL
model = timm.create_model(backbone, pretrained=True, num_classes=0).to(device)
model = model.eval()

data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False) # get the required transform for the given backbone

if dataset == "melanoma":
        print("Melanoma dataset")
        # load the csv file
        csv_file_path = "assets/datasets/ISIC2018/Task3-DiseaseClassification/isic_labels_unified_TRAIN.csv"
        df = pd.read_csv(csv_file_path)
        print(df.head())

        # Define and map the class label
        class_labels = ["MEL","NV","BCC","AKIEC","BKL","DF","VASC"]
        removed_features_per_class = [1,5,10,20,30,40,50,60,70,80,90,100]
        #{'MEL': 0, 'NV': 1, 'BCC': 2, 'AKIEC': 3, 'BKL': 4, 'DF': 5, 'VASC': 6}
        # melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis / Bowen’s disease, benign keratosis, dermatofibroma, and vascular lesion
        class_labels_map = {}
        for indx, label in enumerate(class_labels):
            class_labels_map[indx] = label


        # GET TRAIN EMBEDDINGS (support set)
        print('Calculate feature maps for support set embeddings...')
        if pretrain:
          print("Using pretrain")
          df_train = pd.read_csv("isic_labels_unified_TRAIN_embeddings.csv")
          df_train['embedding'] = df_train['embedding'].apply(ast.literal_eval)
          df_train['metadata'] = df_train['metadata'].apply(ast.literal_eval)
        else:
          # Lets create an object from our custom dataset class
          train_dataset = datasetISIC2018DiseaseClassification(csv_file_path, class_labels, transform, "train")

          classes = train_dataset.classes
          print(f"classes: {classes} type:{type(classes)}")
          train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=False, num_workers=2)
          for batch_idx, (data, target) in enumerate(train_loader):
              print(f"Batch {batch_idx}: data shape = {data.shape}, target = {target}")
              if batch_idx == 0:  # Print only the first batch for verification
                  break
          df_train = make_embeddings(train_loader, model, classes, device)
          df_train.to_csv("isic_labels_unified_TRAIN_embeddings.csv", index=False)
        print(df_train.head())

        add_train_test_df_to_chroma_collection(chroma_collection,df_train)
        chroma_client.persist() # data is only stored in Chroma DB when persist() is called

        assert chroma_collection.count() == (len(df_train))
        ## DO FOR RANDOM CHOICE
        df_train_n = df_train
        # add labels as own col to df
        df_train_n['labels'] = [metadata['label'] for metadata in df_train_n.metadata] # had typo
        df_train_n = df_train_n.drop(['embedding', 'document','metadata'], axis=1) # had typo

        # Get unique labels
        unique_labels = df_train_n['labels'].unique()

        # Create a dictionary to store numpy arrays for each label
        label_arrays = {}

        # Iterate over unique labels and create numpy arrays
        for label in unique_labels:
            label_arrays[label] = np.array(df_train_n[df_train_n['labels'] == label]['id']) # had typo


        ## SUPPORT SET HAS BEEN ADDED TO DB ##
        # EVALUATE ACCURACY MELANOMA

        if pretrain:
          df_test = pd.read_csv("isic_labels_unified_TEST_embeddings.csv")
          df_test['embedding'] = df_test['embedding'].apply(ast.literal_eval)
          df_test['metadata'] = df_test['metadata'].apply(ast.literal_eval)
        else:
          # load the csv file
          csv_file_path = "assets/datasets/ISIC2018/Task3-DiseaseClassification/isic_labels_unified_TEST.csv"
          df = pd.read_csv(csv_file_path)

          # Define and map the class label
          class_labels = ["MEL","NV","BCC","AKIEC","BKL","DF","VASC"]
          #{'MEL': 0, 'NV': 1, 'BCC': 2, 'AKIEC': 3, 'BKL': 4, 'DF': 5, 'VASC': 6}
          # melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis / Bowen’s disease, benign keratosis, dermatofibroma, and vascular lesion
          class_labels_map = {}
          for indx, label in enumerate(class_labels):
              class_labels_map[indx] = label

          # Lets create an object from our custom dataset class
          test_dataset = datasetISIC2018DiseaseClassification(csv_file_path, class_labels, transform, "test")

          classes = test_dataset.classes
          test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=2)
          df_test = make_embeddings(test_loader, model, classes, device)
          df_test.to_csv("isic_labels_unified_TEST_embeddings.csv", index=False)
        print(df_test.head())

elif dataset == "pneumonia":
        print("Pneumonia dataset")
        # DATASET # TODO choose your support dataset (ours: "melanoma", "pneumonia")
        classes = ['NORMAL', 'PNEUMONIA'] if pretrain else train_dataset.classes
        removed_features_per_class = [1,5,10,20,50,100,200,300,500,700,900]

        print(f"classes: {classes} type:{type(classes)}")

        # GET TRAIN EMBEDDINGS (support set)
        print('Calculate feature maps for support set embeddings...')
        if pretrain:
            df_train = pd.read_csv("pneumonia_train_embeddings.csv")
            df_train['embedding'] = df_train['embedding'].apply(ast.literal_eval)
            df_train['metadata'] = df_train['metadata'].apply(ast.literal_eval)

        else:
            train_dataset = torchvision.datasets.ImageFolder('assets/datasets/Pneumonia/train', transform=transform)
            train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=False, num_workers=2)

            df_train = make_embeddings(train_loader, model, classes, device)
            df_train.to_csv("pneumonia_train_embeddings.csv", index=False)
        print(df_train.head())
        add_train_test_df_to_chroma_collection(chroma_collection,df_train)
        chroma_client.persist() # data is only stored in Chroma DB when persist() is called

        assert chroma_collection.count() == (len(df_train))


        ## DO FOR RANDOM CHOICE
        df_train_n = df_train
        # add labels as own col to df
        df_train_n['labels'] = [metadata['label'] for metadata in df_train_n.metadata] # had typo
        df_train_n = df_train_n.drop(['embedding', 'document','metadata'], axis=1) # had typo

        # Get unique labels
        unique_labels = df_train_n['labels'].unique()

        # Create a dictionary to store numpy arrays for each label
        label_arrays = {}

        # Iterate over unique labels and create numpy arrays
        for label in unique_labels:
            label_arrays[label] = np.array(df_train_n[df_train_n['labels'] == label]['id']) # had typo

        ## SUPPORT SET HAS BEEN ADDED TO DB ##
        # EVALUATE ACCURACY PNEUMONIA

        classes = ['NORMAL', 'PNEUMONIA'] if pretrain else test_dataset.classes
        if pretrain:
          df_test = pd.read_csv("pneumonia_test_embeddings.csv")
          df_test['embedding'] = df_test['embedding'].apply(ast.literal_eval)
          df_test['metadata'] = df_test['metadata'].apply(ast.literal_eval)
        else:
          # DATASET # TODO choose your test dataset (ours: "melanoma", "pneumonia")
          test_dataset = torchvision.datasets.ImageFolder('assets/datasets/Pneumonia/test', transform=transform)
          test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=2)

          df_test = make_embeddings(test_loader, model, classes, device)
          df_test.to_csv("pneumonia_test_embeddings.csv", index=False)
        
        print(df_test.head())
else:
        print("Unknown dataset")
        exit()

# Evaluate baseline acc with most significant elements
acc_baseline, id_frequencies = eval_collection_against_df_count_most_significant(chroma_collection, df_test,classes)
print(backbone + ":Classes " + str(df_test.class_id.unique()) + ": acc_baseline= " + str(acc_baseline))

results = {}
# Stepwise remove i random elements
ids_random_elements_already_deleted = list()
for i in removed_features_per_class:
    # function was implemented wrong
    ids_random_elements_already_deleted, last_acc = copy_collection_delete_random_n_and_eval(label_arrays, i, chroma_collection, chroma_client, db_collection_name, df_test, classes, ids_random_elements_already_deleted)
    print(backbone + ":Delete "+str(i)+" random elements from each class. Acc= " + str(last_acc))
    results[str(i)] = last_acc

print(results)