In [169]:
## Dataset Things

from data.volumes import Volume, Page, Line
from data.graphset import Graphset
from data.graph_sampler import GraphSampler, AttributeSampler
from data.image_dataset import ImageDataset


import torch_geometric.utils as tutils


import data.volumes as dv

## Model Things
from models import gnn_encoders as gnn
from models import visual_encoders as cnn
from models import edge_visual_encoders as EVE
from models.graph_construction_model import MMGCM


### Utils
import utils 
import visualizations as visu


## Pipelines
import pipelines as pipes

## tasks
from tasks import record_linkage as rl

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

import torch
import torch.functional as F
import numpy as np
import requests
import cv2
import torchvision
from PIL import Image
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam import GradCAM

from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

from typing import *
from torch.utils.data import DataLoader

In [170]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Hydra Configuration

In [177]:
overrides = [
    "data.dataset.path=../data/CED/SFLL",
    "models.edge_visual_encoder.add_attention=False",
    "models.add_language=False",    
]

In [178]:
with initialize(version_base="1.3.2", config_path="./configs"):
    CFG = compose(config_name="eval", overrides=overrides, return_hydra_config=True)
    

In [179]:
cfg_models = CFG.models
cfg_data = CFG.data
cfg_setup = CFG.setup

In [174]:
#! The utility of this dictionary is to relate the groundtruth with the visual information AD-HOC
pk = {"Noms_harmo":"nom", "cognom1_harmo":"cognom_1", "cognom2_harmo":"cognom_2", "parentesc_har":"parentesc", "ocupacio":"ocupacio"}

#  ^ Hydra things

batch_size = cfg_data.collator.batch_size
shuffle = cfg_data.collator.shuffle
number_volum_years = len(cfg_data.dataset.volumes) 
checkpoint_name = cfg_models.name_checkpoint

In [180]:
checkpoint_name

'MMGC_Experiment_1_New_Edge_PE_no_Attention_language'

In [181]:
cfg_models.edge_visual_encoder.add_attention

False

## Extract the Data

In [62]:
## & Extract the dataset and the information in this case 
volumes = pipes.load_volumes(cfg=cfg_data) 
image_dataset = ImageDataset(Volumes=volumes, cfg=cfg_data.dataset)
df_transcriptions = image_dataset._total_gt
n_different_individuals = image_dataset._total_individual_nodes
graphset = Graphset(total_nodes=n_different_individuals,
                    df_transcriptions=df_transcriptions,
                    n_volumes=len(volumes),
                    graph_configuration=cfg_data.graph_configuration,
                    auxiliar_entities_pk = pk)

sampler = AttributeSampler(graph=graphset._graph, batch_size=batch_size, shuffle=shuffle)


print("Generating DataLoader")

total_loader = DataLoader(dataset=image_dataset, 
                        batch_size = batch_size,
                        collate_fn=image_dataset.collate_fn,
                        num_workers=0,
                        shuffle=True,
                        pin_memory=True)

print("DATA LOADED SUCCESFULLY")

STARTING DOWNLOADING VOLUMES: VOLUME- ../data/CED/SFLL/1889
050000120052048,0014.jpg
050000120052048,0031.jpg
050000120052048,0060.jpg
050000120052048,0061.jpg
050000120052048,0074.jpg
050000120052048,0082.jpg
050000120052048,0088.jpg
050000120052048,0091.jpg
STARTING DOWNLOADING VOLUMES: VOLUME- ../data/CED/SFLL/1906
050000120052053,0007.jpg
050000120052053,0022.jpg
050000120052053,0058.jpg
050000120052053,0059.jpg
050000120052053,0111.jpg
STARTING DOWNLOADING VOLUMES: VOLUME- ../data/CED/SFLL/1910
050000120052054,0007.jpg
050000120052054,0021.jpg
050000120052054,0032.jpg
050000120052054,0035.jpg
050000120052054,0040.jpg
050000120052054,0055.jpg
050000120052054,0068.jpg
050000120052054,0073.jpg
STARTING DOWNLOADING VOLUMES: VOLUME- ../data/CED/SFLL/1915
050000120052055,0007.jpg
050000120052055,0008.jpg
050000120052055,0012.jpg
050000120052055,0013.jpg
050000120052055,0015.jpg
050000120052055,0031.jpg
050000120052055,0037.jpg
050000120052055,0048.jpg
050000120052055,0074.jpg
0500001200

