Copy and move BenthicNet data over to local node storage

In [None]:
import subprocess

move_script_path = "./slurm/copy_and_extract_data.sh"
subprocess.run(["bash", move_script_path], check=False)

Display system GPU resources

In [None]:
import torch


def get_available_gpus():
    """Get a list of available GPUs on the system."""
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        gpu_names = [torch.cuda.get_device_name(i) for i in range(num_gpus)]
        return gpu_names
    else:
        return []


available_gpus = get_available_gpus()
if available_gpus:
    print("Available GPUs:")
    for i, gpu in enumerate(available_gpus):
        print(f"GPU {i + 1}: {gpu}")
else:
    print("No GPUs available on the system.")

Set arguments and parameters

In [None]:
model = "hp_imagenet_v2_rn50"
train_cfg_path = "./cfgs/cnn/resnet50_hl.json"
graph_path = "./graph_info/finalized_output.csv"
model_checkpoint = f"./pretrained_encoders/{model}.ckpt"
data_csv_path = "./data_csv/benthicnet_nn.csv"
tar_dir = "/gpfs/project/6012565/become_labelled/compiled_labelled_512px/tar"

batch_size = 256
num_workers = 4

test_head = "substrate"

test_random = False

Define graph related functions

In [None]:
import ast

import networkx as nx
import numpy as np

_HEAD_TO_COLUMN_MAP = {
    "biota": "CATAMI Biota",
    "substrate": "CATAMI Substrate",
    "relief": "CATAMI Relief",
    "bedforms": "CATAMI Bedforms",
}

_HEAD_TO_MASK_MAP = {
    "biota": "Biota Mask",
    "substrate": "Substrate Mask",
    "relief": "Relief Mask",
    "bedforms": "Bedforms Mask",
}


# Utilities for graph functions
def check_elements(list_a, list_b):
    return any(element in set(list_b) for element in set(list_a))


# General graph related functions
def create_subgraph(G, radius):
    G = G.copy()

    G.remove_edges_from(nx.selfloop_edges(G))

    root_nodes = [node for node in nx.topological_sort(G) if G.in_degree(node) == 0]

    pruned_G = nx.DiGraph()

    nodes_within_radius = []
    nodes_exact_radius = []
    for start_node in root_nodes:
        sub_G = nx.ego_graph(G=G, n=start_node, radius=radius)

        nodes_and_distances = nx.single_source_shortest_path_length(G, start_node)

        # nodes within radius
        nodes_within_radius_sub = [
            node for node, distance in nodes_and_distances.items() if distance <= radius
        ]

        # nodes with exact radius (include start_node)
        nodes_exact_radius_sub = [
            node for node, distance in nodes_and_distances.items() if distance == radius
        ]

        # updates
        nodes_within_radius += nodes_within_radius_sub
        nodes_exact_radius += nodes_exact_radius_sub
        pruned_G.update(sub_G)

    return pruned_G, nodes_within_radius, nodes_exact_radius


def filter_data(df, head, depth, leaf_nodes, column):
    df = df.copy()
    df = df.dropna(subset=[column])

    df.loc[:, column] = df.loc[:, column].fillna("")
    df.loc[:, f"_{head}"] = df.loc[:, column].apply(
        lambda x: np.where(np.array(x) == 1)[0]
    )
    df.loc[:, f"{head}_depth_{depth}_applicable"] = df.loc[:, f"_{head}"].apply(
        check_elements, args=(leaf_nodes,)
    )
    df.drop([f"_{head}"], axis=1, inplace=True)

    return df


def filter_data_with_hops(df, hops_dict, root_graphs, column):
    for head, hop in hops_dict.items():
        root_graph = root_graphs[head]
        _, nodes_within_radius, nodes_exact_radius = create_subgraph(root_graph, hop)
        df = filter_data(df, head, hop, nodes_exact_radius, column)
    return df, nodes_within_radius


def find_max_depth(graph):
    max_depth = 0

    graph_nodes = set(graph.nodes())
    nodes_within_max_depth = set()

    while not graph_nodes.issubset(nodes_within_max_depth):
        _, nodes_within_radius, _ = create_subgraph(graph, max_depth)
        nodes_within_max_depth = set(nodes_within_radius)
        max_depth += 1

    return max_depth - 1


