In [None]:
import os
os.chdir('..')

In [None]:
from src.experiment import CLIPGraphRun
from src.models.explainer import CLIPGraphExplainer, ArtGraphCaptumExplainer, GraphIntegratedGradients
from src.data import DataDict
from src.utils import load_ruamel
from src.visualization.clip_vis import get_img_path, get_style_by_case_id
from src.visualization.utils import SessionStateKey
from torch_geometric.explain import ThresholdConfig
import torch
import pandas as pd
import json
from neo4j import GraphDatabase
from torch_geometric import seed_everything
seed_everything(42)

In [None]:
def get_prediction(run: CLIPGraphRun, case_id: int) -> torch.Tensor:
    class_feats = run.model.encode_graph(
        run.graph.x_dict, run.graph.edge_index_dict, normalize=True
    )
    img_tensor = (
        run.test_loader.dataset[case_id].get(DataDict.IMAGE).unsqueeze(0).to(run.device)
    )
    img_feats = run.model.encode_image(img_tensor, normalize=True)
    prediction = img_feats @ class_feats.T
    return  {
        SessionStateKey.LOGITS: prediction,
        SessionStateKey.IMG_FEATS: img_feats,
        SessionStateKey.GRAPH_FEATS: class_feats,
    }

In [None]:
def get_mappings(root: str, mapping_file: str) -> dict[dict[int, str]]:
    files = {k.split('_')[0]: pd.read_csv(f"{root}/{k}", names=["idx", "name"]).to_dict()["name"] for k in os.listdir(root)}
    with open(mapping_file) as f:
        artwork_mapping = json.load(f)
    artwork_mapping = {int(k): v for k, v in artwork_mapping.items()}
    files["artwork"] = {artwork_mapping[k]: v for k, v in files["artwork"].items() if k in artwork_mapping.keys()}
    return files

In [None]:
parameters = load_ruamel('./configs/proposed_model/0_1/sage/normal_clip_graph_sage2.yaml')
parameters["clean_out_dir"] = False
parameters["model"]["gnn"]["params"]["dropout"] = 0.0
run = CLIPGraphRun(parameters=parameters)

In [None]:
class2idx, idx2class = run.get_class_maps()

In [None]:
explainer = CLIPGraphExplainer(
    device=run.device,
    image_preprocess="clip",
    mappings=get_mappings(
        root="./data/external/artgraph2bestemotions/mapping/",
        mapping_file="./data/processed/graph/0_1/artwork_mapping.json",
    ),
    neo4j_driver=GraphDatabase.driver(uri='bolt://localhost:7688', auth=("neo4j", "neo4j")),
    neo4j_db="neo4j"
)

## CHOOSE THE TEST INSTANCE

In [None]:
from ipywidgets import IntSlider

test_instance = IntSlider(
    value=0,
    min=0,
    max=len(run.test_loader.dataset)-1,
    step=1,
    description="Choose the test instance",
    orientation="horizontal",
)
display(test_instance)

In [None]:
test_instance.value = 19

In [None]:
image_style = get_style_by_case_id(run=run, case_id=test_instance.value)
image_style

In [None]:
img_pth = get_img_path(run=run, case_id=test_instance.value)
img_pth

In [None]:
from PIL import Image
img = Image.open(img_pth).convert("RGB")
img.resize((img.width*3, img.height*3))

### EXPLAIN THE IMAGE

In [None]:
out = get_prediction(run=run, case_id=test_instance.value)
pred_idx = out[SessionStateKey.LOGITS].argmax().cpu().item()
pred_class = idx2class[pred_idx]
overlayed_image = explainer.explain_image(
            img_path=img_pth,
            model=run.model,
            reference_feats=out[SessionStateKey.GRAPH_FEATS],
            target=pred_idx,
            overlayed=True,
        )

In [None]:
ov = Image.fromarray(overlayed_image, "RGB")
ov.resize((ov.width*3, ov.height*3))

In [None]:
f"Explaination of the image for class {pred_class}"

### EXPLAIN THE GRAPH

In [None]:
explanation = explainer.explain_graph(
    model=run.model,
    graph=run.graph,
    reference_feats=out[SessionStateKey.IMG_FEATS].detach(),
    target_node=class2idx[pred_class],
    algo=ArtGraphCaptumExplainer(GraphIntegratedGradients, graph=run.graph),
    algo_kwargs=dict(
        explanation_type="model",
        node_mask_type="attributes",
        edge_mask_type="object",
        model_config=dict(
            mode="multiclass_classification",
            task_level="node",
            return_type="raw",
        ),
        threshold_config=ThresholdConfig(threshold_type="topk", value=100),
    ),
)
explainer.plot_explanation(explanation, run.graph.metadata(), run.graph)


In [None]:
from tqdm import tqdm
correct_case_ids = set()
for case_id in tqdm(range(len(run.test_loader.dataset))):
    image_style = get_style_by_case_id(run=run, case_id=case_id)
    out = get_prediction(run=run, case_id=case_id)
    pred_idx = out[SessionStateKey.LOGITS].argmax().cpu().item()
    pred_class = idx2class[pred_idx]
    if image_style == pred_class:
        #print(f"{image_style} è uguale a {pred_class} per l'immagine {case_id}")
        correct_case_ids.add(case_id)

    

In [None]:
correct_case_ids

In [None]:
import random
random.choice(list(correct_case_ids))

In [None]:
len(correct_case_ids)