Generating Adj for attributes: 462it [00:00, 816.42it/s] 
Generating Adj for attributes: 1086it [00:02, 383.86it/s]
Generating Adj for attributes: 1338it [00:00, 4723.54it/s]
Generating Adj for Same As: 5486it [00:00, 41770.58it/s]
Generating Adj for families: 3045it [00:00, 19591.04it/s]

Generating DataLoader
DATA LOADED SUCCESFULLY





## Load Model

In [63]:
## Model Things
from models import gnn_encoders as gnn
from models import visual_encoders as cnn
from models import edge_visual_encoders as EVE
from models.graph_construction_model import MMGCM
import random
import copy
import os

H = image_dataset.general_height 
W = image_dataset.general_width // 16

cfg_models.edge_visual_encoder.input_atention_mechanism = H*W

model = MMGCM(visual_encoder=cnn.LineFeatureExtractor, gnn_encoder=gnn.AttributeGNN, edge_encoder=EVE.EdgeAttFeatureExtractor, cfg=cfg_models).to(device)
model_name = f"../checkpoints/{checkpoint_name}.pt"
name_embeddings = f"{checkpoint_name}"

model.load_state_dict(torch.load(model_name))
model.to(device)
print("MODEL LOADED SUCCESFULLY")

MODEL LOADED SUCCESFULLY


In [176]:
print("EXTRACTING LANGUAGE EMBEDDINGS")
filepath  = f"../embeddings/language_embeddings_{number_volum_years}_entities_{len(cfg_setup.configuration.compute_loss_on)}.pkl"
language_embeddings = utils.read_pickle(filepath)

print("EXTRACTING VISUAL EMBEDDINGS")
attribute_embeddings = utils.read_pickle(filepath=f"../embeddings/{checkpoint_name}.pkl")
attribute_embeddings = attribute_embeddings.cpu()


print("Embeddings Loaded Succesfully")

graph = graphset._graph
graph.x_language = language_embeddings
graph.epoch_populations = image_dataset._population_per_volume

EXTRACTING LANGUAGE EMBEDDINGS
EXTRACTING VISUAL EMBEDDINGS
Embeddings Loaded Succesfully


## Plotting distributions and analysi of the distances

In [182]:
base_save_plot = f"../plots/evaluation/{checkpoint_name}/"
base_save_metrics = f"../metrics/evaluation/{checkpoint_name}/"
os.makedirs(base_save_plot, exist_ok=True)

### Filter the most common transcriptions and the least common transcriptions

In [183]:
filtered_name_attributes = {label: indexes for label, indexes in graph["nom"].map_attribute_index.items() if len(indexes) > 20}
filtered_surname_attributes = {label: indexes for label, indexes in graph["cognom_1"].map_attribute_index.items() if len(indexes) > 20}
filtered_ssurname_attributes = {label: indexes for label, indexes in graph["cognom_2"].map_attribute_index.items() if len(indexes) > 20} 

In [184]:
### Get a random name, surname and second surname
random_name = random.sample((filtered_name_attributes.keys()), 1)[0]
random_surname = random.sample((filtered_surname_attributes.keys()), 1)[0]
random_second_surname = random.sample((filtered_ssurname_attributes.keys()), 1)[0]

In [185]:
random_name

'bonaventura'

## Plot the drift of the whole population before and after the model feature extraction

#### Name

In [186]:
print(attribute_embeddings.shape)

torch.Size([13829, 3, 128])


In [187]:
population = graph["nom"].map_attribute_index[random_name]
nom_idx = list(graph.map_attribute_nodes.values()).index("nom")
nom_embeddings = attribute_embeddings[:, nom_idx, :].numpy()
previous_embeddings = copy.copy(graph.x_attributes).numpy()

In [188]:
population

array([   59,    73,  1481,  1483,  1715,  1902,  2272,  2379,  2484,
        3860,  3935,  4598,  4709,  6186,  6448,  6499,  6532,  6681,
        6909,  7392,  7467,  8837,  9363,  9920, 10322, 10534, 11585,
       12479, 12522, 12622, 13286])