def remove_relevant_masked_nodes(mask, relevant_nodes):
    return np.array([node for node in relevant_nodes if node not in set(mask)])


def extract_relevant_elements(row, col, idx_col):
    return row[col][row[idx_col]]

Import benthicnet dataset object

In [None]:
import os

import PIL

from utils.benthicnet_dataset import BenthicNetDataset

Construct and load model

In [None]:
from json import loads

from omegaconf import OmegaConf

from utils.utils import construct_model, gen_R_mat, gen_root_graphs

with open(train_cfg_path, "r") as f:
    train_cfg_content = f.read()

train_cfg = loads(train_cfg_content)
train_kwargs = OmegaConf.create(train_cfg)

# Get graphs and adjacency matrices
root_graphs, idx_to_node, node_to_idx = gen_root_graphs(graph_path)
Rs = {root: gen_R_mat(graph) for root, graph in root_graphs.items()}

# Build model
model = construct_model(train_kwargs, Rs, model_checkpoint, test_mode=True)

Prepare dataloader

In [None]:
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from utils.utils import get_augs, get_df, process_data_df

df = get_df(data_csv_path)
df = df[~pd.isna(df[_HEAD_TO_COLUMN_MAP[test_head]])]

test_df = df[df["partition"] == "test"]

_, val_transform = get_augs(False)

test_df = process_data_df(test_df, Rs)

