## SummariseEmbeddingsOverGroupings
This script:
- Reads in a pickle files which contains a dataframe with one row per sampled image. Each image is associated with a location (lat, lon), a link to the image_file, an embedding, category_scores, and the cluster the image has been assigned to with both 2 and 7 total clusters
- Finds the percentage of images in each LSOA, within each of the clusters
- Finds the mean/min/max embedding within each cluster, within each LSOA
- Saves a pickle file containing a dataframe containing this information

In [2]:
import pickle
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from functools import reduce

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, KFold
from sklearn.ensemble import RandomForestRegressor

from joblib import Parallel, delayed

def is_missing_embedding(x):
    return isinstance(x, float) and np.isnan(x)

from functions import *

In [3]:
data_dir = os.path.join("../../../../data/embeddings/")

In [None]:
k = 7 

### Get data

In [4]:
points_data_cache = data_dir + f"embedding_summaries/expanded_gdf_withclustering.pkl"
with open(points_data_cache, "rb") as f:
    expanded_gdf = pickle.load(f)

In [5]:
# cluster_col = "scene_cluster_7"

# cluster_sizes = expanded_gdf[cluster_col].value_counts()
# target_n = cluster_sizes.min()

# # print(cluster_sizes)
# # print(f"Downsampling each cluster to {target_n} rows")

# balanced_gdf = (expanded_gdf
#     .groupby(cluster_col, group_keys=False)
#     .apply(lambda x: x.sample(n=target_n, random_state=42))
#     .reset_index(drop=True))

# cluster_sizes = balanced_gdf[cluster_col].value_counts()

# def balanced_sample(df, cluster_col, random_state=None):
#     target_n = df[cluster_col].value_counts().min()
#     return (
#         df
#         .groupby(cluster_col, group_keys=False)
#         .apply(lambda x: x.sample(n=target_n, random_state=random_state))
#         .reset_index(drop=True))

# # Example: repeat 20 times
# samples = [balanced_sample(expanded_gdf, "scene_cluster_7", rs) for rs in range(20)]

# Create a dataframe with % of images in each category, in each LSOA 

In [7]:
df = expanded_gdf
category_column = f"scene_cluster_{k}"

# --- 1. Count images per (LSOA, category) ---
category_counts = (df.groupby(["LSOA21CD", category_column]).size().reset_index(name="count"))

# --- 2. Total images per LSOA ---
total_counts = (df.groupby("LSOA21CD").size().reset_index(name="total_images"))

# --- 3. Merge totals ---
category_counts = category_counts.merge(total_counts, on="LSOA21CD")

# --- 4. Add percentage for each category ---
category_counts["pct"] = (category_counts["count"] / category_counts["total_images"] * 100)

# --- 5. Wide table: counts in columns ---
counts_wide = (category_counts.pivot(index="LSOA21CD", columns=category_column, values="count").fillna(0).add_prefix("count_"))

# --- 6. Wide table: percentages in columns ---
pct_wide = (category_counts.pivot(index="LSOA21CD", columns=category_column, values="pct")
        .fillna(0).add_prefix("pct_"))

# --- 7. Combine both + total images per LSOA ---
lsoa_summary = (total_counts.set_index("LSOA21CD").join([counts_wide, pct_wide]))

# plt.hist(lsoa_summary['total_images'], bins=20)

lsoa_summary.head()

Unnamed: 0_level_0,total_images,count_1,count_2,count_3,count_4,count_5,count_6,count_7,pct_1,pct_2,pct_3,pct_4,pct_5,pct_6,pct_7
LSOA21CD,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
E01004766,64,5.0,18.0,10.0,11.0,8.0,11.0,1.0,7.8125,28.125,15.625,17.1875,12.5,17.1875,1.5625
E01004767,72,11.0,15.0,5.0,12.0,7.0,17.0,5.0,15.277778,20.833333,6.944444,16.666667,9.722222,23.611111,6.944444
E01004768,44,3.0,9.0,3.0,16.0,0.0,1.0,12.0,6.818182,20.454545,6.818182,36.363636,0.0,2.272727,27.272727
E01004769,40,9.0,7.0,2.0,10.0,1.0,2.0,9.0,22.5,17.5,5.0,25.0,2.5,5.0,22.5
E01004770,40,5.0,12.0,2.0,13.0,0.0,2.0,6.0,12.5,30.0,5.0,32.5,0.0,5.0,15.0