In [189]:
visu.plot_shift(embedding_matrix_1=previous_embeddings[population, nom_idx,:],
                embedding_matrix_2= nom_embeddings[population],
                fig_path=os.path.join(base_save_plot, "nom", f"shift_{random_name}.jpg"))



In [190]:
visu.plot_embedding_distribution(nom_embeddings, os.path.join(base_save_plot, "nom", f"distribution.jpg"))


#### Surname

In [191]:
population = graph["cognom_1"].map_attribute_index[random_surname]
cognom_idx = list(graph.map_attribute_nodes.values()).index("cognom_1")
cognom_embeddings = attribute_embeddings[:, cognom_idx, :].numpy()

In [192]:
visu.plot_shift(embedding_matrix_1=previous_embeddings[population, cognom_idx,:],
                embedding_matrix_2= cognom_embeddings[population],
                fig_path=os.path.join(base_save_plot, "cognom_1", f"shift_{random_surname}.jpg"))

In [193]:
visu.plot_embedding_distribution(cognom_embeddings, os.path.join(base_save_plot, "cognom_1", f"distribution.jpg"))


#### Second Surname

In [194]:
population = graph["cognom_2"].map_attribute_index[random_second_surname]
cognom2_idx = list(graph.map_attribute_nodes.values()).index("cognom_2")
cognom2_embeddings = attribute_embeddings[:, cognom2_idx, :].numpy()

In [195]:
visu.plot_shift(embedding_matrix_1=previous_embeddings[population, cognom2_idx,:],
                embedding_matrix_2 = cognom2_embeddings[population],
                fig_path=os.path.join(base_save_plot, "cognom_2", f"shift_{random_second_surname}.jpg"))

In [196]:
visu.plot_embedding_distribution(cognom2_embeddings, os.path.join(base_save_plot, "cognom_2", f"distribution.jpg"))


## Evaluate the distance metri space from each of the attributes

### Evaluate the cluster distances

In [197]:
computed_freq_name = {name: len(values) for name, values in graph["nom"].map_attribute_index.items() if len(values) > 10}
computed_freq_surname = {name: len(values) for name, values in graph["cognom_1"].map_attribute_index.items() if len(values) > 10}
computed_freq_ssurname = {name: len(values) for name, values in graph["cognom_2"].map_attribute_index.items() if len(values) > 10}
    

In [198]:
name_freq  = sorted(list(computed_freq_name.items()), key= lambda k: (k[1]), reverse=True)
surname_freq  = sorted(list(computed_freq_surname.items()), key= lambda k: (k[1]), reverse=True)
second_surnname_freq  = sorted(list(computed_freq_ssurname.items()), key= lambda k: (k[1]), reverse=True)

most_common_names, less_common_names = name_freq[:10], name_freq[-10:]
most_common_surnames, less_common_surnames = surname_freq[:10], surname_freq[-10:]
most_common_ssurnames, less_common_ssurnames = second_surnname_freq[:10], second_surnname_freq[-10:]


In [199]:
most_common_names

[('josep', 1114),
 ('maria', 999),
 ('joan', 845),
 ('josefa', 536),
 ('jaume', 491),
 ('dolores', 471),
 ('teresa', 401),
 ('rosa', 393),
 ('francisco', 391),
 ('antonio', 378)]

In [200]:
less_common_names

[('avelina', 11),
 ('clotilde', 11),
 ('eusebi', 11),
 ('llucia', 11),
 ('pasqual', 11),
 ('rafela', 11),
 ('remigia', 11),
 ('serafina', 11),
 ('sever', 11),
 ('ursicina', 11)]

#### Name

In [201]:
kepp_indexes_name = [(label, graph["nom"].map_attribute_index[label]) for label, _ in most_common_names + less_common_names  ]
visu.plot_violin_plot_from_freq_attribute_distances(specific_attribute_embeddings=attribute_embeddings[:,0,:], dic_realtion_names=kepp_indexes_name,
                                                    file_path=os.path.join(base_save_plot, "nom", f"common_distances_distribution.jpg"))

#### Surname

