In [2]:
import pandas as pd

In [3]:
geo_embed = pd.read_csv("/home/leon/Documents/GPML/good_practices_ml/CLIP_Embeddings/Image/geoguessr_embeddings.csv")
aerial_embed = pd.read_csv("/home/leon/Documents/GPML/good_practices_ml/CLIP_Embeddings/Image/aerial_embeddings.csv")
tourist = pd.read_csv("/home/leon/Documents/GPML/good_practices_ml/CLIP_Embeddings/Image/tourist_embeddings.csv")

In [15]:
def print_information(df: pd.DataFrame, name: str):
    print(f"Information about {name}")
    print(f"length: {len(df)}")
    #print(f"Class distribution: {df['label'].value_counts()}")
    print(f"Number of unique classes: {len(df['label'].unique())}")
    print(f"Number of classes below 10: {len(df['label'].value_counts()[df['label'].value_counts() < 10])}")

In [16]:
print_information(geo_embed, "geoguessr")
print_information(aerial_embed, "aerial")
print_information(tourist, "tourist")

Information about geoguessr
length: 49905
Number of unique classes: 120
Number of classes below 10: 25
Information about aerial_map
length: 267
Number of unique classes: 38
Number of classes below 10: 27
Information about bigfoto
length: 2297
Number of unique classes: 33
Number of classes below 10: 4


In [17]:
def balance_data(df: pd.DataFrame, target_path: str, max_images: int = 1000, min_images: int = 10, seed: int = 1234):
    """
    Balance the data in a DataFrame by randomly sampling a maximum of images from each class, dropping classes below the minimum.

    Args:
        df (pd.DataFrame): The DataFrame containing the data.
        REPO_PATH (str): The path to the repository.
        max_images (int, optional): The maximum number of images to sample. Defaults to 1000.
        min_images (int, optional): The minimum number of images to sample. Defaults to 10.
        seed (int, optional): The random seed. Defaults to 1234.

    Returns:
        pd.DataFrame: The balanced DataFrame.
    """
    # Group the DataFrame by label
    grouped = df.groupby("label")
    
    # Create an empty DataFrame to store the balanced data
    balanced_df = pd.DataFrame(columns=df.columns)
    
    # Iterate over each group
    for _ , group_df in grouped:
        # Check if the group has enough images
        if len(group_df) >= min_images:
            # Randomly sample the maximum number of images from the group
            sampled_df = group_df.sample(n=min(len(group_df), max_images), random_state=seed)
            # Append the sampled data to the balanced DataFrame
            balanced_df = pd.concat([balanced_df,sampled_df], ignore_index=True)
    
    # Save the balanced DataFrame
    return balanced_df

In [18]:
import os 
REP_PATH = "/home/leon/Documents/GPML/good_practices_ml"
geo_embed = pd.read_csv(os.path.join(REP_PATH, "CLIP_Embeddings/Image/geoguessr_embeddings.csv"))
balanced_geo_df = balance_data(df=geo_embed, target_path=os.path.join(REP_PATH, "CLIP_Embeddings/Image"), max_images=2000, min_images=10, seed=1234)
# Get all images from geo_df for the classes that have more than 2000 images and that are not in balanced_geo_df
geo_large_classes = geo_embed["label"].value_counts()[geo_embed["label"].value_counts() > 2000].index.tolist()
geo_additional_images = geo_embed[geo_embed["label"].isin(geo_large_classes) & ~geo_embed["path"].isin(balanced_geo_df["path"])]

In [21]:
balanced_geo_df

Unnamed: 0,label,width,height,format,path,Embedding
0,Albania,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-1.8170e-01, -1.0691e-01, 1.8980e-01..."
1,Albania,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-1.3501e-01, 1.0323e-01, 5.4323e-01..."
2,Albania,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-2.8526e-02, -8.6397e-02, 2.2200e-01..."
3,Albania,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-1.0155e-01, -2.9273e-01, -1.7680e-01..."
4,Albania,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-2.7844e-02, -6.3074e-02, 7.0930e-02..."
...,...,...,...,...,...,...
35591,Vietnam,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-6.2175e-02, 7.6111e-02, 4.1519e-01..."
35592,Vietnam,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[ 8.8583e-02, -3.2490e-01, 2.5532e-01..."
35593,Vietnam,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[-1.0660e-01, -3.0397e-01, 4.4139e-01..."
35594,Vietnam,1536,662,JPEG,/share/temp/bjordan/good_practices_in_machine_...,"tensor([[ 2.1165e-01, 2.6407e-02, 1.9423e-01..."