test_dataset = BenthicNetDataset(
    tar_dir=tar_dir, annotations=test_df, transform=val_transform
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=num_workers,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

Run test loop to get model output dataframe

In [None]:
torch.set_grad_enabled(False)
model.eval()

preds = []
tgts = []
masks = []
prev_pred = 0

total_batches = len(test_dataloader)
# Test data loop
for i, batch in enumerate(test_dataloader):
    print(f"Processing batch {i+1}/{total_batches}...", end="\r")

    # New seed every loop
    torch.manual_seed(i)

    # Move items to device
    inputs, data = batch
    inputs = inputs.to(device)

    batch = [inputs, data]
    batch_preds, batch_tgts, batch_masks = model.predict_head_step(
        batch, head=test_head, random_out=test_random
    )

    assert len(batch_preds) == len(batch_tgts)
    assert len(batch_tgts) == len(batch_masks)
    preds.extend(batch_preds.tolist())
    tgts.extend(batch_tgts.tolist())
    masks.extend(batch_masks.tolist())

model_output_dict = {
    "preds": preds,
    "tgts": tgts,
    "masks": masks,
}

model_output_df = pd.DataFrame(model_output_dict)

Prepare accuracy metrics

In [None]:
def depth_df_correct(row):
    assert len(row["relevant_preds"]) == len(row["relevant_tgts"])
    return all(row["relevant_preds"] == row["relevant_tgts"])


def depth_df_avg_bit_f1(row):
    pred = row["relevant_preds"]
    tgt = row["relevant_tgts"]
    assert len(pred) == len(tgt)

    avg_f1 = np.zeros(2)
    for state in [0, 1]:
        # These are the bits that tgt says is equal to "state"
        state_mask = tgt == state

        # What pred says for the bits obtained above
        tp = np.sum(pred[state_mask] == state)
        fn = np.sum(pred[state_mask] != state)
        fp = np.sum(pred[~state_mask] == state)
        tn = np.sum(pred[~state_mask] != state)

        precision = tp / (tp + fp)
        recall = tp / (tp + fn)

        if (precision + recall) == 0:
            avg_f1[state] = 0
        else:
            avg_f1[state] = 2 * precision * recall / (precision + recall)

    return np.mean(np.array(avg_f1))


def depth_df_bit_mcc(row):
    pred = row["relevant_preds"]
    tgt = row["relevant_tgts"]
    assert len(pred) == len(tgt)

    state = 0
    state_mask = tgt == state

    # What pred says for the bits obtained above
    tp = np.sum(pred[state_mask] == state)
    fn = np.sum(pred[state_mask] != state)
    fp = np.sum(pred[~state_mask] == state)
    tn = np.sum(pred[~state_mask] != state)

    numerator = tp * tn - fp * fn
    denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    if denominator == 0:
        mcc = 0
    else:
        mcc = numerator / denominator

    return mcc


def set_depth_idx_to_text_map(relevant_nodes, idx_to_text_maps, head):
    idx_to_text_map = idx_to_text_maps[head]

    depth_map = {i: idx_to_text_map[idx] for i, idx in enumerate(relevant_nodes)}

    return depth_map


# Define translation functions
def translate_idx_to_text(idx_array, idx_to_text_map, head):
    if isinstance(idx_array, str):
        idx_array = ast.literal_eval(idx_array)
    text_array = [idx_to_text_map[idx] for idx in idx_array]
    unique_texts = set(text_array)
    text_result = []

    for text in text_array:
        if not any(text in item for item in unique_texts if item != text):
            text_result.append(text)

    return text_result


def translate_idx_to_text_row(row, col_to_translate, map_column, head):
    idx_to_text_map = row[map_column]
    idx_array = row[col_to_translate]
    if isinstance(idx_array, str):
        idx_array = ast.literal_eval(idx_array)
    text_array = [idx_to_text_map[idx] for idx in idx_array]
    unique_texts = set(text_array)
    text_result = []

    for text in text_array:
        if not any(text in item for item in unique_texts if item != text):
            text_result.append(text)

    return text_result

Main processing loop for depth metrics for fused/superposition labels

In [None]:
import matplotlib.pyplot as plt


def plot_graph(G, size, color):
    options = {
        "node_size": size,
        "node_color": color,
        # "alpha": 0.6,
        "with_labels": True,
    }

    nx.draw(G, pos=nx.spring_layout(G, seed=42), **options)

    plt.show()

In [None]:
max_depth = find_max_depth(root_graphs[test_head])

depth_dict = {
    test_head: 0,
}

output_df = model_output_df.copy()

accs = np.zeros(max_depth + 1)
f1s = np.zeros(max_depth + 1)
mccs = np.zeros(max_depth + 1)
sample_sizes = np.zeros(max_depth + 1)

for depth in range(max_depth + 1):
    # Set depth-wise column names
    relevant_nodes_col = f"relevant_{test_head}_nodes_at_depth_{depth}"
    relevant_map_col = f"idx_to_node_{test_head}_at_depth_{depth}"

    print("\n" + "=" * 60)
    print(f"Processing depth {depth} of {max_depth} for {test_head}...")
    depth_dict[test_head] = depth
    depth_df, nodes_within_radius = filter_data_with_hops(
        output_df, depth_dict, root_graphs, column="tgts"
    )
    depth_df = depth_df[depth_df[f"{test_head}_depth_{depth}_applicable"] == True]
    depth_df["_masks"] = depth_df["masks"].apply(
        lambda x: np.where(np.array(x) == 0)[0]
    )

    # Determine which nodes are relevant at current depth after filtering for masks
    depth_df[relevant_nodes_col] = depth_df["_masks"].apply(
        remove_relevant_masked_nodes, args=(nodes_within_radius,)
    )
    depth_df[relevant_map_col] = depth_df[relevant_nodes_col].apply(
        set_depth_idx_to_text_map, args=(idx_to_node, test_head)
    )
    depth_df["preds"] = depth_df["preds"].apply(lambda x: np.array(x))
    depth_df["tgts"] = depth_df["tgts"].apply(lambda x: np.array(x))

    # Filter preds and tgts for relevant nodes
    depth_df["relevant_preds"] = depth_df.apply(
        extract_relevant_elements, args=("preds", relevant_nodes_col), axis=1
    )
    depth_df["relevant_tgts"] = depth_df.apply(
        extract_relevant_elements, args=("tgts", relevant_nodes_col), axis=1
    )

    # Calculate acc and bit_f1
    depth_df["correct"] = depth_df.apply(depth_df_correct, axis=1)
    depth_df["avg_bit_f1"] = depth_df.apply(depth_df_avg_bit_f1, axis=1)
    depth_df["bit_mcc"] = depth_df.apply(depth_df_bit_mcc, axis=1)

    # Have to convert tgts to string to be hashable
    depth_df["relevant_tgts_str"] = depth_df["relevant_tgts"].apply(
        lambda x: str(list(np.where(x == 1)[0]))
    )
    depth_df["relevant_preds_str"] = depth_df["relevant_preds"].apply(
        lambda x: str(list(np.where(x == 1)[0]))
    )
    depth_acc = []
    depth_bit_f1 = []
    depth_bit_mcc = []

    unique_labels = depth_df["relevant_tgts_str"].unique()

    for label in unique_labels:
        label_df = depth_df[depth_df["relevant_tgts_str"] == label].copy()
        label_acc = np.sum(label_df["correct"]) / len(label_df)
        label_bit_f1 = np.sum(label_df["avg_bit_f1"]) / len(label_df)
        label_bit_mcc = np.sum(label_df["bit_mcc"]) / len(label_df)

        # Determine most common prediction from our model
        most_common_pred = label_df["relevant_preds_str"].value_counts().idxmax()

        # Determine how frequently most common prediction is made
        most_common_df = label_df[label_df["relevant_preds_str"] == most_common_pred]
        num_most_common = len(
            label_df[label_df["relevant_preds_str"] == most_common_pred]
        )
        num_samples_label = len(label_df)
        most_common_pct = num_most_common / num_samples_label

        # Determine relevant map
        idx_map = (
            label_df.loc[label_df["relevant_preds_str"] == most_common_pred, :][
                relevant_map_col
            ]
            .copy()
            .reset_index(drop=True)[0]
        )

        most_common_pred = translate_idx_to_text(
            most_common_pred, idx_map, head=test_head
        )

        label_text = translate_idx_to_text(label, idx_map, head=test_head)
        print("\n\t" + "-" * 40)
        # print(f"\tRaw Label: {label}")
        print(f"\tLabel: {label_text}")
        print("\tAcc:", label_acc)
        print("\tF1:", label_bit_f1)
        print("\tMCC:", label_bit_mcc)
        print("\tMost common pred:", most_common_pred)
        print("\tPredicted frequency:", most_common_pct)
        print("\tNumber of samples for label:", num_samples_label)

        depth_acc.append(label_acc)
        depth_bit_f1.append(label_bit_f1)
        depth_bit_mcc.append(label_bit_mcc)

    mean_depth_acc = np.mean(depth_acc)
    mean_depth_bit_f1 = np.mean(depth_bit_f1)
    mean_depth_bit_mcc = np.mean(depth_bit_mcc)
    sample_count = len(depth_df)
    print("\n\t" + "-" * 60)
    print(f"\tDepth {depth} acc: {mean_depth_acc}")
    print(f"\tDepth {depth} bit F1: {mean_depth_bit_f1}")
    print(f"\tDepth {depth} bit MCC: {mean_depth_bit_mcc}")
    print(f"\tDepth {depth} sample count: {sample_count}")

    accs[depth] = mean_depth_acc
    f1s[depth] = mean_depth_bit_f1
    mccs[depth] = mean_depth_bit_mcc
    sample_sizes[depth] = sample_count

head_acc = np.sum(accs * sample_sizes) / np.sum(sample_sizes)
macro_head_acc = np.mean(accs)
head_f1 = np.sum(f1s * sample_sizes) / np.sum(sample_sizes)
head_mcc = np.sum(mccs * sample_sizes) / np.sum(sample_sizes)
print("\n" + "=" * 80)
print(f"Sample weighted accuracy for {test_head}: {head_acc}")
print(f"Macro accuracy for {test_head}: {macro_head_acc}")
print(f"Bit F1 for {test_head}: {head_f1}")
print(f"Bit MCC for {test_head}: {head_mcc}")

Define function to decompose labels and score performance on individual targets

In [None]:
def decompose_and_score_labels(
    row, tp_dict, fp_dict, fn_dict, count_dict, depth, verbose=0
):
    tgts = row["text_tgts"]
    tgts_len = len(tgts)

    preds = set(row["text_preds"])
    original_preds_len = len(preds)

    if verbose > 0 and original_preds_len > tgts_len:
        print(f"\tExtra predictions detected for {tgts}.")
        print(f"Targets: {tgts}")
        print(f"Predictions: {preds}")

    local_tp_dict = {}
    local_fp_dict = {}
    local_fn_dict = {}
    local_count_dict = {}

    for tgt in tgts:
        split_symbol_count = tgt.count(" > ")
        if split_symbol_count == depth:
            local_count_dict[tgt] = 1
            if tgt in preds:
                local_tp_dict[tgt] = 1
                local_fp_dict[tgt] = 0
                local_fn_dict[tgt] = 0
                preds.remove(tgt)
            else:
                local_tp_dict[tgt] = 0
                local_fp_dict[tgt] = 0
                local_fn_dict[tgt] = 1

    for pred in preds:
        split_symbol_count = pred.count(" > ")
        if split_symbol_count == depth:
            if pred not in tgts:
                local_tp_dict[pred] = 0
                local_fp_dict[pred] = 1
                local_fn_dict[pred] = 0

    for label in local_tp_dict.keys():
        tp_dict[label] = tp_dict.get(label, 0) + local_tp_dict[label]
        fp_dict[label] = fp_dict.get(label, 0) + local_fp_dict.get(label, 0)
        fn_dict[label] = fn_dict.get(label, 0) + local_fn_dict.get(label, 0)
        count_dict[label] = count_dict.get(label, 0) + local_count_dict.get(label, 0)

Main processing loop for calculating decomposed label accuracy

In [None]:
max_depth = find_max_depth(root_graphs[test_head])

depth_dict = {
    test_head: 0,
}

output_df = model_output_df.copy()

precision = np.zeros(max_depth + 1)
recall = np.zeros(max_depth + 1)
f1 = np.zeros(max_depth + 1)
sample_sizes = np.zeros(max_depth + 1)

for depth in range(max_depth + 1):
    # Set depth-wise column names
    relevant_nodes_col = f"relevant_{test_head}_nodes_at_depth_{depth}"
    relevant_map_col = f"idx_to_node_{test_head}_at_depth_{depth}"

    print("\n" + "=" * 60)
    print(f"Processing depth {depth} of {max_depth} for {test_head}...")
    depth_dict[test_head] = depth
    depth_df, nodes_within_radius = filter_data_with_hops(
        output_df, depth_dict, root_graphs, column="tgts"
    )
    depth_df = depth_df[depth_df[f"{test_head}_depth_{depth}_applicable"] == True]
    depth_df["_masks"] = depth_df["masks"].apply(
        lambda x: np.where(np.array(x) == 0)[0]
    )

    # Determine which nodes are relevant at current depth after filtering for masks
    depth_df[relevant_nodes_col] = depth_df["_masks"].apply(
        remove_relevant_masked_nodes, args=(nodes_within_radius,)
    )
    depth_df[relevant_map_col] = depth_df[relevant_nodes_col].apply(
        set_depth_idx_to_text_map, args=(idx_to_node, test_head)
    )
    depth_df["preds"] = depth_df["preds"].apply(lambda x: np.array(x))
    depth_df["tgts"] = depth_df["tgts"].apply(lambda x: np.array(x))

    # Filter preds and tgts for relevant nodes
    depth_df["relevant_preds"] = depth_df.apply(
        extract_relevant_elements, args=("preds", relevant_nodes_col), axis=1
    )
    depth_df["relevant_tgts"] = depth_df.apply(
        extract_relevant_elements, args=("tgts", relevant_nodes_col), axis=1
    )

    # Have to convert tgts and preds to string to be hashable
    depth_df["relevant_tgts_str"] = depth_df["relevant_tgts"].apply(
        lambda x: str(list(np.where(x == 1)[0]))
    )
    depth_df["relevant_preds_str"] = depth_df["relevant_preds"].apply(
        lambda x: str(list(np.where(x == 1)[0]))
    )

    # Get text labels and preds
    depth_df["text_tgts"] = depth_df.apply(
        translate_idx_to_text_row,
        args=("relevant_tgts_str", relevant_map_col, test_head),
        axis=1,
    )
    depth_df["text_preds"] = depth_df.apply(
        translate_idx_to_text_row,
        args=("relevant_preds_str", relevant_map_col, test_head),
        axis=1,
    )

    # Convert text tgts and preds list to strings
    depth_df["text_tgts_str"] = depth_df["text_tgts"].apply(str)
    depth_df["text_preds_str"] = depth_df["text_preds"].apply(str)

    unique_labels = depth_df["text_tgts_str"].unique()

    depth_tp_dict = {}
    depth_fp_dict = {}
    depth_fn_dict = {}
    depth_count_dict = {}

    for label in unique_labels:
        label_df = (
            depth_df[depth_df["text_tgts_str"] == label].copy().reset_index(drop=True)
        )
        label_list = label_df["text_tgts"][0]

        tp_dict = {}
        fp_dict = {}
        fn_dict = {}
        count_dict = {}

        label_df.apply(
            decompose_and_score_labels,
            args=(tp_dict, fp_dict, fn_dict, count_dict, depth, 0),
            axis=1,
        )

        for tgt in tp_dict.keys():
            depth_tp_dict[tgt] = depth_tp_dict.get(tgt, 0) + tp_dict[tgt]
            depth_fp_dict[tgt] = depth_fp_dict.get(tgt, 0) + fp_dict[tgt]
            depth_fn_dict[tgt] = depth_fn_dict.get(tgt, 0) + fn_dict[tgt]
            depth_count_dict[tgt] = depth_count_dict.get(tgt, 0) + count_dict[tgt]

    depth_precisions = []
    depth_recalls = []
    depth_f1s = []
    depth_counts = []

    for tgt in depth_tp_dict.keys():
        tgt_precision_denominator = depth_tp_dict[tgt] + depth_fp_dict[tgt]
        if tgt_precision_denominator == 0:
            tgt_precision = 0
        else:
            tgt_precision = depth_tp_dict[tgt] / (
                depth_tp_dict[tgt] + depth_fp_dict[tgt]
            )

        tgt_recall_denominator = depth_tp_dict[tgt] + depth_fn_dict[tgt]
        if tgt_recall_denominator == 0:
            tgt_recall = 0
        else:
            tgt_recall = depth_tp_dict[tgt] / (depth_tp_dict[tgt] + depth_fn_dict[tgt])

        tgt_f1_denominator = tgt_precision + tgt_recall
        if tgt_f1_denominator == 0:
            tgt_f1 = 0
        else:
            tgt_f1 = 2 * (tgt_precision * tgt_recall) / (tgt_precision + tgt_recall)

        print("\n\t" + "-" * 40)
        print(f"\tTarget: {tgt}")
        print(f"\tTP: {depth_tp_dict[tgt]}")
        print(f"\tFP: {depth_fp_dict[tgt]}")
        print(f"\tFN: {depth_fn_dict[tgt]}")
        print(f"\tPrecision: {tgt_precision:.2f}")
        print(f"\tRecall: {tgt_recall:.2f}")
        print(f"\tF1: {tgt_f1:.2f}")
        print(f"\tCount: {depth_count_dict[tgt]}")

        depth_precisions.append(tgt_precision)
        depth_recalls.append(tgt_recall)
        depth_f1s.append(tgt_f1)
        depth_counts.append(depth_count_dict[tgt])

    depth_precision = np.mean(depth_precisions)
    depth_recall = np.mean(depth_recalls)
    depth_f1 = np.mean(depth_f1s)
    depth_count = np.sum(depth_counts)

    print("\n\t" + "-" * 60)
    print(f"\tDepth {depth} precision: {depth_precision:.2f}")
    print(f"\tDepth {depth} recall: {depth_recall:.2f}")
    print(f"\tDepth {depth} F1: {depth_f1:.2f}")
    print(f"\tDepth {depth} count: {depth_count}")

    precision[depth] = depth_precision
    recall[depth] = depth_recall
    f1[depth] = depth_f1
    sample_sizes[depth] = depth_count

head_precision = np.sum(precision * sample_sizes) / np.sum(sample_sizes)
head_recall = np.sum(recall * sample_sizes) / np.sum(sample_sizes)
head_f1 = np.sum(f1 * sample_sizes) / np.sum(sample_sizes)
macro_head_precision = np.mean(precision)
macro_head_recall = np.mean(recall)
macro_head_f1 = np.mean(f1)
print("\n" + "=" * 80)
print(
    f"Precision for {test_head}, sample-weighted: {head_precision:.2f}, macro: {macro_head_precision:.2f}"
)
print(
    f"Recall for {test_head}, sample-weighted: {head_recall:.2f}, macro: {macro_head_recall:.2f}"
)
print(f"F1 for {test_head}, sample-weighted: {head_f1:.2f}, macro: {macro_head_f1:.2f}")

In [None]:
plot_graph(G=root_graphs["substrate"], size=5, color="g")