In [202]:
kepp_indexes_surname = [(label, graph["cognom_1"].map_attribute_index[label]) for label, _ in most_common_surnames + less_common_surnames  ]
visu.plot_violin_plot_from_freq_attribute_distances(specific_attribute_embeddings=attribute_embeddings[:,1,:], dic_realtion_names=kepp_indexes_surname,
                                                    file_path=os.path.join(base_save_plot, "cognom_1", f"common_distances_distribution.jpg"))


#### Second Surname

In [203]:
kepp_indexes_ssurname = [(label, graph["cognom_2"].map_attribute_index[label]) for label, _ in most_common_ssurnames + less_common_ssurnames  ]
visu.plot_violin_plot_from_freq_attribute_distances(specific_attribute_embeddings=attribute_embeddings[:,2,:], dic_realtion_names=kepp_indexes_ssurname,
                                                    file_path=os.path.join(base_save_plot, "cognom_2", f"common_distances_distribution.jpg"))


In [204]:
visu.plot_attribute_metric_space(attribute_embedding_space=attribute_embeddings.cpu(),
                                    fig_name=os.path.join(base_save_plot, "Attributes_Distribution.jpg"))


## See the nearest Neighboors of the different attributes

In [205]:
def r_precision(relevant_documents, retrieved_documents):
    """
    Calculate the R-Precision metric.
    
    Parameters:
    relevant_documents (list): List of relevant document IDs.
    retrieved_documents (list): List of retrieved document IDs in ranked order.

    Returns:
    float: R-Precision score.
    """
    # The number of relevant documents in the ground truth
    R = len(relevant_documents)
    
    # Take the top-R retrieved documents
    top_r_retrieved = retrieved_documents[:R]
    
    # Count how many of the top-R retrieved documents are relevant
    relevant_retrieved_count = len(set(top_r_retrieved) & set(relevant_documents))
    
    # R-Precision is the ratio of relevant documents retrieved in top-R to R
    return relevant_retrieved_count / (R) 



@torch.no_grad()
def extract_intra_cluster_nearest_neighboors_at_k(x,
                                                  top_k=10):
    #Graph = utils.read_pickle("pickles/graphset_3_volumes_128_entities_4.pkl")

    dict_nearest_neighbors = {}
    distances_dict = {}
    for idx in range(x.shape[1]):
        embeddings = x[:, idx]
        distances = torch.cdist(embeddings, embeddings, p=2)
        distances_dict[idx] = distances
        
        # Set the diagonal to a large positive number to avoid self-matching
        distances.fill_diagonal_(float('inf'))
        # Get the top K nearest neighbors for each embedding (smallest distances)
        top_k_values, top_k_indices = torch.topk(distances, top_k, dim=1, largest=False)

        # Create a dictionary mapping each index to its top K nearest neighbors
        nearest_neighbors = {ix: top_k_indices[ix].tolist() for ix in range(distances.size(0))}

        dict_nearest_neighbors[idx] = nearest_neighbors
        
    return dict_nearest_neighbors, distances_dict
    

## Metric space of the language model

In [206]:
language_embeddings = torch.from_numpy(np.array(language_embeddings)).unsqueeze(1)
language_embeddings.shape

torch.Size([2948, 1, 128])

In [207]:
image_lines = image_dataset._line_paths
nn, distances = extract_intra_cluster_nearest_neighboors_at_k(language_embeddings, top_k=language_embeddings.shape[0])

In [91]:
list_ocrs = {"nom": [], "cognom_1":[], "cognom_2": []}
for nom, cog, cog2 in (image_dataset._ocrs):
    
    list_ocrs["nom"].append(image_dataset._map_ocr[nom])
    list_ocrs["cognom_1"].append(image_dataset._map_ocr[cog])
    list_ocrs["cognom_2"].append(image_dataset._map_ocr[cog2])
    

In [92]:
print(nn[0][0])