# Find mean/median/max embedding in each LSOA, also by category

In [8]:
# Aggregation functions
def mean_embed(series):
    return np.mean(np.stack(series.values), axis=0)

def max_embed(series):
    return np.max(np.stack(series.values), axis=0)

def median_embed(series):
    return np.median(np.stack(series.values), axis=0)

agg_funcs = {"mean": mean_embed, "max": max_embed, "median": median_embed}

# List of categories
categories = df[category_column].unique()

# Initialize list to hold all DataFrames
all_dfs = []

for agg_name, func in agg_funcs.items():
    dfs = []
    
    # Per-category embeddings
    for cat in categories:
        df_cat = df[df[category_column] == cat]
        emb_cat = df_cat.groupby("LSOA21CD")["embedding"].apply(func).reset_index()
        emb_cat = emb_cat.rename(columns={"embedding": f"{cat}_{agg_name}"})
        dfs.append(emb_cat)
    
    # Merge all categories
    merged = reduce(lambda left, right: pd.merge(left, right, on="LSOA21CD", how="outer"), dfs)
    
    # Overall embedding (all images in LSOA)
    overall = expanded_gdf.groupby("LSOA21CD")["embedding"].apply(func).reset_index()
    overall = overall.rename(columns={"embedding": f"overall_{agg_name}"})
    
    merged = merged.merge(overall, on="LSOA21CD", how="left")
    
    all_dfs.append(merged)

# Merge mean, max, median into a single DataFrame
final_df = reduce(lambda left, right: pd.merge(left, right, on="LSOA21CD", how="outer"), all_dfs)

# # Fill missing embeddings with zeros if needed
# final_df = final_df.fillna(0)

### Save

In [9]:
final_df