[409, 870, 1181, 1216, 1739, 165, 1580, 942, 2417, 1626, 1340, 1047, 2892, 1062, 856, 2073, 1125, 2263, 1817, 1599, 2752, 1801, 1271, 2745, 40, 150, 2035, 1404, 61, 387, 151, 2591, 2777, 607, 1428, 358, 2022, 82, 1896, 1057, 784, 2880, 2307, 2068, 2415, 940, 2424, 1035, 2798, 1990, 212, 2775, 1128, 2152, 168, 808, 1302, 627, 1827, 1657, 1263, 648, 293, 139, 2945, 1796, 1018, 2504, 560, 689, 441, 1256, 1681, 598, 1655, 75, 1036, 558, 1516, 864, 2580, 544, 2054, 1251, 2429, 2122, 1594, 1810, 1237, 952, 1856, 481, 490, 2261, 1220, 561, 226, 2917, 1994, 731, 990, 1520, 919, 1792, 805, 2361, 1793, 2649, 1565, 2146, 1679, 1704, 2653, 236, 2379, 2252, 2327, 1038, 1577, 1713, 987, 112, 1134, 2593, 2803, 2249, 674, 1821, 1940, 1208, 158, 2265, 721, 228, 2298, 1582, 1384, 1575, 2763, 2804, 1299, 2853, 2200, 2847, 1678, 2300, 834, 1978, 2270, 1332, 688, 2195, 2784, 2872, 1068, 2145, 2622, 2808, 1190, 1020, 1794, 1506, 2127, 2744, 923, 16, 1026, 2102, 2633, 1938, 1021, 2174, 532, 2720, 1439, 2490,

In [93]:
metrics_R_precission = {}
for (att_idx, att) in [(0, "nom"), (1, "cognom_1"), (2, "cognom_2")]:
    metrics_R_precission[att] = {}
    att_mean_metric = 0
    order_ocrs_individuals = np.array(list_ocrs[att])
    for ind_idx in range(language_embeddings.shape[0]):
        
        try:
            content_attribute = graph[att].map_index_attribute[ind_idx]
        except:
            continue
        
        relevant_individuals = graph[att].map_attribute_index[content_attribute]
        relevant_languages = order_ocrs_individuals[relevant_individuals]
               
        #if ind_idx == 2948:
        #    print(relevant_languages)
        #    break
        
        value = r_precision(relevant_documents=relevant_languages, retrieved_documents=nn[0][ind_idx])
        att_mean_metric += value
        if metrics_R_precission.get(content_attribute, None) is None:
            metrics_R_precission[att][content_attribute] = [value * 100]
            
        else:
            metrics_R_precission[att][content_attribute].append(value)
            
    metrics_R_precission[att]["mean_recall"] = att_mean_metric/len(metrics_R_precission[att])
        

In [94]:
metrics_R_precission["cognom_2"]["mean_recall"]

0.005385880091916884

In [95]:
#0.010 nom
#0.04 cognom
#0.03 cognom_2

## Metric space of the visual model

In [208]:
nn, distances = extract_intra_cluster_nearest_neighboors_at_k(attribute_embeddings, top_k=attribute_embeddings.shape[0])

In [209]:
attribute_embeddings.shape

torch.Size([13829, 3, 128])

In [210]:
graph["nom"].map_index_attribute[33]

'jaume'

In [211]:
name_freq  = sorted(list(computed_freq_name.items()), key= lambda k: (k[1]), reverse=True)
surname_freq  = sorted(list(computed_freq_surname.items()), key= lambda k: (k[1]), reverse=True)
second_surnname_freq  = sorted(list(computed_freq_ssurname.items()), key= lambda k: (k[1]), reverse=True)

name_high_freq, name_low_freq = [i[0] for i in name_freq if i[1] >= 100],  [i[0] for i in name_freq if i[1] <= 20] 
surname_high_freq, surname_low_freq = [i[0] for i in surname_freq if i[1] >= 100],  [i[0] for i in surname_freq if i[1] <= 20] 
ssurname_high_freq, ssurname_low_freq = [i[0] for i in second_surnname_freq if i[1] >= 100],  [i[0] for i in second_surnname_freq if i[1] <= 20] 


In [212]:
metrics_R_precission = {}
for (att_idx, att, high_freq, low_freq) in [(0, "nom", name_high_freq, name_low_freq), (1, "cognom_1", surname_high_freq, surname_low_freq), (2, "cognom_2", ssurname_high_freq, ssurname_low_freq)]:
    metrics_R_precission[att] = {}
    att_mean_low = 0
    att_mean_high = 0
    att_mean = 0
    count_high = 0
    count_low = 0
    count = 0
    for ind_idx in range(attribute_embeddings.shape[0]):
        
        try:
            content_attribute = graph[att].map_index_attribute[ind_idx]
            
        except:
            continue
        
        if metrics_R_precission.get(content_attribute, None) is None:
            metrics_R_precission[att][content_attribute] = []
        
        
        relevant_individuals = graph[att].map_attribute_index[content_attribute]
        value = r_precision(relevant_documents=relevant_individuals, retrieved_documents=nn[att_idx][ind_idx])
        att_mean += value
        
        if content_attribute in high_freq:
            count_high += 1
            att_mean_high += value           
        
        elif content_attribute in low_freq:
            count_low += 1
            att_mean_low += value           

        metrics_R_precission[att][content_attribute].append(value)
                
    metrics_R_precission[att]["high_mean_recall"] = att_mean_high/(count_high)
    metrics_R_precission[att]["low_mean_recall"] = att_mean_low/(count_low)
    metrics_R_precission[att]["mean_recall"] = att_mean/(attribute_embeddings.shape[0])


            

In [213]:
print(metrics_R_precission["nom"]["mean_recall"])
print(metrics_R_precission["nom"]["high_mean_recall"])
print(metrics_R_precission["nom"]["low_mean_recall"])

0.02878049812218453
0.041112067630115136
0.0008279450032342326


In [214]:
print(metrics_R_precission["cognom_1"]["mean_recall"])
print(metrics_R_precission["cognom_1"]["high_mean_recall"])
print(metrics_R_precission["cognom_1"]["low_mean_recall"])

0.016342247519383804
0.018197421607707962
0.02151953542645649


In [215]:
print(metrics_R_precission["cognom_2"]["mean_recall"])
print(metrics_R_precission["cognom_2"]["high_mean_recall"])
print(metrics_R_precission["cognom_2"]["low_mean_recall"])

0.014176128352925053
0.01633218919423515
0.016796353925177802


## Record Linkage Task

### Evaluate the inter attribute metric space 

In [216]:
knn_metr = utils.evaluate_attribute_metric_space(attribute_embeddings.numpy(), plot_path=os.path.join(base_save_plot, f"inter_attr_distribution.jpg"))


Fold 0:

STARTING THE EVALUATION WITH KNN, n_neighbors = 1

STARTING THE EVALUATION WITH KNN, n_neighbors = 3

STARTING THE EVALUATION WITH KNN, n_neighbors = 5

STARTING THE EVALUATION WITH KNN, n_neighbors = 10
Fold 1:

STARTING THE EVALUATION WITH KNN, n_neighbors = 1

STARTING THE EVALUATION WITH KNN, n_neighbors = 3

STARTING THE EVALUATION WITH KNN, n_neighbors = 5

STARTING THE EVALUATION WITH KNN, n_neighbors = 10
Fold 2:

STARTING THE EVALUATION WITH KNN, n_neighbors = 1

STARTING THE EVALUATION WITH KNN, n_neighbors = 3

STARTING THE EVALUATION WITH KNN, n_neighbors = 5

STARTING THE EVALUATION WITH KNN, n_neighbors = 10
Fold 3:

STARTING THE EVALUATION WITH KNN, n_neighbors = 1

STARTING THE EVALUATION WITH KNN, n_neighbors = 3

STARTING THE EVALUATION WITH KNN, n_neighbors = 5

STARTING THE EVALUATION WITH KNN, n_neighbors = 10
Fold 4:

STARTING THE EVALUATION WITH KNN, n_neighbors = 1

STARTING THE EVALUATION WITH KNN, n_neighbors = 3

STARTING THE EVALUATION WITH KNN, n_n

In [217]:
knn_metr

{'inter_att_dist': {'Accuracy_dist': {1: (1.0, 0.0),
   3: (1.0, 0.0),
   5: (1.0, 0.0),
   10: (1.0, 0.0)},
  'F_score_dist': {1: (1.0, 0.0),
   3: (1.0, 0.0),
   5: (1.0, 0.0),
   10: (1.0, 0.0)}}}

### Record Linkage Task

In [258]:
checkpoint_name = "MMGC_Experiment_1_New_Edge_PE_no_Attention_language/MMGC_Experiment_1_New_Edge_PE_no_Attention_language_train_199"
print("EXTRACTING LANGUAGE EMBEDDINGS")
filepath  = f"../embeddings/language_embeddings_{number_volum_years}_entities_{len(cfg_setup.configuration.compute_loss_on)}.pkl"
language_embeddings = utils.read_pickle(filepath)

print("EXTRACTING VISUAL EMBEDDINGS")
attribute_embeddings = utils.read_pickle(filepath=f"../embeddings/{checkpoint_name}.pkl")
attribute_embeddings = attribute_embeddings.cpu()


print("Embeddings Loaded Succesfully")

graph = graphset._graph
graph.x_language = language_embeddings
graph.epoch_populations = image_dataset._population_per_volume

EXTRACTING LANGUAGE EMBEDDINGS
EXTRACTING VISUAL EMBEDDINGS
Embeddings Loaded Succesfully


In [259]:
candidate_pairs = graph[("individual", "similar", "individual")].negative_sampling
true_pairs = graph[("individual", "similar", "individual")].edge_index

In [260]:
# extract the population from the last time
final_time_gap_population = list(graph.epoch_populations[-2]) + list(graph.epoch_populations[-1]) 
final_time_gap_population = torch.tensor(final_time_gap_population).type(torch.int64)

In [261]:
len(graph.epoch_populations)

4

#### Masking the population that appears in the last pair of periods (Test)

In [262]:
mask = torch.isin(true_pairs[:2, :], final_time_gap_population).all(dim=0)
mask_candidate = torch.isin(candidate_pairs[:2,:], final_time_gap_population).all(dim=0)

In [263]:
specific_true_pairs_subgraph = true_pairs[:, mask]
specific_subgraph_candidate = candidate_pairs[:, mask_candidate]

### Test same as embeddings

In [264]:
X_test_indexes = torch.cat((specific_true_pairs_subgraph, specific_subgraph_candidate), dim=1).type(torch.int32).numpy()
X_test = (attribute_embeddings[X_test_indexes[0], :] - attribute_embeddings[X_test_indexes[1],:]).pow(2).sum(-1).sqrt()
y_test = X_test_indexes[-1]

### Train 

In [265]:
## ^ TRAIN EXTRACTION 
earlies_time_populations = []
for pop in graph.epoch_populations[:-1]: ## Keep the las two periods
    earlies_time_populations += list(pop)

### Extrat the population of the first periods to train

In [266]:
earlies_time_populations = torch.tensor(earlies_time_populations).type(torch.int64)
mask = torch.isin(true_pairs[:2, :], earlies_time_populations).all(dim=0)
mask_candidate_train = torch.isin(candidate_pairs[:2,:], earlies_time_populations).all(dim=0)

specific_true_pairs_subgraph_train = true_pairs[:, mask]
specific_subgraph_candidate_train = candidate_pairs[:, mask_candidate_train]

### Extract the train embeddings

In [267]:
X_train_indexes = torch.cat((specific_true_pairs_subgraph_train, specific_subgraph_candidate_train), dim=1).numpy()
X_train = (attribute_embeddings[X_train_indexes[0], :] - attribute_embeddings[X_train_indexes[1],:]).pow(2).sum(-1).sqrt()

y_train = X_train_indexes[-1] #torch.cat((torch.ones((specific_true_pairs_subgraph_train.shape[1])), torch.zeros((specific_subgraph_candidate_train.shape[1]))), dim=0).numpy()


In [268]:
np.sum(y_train)

3911.0

## Train the Logistic regression

In [269]:
metrics_rl = rl.record_linkage_with_logistic_regression(X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test, candidate_sa=X_test_indexes,
                                                                                                                random_state=0, penalty="l2", class_weight="balanced", n_jobs=8)

2425


MEAN Probability:  0.5
STD Probability:  0.08280965206569965
Confusion Matrix:
 [[ 154 4484]
 [ 294 2131]]
Accuracy: 0.3235169191561659
Recall:  0.8787628865979381
Precission:  0.3221466364323507
F-score:  0.47146017699115034
Specificity:  0.044746787603930464