Unnamed: 0,LSOA21CD,1_mean,6_mean,5_mean,7_mean,3_mean,4_mean,2_mean,overall_mean,1_max,...,2_max,overall_max,1_median,6_median,5_median,7_median,3_median,4_median,2_median,overall_median
0,E01004766,"[0.033103533, -0.07074865, -0.0038302832, -0.0...","[0.03522002, -0.053018257, 0.013809818, -0.003...","[0.023067513, -0.051904604, 0.005900901, -0.01...","[0.06578761, -0.07596135, -0.018992858, 0.0056...","[0.034151368, -0.04616829, -0.01352316, -0.010...","[0.022992719, -0.064196885, 0.017334871, -0.00...","[0.04612167, -0.063937575, -0.0060228854, -0.0...","[0.034810767, -0.058544807, 0.0016876708, -0.0...","[0.046476994, -0.042404536, 0.011767318, 0.004...",...,"[0.065536425, -0.035994805, 0.027840978, 0.022...","[0.06868579, 0.02436501, 0.04058686, 0.0227269...","[0.034778804, -0.08205054, -0.004037957, -0.00...","[0.03324594, -0.06337941, 0.013070603, -0.0017...","[0.03259358, -0.050204903, 0.0049343845, -0.00...","[0.06578761, -0.07596135, -0.018992858, 0.0056...","[0.035405725, -0.04771714, -0.016125362, -0.01...","[0.022437582, -0.06504305, 0.017818907, -0.007...","[0.048336804, -0.062283844, -0.006093976, -0.0...","[0.037450068, -0.060527682, 0.0038238708, -0.0..."
1,E01004767,"[0.0155904, -0.06697523, -0.008416992, 0.00601...","[0.036190536, -0.041433036, 0.01250577, -0.011...","[0.03722277, -0.05215759, -0.011467544, -0.018...","[0.031461813, -0.06416202, -0.025430348, 0.007...","[0.022189489, -0.058764618, -0.016952509, 0.00...","[0.026041085, -0.06665549, 0.013026682, -0.007...","[0.042382684, -0.06791675, -0.0048367633, -0.0...","[0.03144143, -0.05888115, -0.0012278779, -0.00...","[0.0440901, -0.03508209, 0.007759808, 0.041439...",...,"[0.079137824, -0.040383887, 0.02246309, 0.0106...","[0.079137824, 0.007119606, 0.05769042, 0.04143...","[0.014381738, -0.062272016, -0.005507076, -0.0...","[0.037343822, -0.038623177, 0.010105682, -0.01...","[0.031542774, -0.061148256, -0.0093991645, -0....","[0.027648307, -0.051163804, -0.0267478, 0.0071...","[0.014970649, -0.056541987, -0.01938628, 0.010...","[0.023301795, -0.056862667, 0.011832323, -0.00...","[0.0469197, -0.06351485, -0.007057649, -0.0105...","[0.02811988, -0.057094224, -0.002435438, -0.00..."
2,E01004768,"[0.037889328, -0.081487596, -0.0024430014, 0.0...","[0.0106770545, -0.05135075, 0.022684604, -0.01...",,"[0.053052247, -0.0632091, -0.014393783, 0.0062...","[0.035934847, -0.047909573, -0.008653186, -0.0...","[0.033699464, -0.062361263, 0.012203996, -0.01...","[0.038975507, -0.065425605, -0.002837063, -0.0...","[0.039971534, -0.063287765, -0.0003090683, -0....","[0.06717749, -0.04954933, 0.009416995, 0.03412...",...,"[0.058850233, -0.042929746, 0.022226406, 0.007...","[0.08859621, -0.029185485, 0.035205305, 0.0341...","[0.034779258, -0.09420764, -0.0073540835, 0.00...","[0.0106770545, -0.05135075, 0.022684604, -0.01...",,"[0.052555893, -0.057254165, -0.018139582, 0.00...","[0.03992245, -0.04233761, -0.013213266, -0.013...","[0.03190714, -0.06591379, 0.0129283685, -0.017...","[0.03650768, -0.06322911, -0.0031270299, -0.01...","[0.036196973, -0.062007807, -0.0030021807, -0...."
3,E01004769,"[0.037865788, -0.0547471, -0.015366654, -0.006...","[0.061572693, -0.05962702, 0.012585461, -0.015...","[0.044616487, -0.06895213, -0.009721331, 0.008...","[0.047364842, -0.052075706, -0.026688226, -0.0...","[0.05049605, -0.06860547, -0.01464322, -0.0088...","[0.034732457, -0.06680633, 0.0067776283, -0.01...","[0.023240106, -0.06275923, -0.011262887, -0.00...","[0.038645875, -0.059855007, -0.010084866, -0.0...","[0.055508982, -0.026112689, 0.00014157066, 0.0...",...,"[0.04269879, -0.04647873, 0.0036678074, 0.0076...","[0.076253936, -0.026112689, 0.024038304, 0.018...","[0.038979694, -0.059747826, -0.016038358, -0.0...","[0.061572693, -0.05962702, 0.012585461, -0.015...","[0.044616487, -0.06895213, -0.009721331, 0.008...","[0.042015992, -0.04918718, -0.026024027, -0.00...","[0.05049605, -0.06860547, -0.01464322, -0.0088...","[0.031517163, -0.07080994, 0.010840595, -0.014...","[0.02536327, -0.0641494, -0.0074477354, -0.009...","[0.039836742, -0.058352925, -0.011052389, -0.0..."
4,E01004770,"[0.019877654, -0.047637712, -0.00015878165, 0....","[0.034624744, -0.035400465, 0.017509863, -0.01...",,"[0.024157813, -0.06713715, -0.01565353, 0.0050...","[0.025436353, -0.053490806, -0.017503206, -0.0...","[0.030130852, -0.06452417, 0.004716233, -0.010...","[0.047155026, -0.053241987, -0.013663166, -0.0...","[0.03305047, -0.057412803, -0.0049337186, -0.0...","[0.040416043, -0.021657884, 0.014084354, 0.012...",...,"[0.06972161, -0.03507684, 0.019781169, 0.00565...","[0.06972161, -0.021657884, 0.03964949, 0.02837...","[0.019066552, -0.053142264, -0.00491653, 0.006...","[0.034624744, -0.035400465, 0.017509863, -0.01...",,"[0.024650805, -0.064850114, -0.008979902, 0.00...","[0.025436353, -0.053490806, -0.017503206, -0.0...","[0.032980166, -0.06863478, 0.0043842085, -0.01...","[0.050275862, -0.052111566, -0.013365662, -0.0...","[0.0318129, -0.059660837, -0.0050424268, -0.00..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1690,E01005901,,,,,,"[0.02158466, -0.06535083, 0.002541018, -0.0120...","[0.027051007, -0.056315992, -0.009303177, -0.0...","[0.024317836, -0.06083342, -0.0033810788, -0.0...",,...,"[0.06334833, -0.03230919, 0.02370188, 0.005372...","[0.06334833, -0.03230919, 0.02370188, 0.015132...",,,,,,"[0.024750924, -0.06543581, 0.002196563, -0.014...","[0.028607313, -0.055382513, -0.010520168, -0.0...","[0.026195362, -0.060958337, -0.003408682, -0.0..."
1691,E01005939,,,,,,"[0.024516538, -0.060917716, 0.010709342, -0.02...","[0.031052334, -0.053518843, -0.008302648, -0.0...","[0.027784435, -0.057218272, 0.0012033462, -0.0...",,...,"[0.06353953, -0.033973522, 0.0065727094, -0.00...","[0.06353953, -0.033973522, 0.035203513, -0.004...",,,,,,"[0.02477602, -0.0622021, 0.007916713, -0.02307...","[0.02641822, -0.05552365, -0.010795528, -0.014...","[0.026053756, -0.058995366, -0.0010066541, -0...."
1692,E01005980,,,,,,"[0.025245028, -0.059243917, 0.006049343, -0.01...","[0.032022417, -0.054652877, 0.007733837, -0.02...","[0.028633721, -0.0569484, 0.0068915905, -0.019...",,...,"[0.07222642, -0.044546586, 0.029492278, -0.003...","[0.07222642, -0.04029348, 0.029492278, -0.0033...",,,,,,"[0.02644438, -0.062499933, 0.005549832, -0.016...","[0.029433219, -0.05518974, 0.006745371, -0.022...","[0.027387286, -0.05704549, 0.0055932594, -0.02..."
1693,E01006287,,,,,,"[0.014723813, -0.06370592, 0.019477382, 0.0004...","[0.016617486, -0.058633998, -0.0058315084, -0....","[0.01567065, -0.06116996, 0.0068229362, -0.001...",,...,"[0.03474071, -0.04220454, 0.0017510265, 0.0065...","[0.03474071, -0.04220454, 0.023362806, 0.01884...",,,,,,"[0.01513793, -0.06447614, 0.018985176, 0.00378...","[0.022894237, -0.05684855, 0.0002531857, 0.001...","[0.019345962, -0.06159339, 0.009163696, 0.0014..."


In [32]:
final_df = final_df.merge(lsoa_summary, on = "LSOA21CD")
file_ending = f'kmeanscluster{k}_resampled1'
final_df.to_pickle(data_dir + f"embedding_summaries/big_summary_df_{file_ending}.pkl")