In [1]:
import sys
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
from llavart.utils.dirutils import get_model_checkpoints_dir, get_data_dir
from llavart.data.graph.ops import get_nodes_connections, get_nodes_by_id
from llavart.models.graph_retriever.ops import load_model, entities_to_ids
from llavart.models.graph_retriever.modeling import entity_constrainer
from safetensors.torch import load_file, save_file
import lovely_tensors as lt
from langchain_community.graphs import Neo4jGraph
from tqdm import tqdm
from pykeen.triples import TriplesFactory
from pykeen.predict import predict_target
from pykeen import models
lt.monkey_patch()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = load_model(get_model_checkpoints_dir() / "graph_retriever" / "results" / "TransR")

In [4]:
features = load_file(get_data_dir() / "images-hd_features.safetensors")
features = features["features"]

In [5]:
graph = Neo4jGraph()
query = """
MATCH (a:Artwork)
RETURN ID(a) as id, a.name as name
ORDER BY id
"""
result = graph.query(query)
artworks_id2name = {row["id"]: row["name"] for row in result}

keys = list(artworks_id2name.keys())
keys = [str(key) for key in keys]

In [6]:
triples_factory = TriplesFactory.from_path(get_data_dir() / "graph" / "graph.tsv")
training, testing, validation = triples_factory.split(
    [0.8, 0.1, 0.1], random_state=42
)

In [7]:
entities2ids, missing_entities, entities2indexes = entities_to_ids(
    entities=keys,
    triples_factory=training,
    return_missing_dict=True,
    return_index_dict=True,
)

In [8]:
artworks2artists = get_nodes_connections(graph, [int(key) for key in keys], "createdBy")
artworks2artists

{8142: [850],
 8143: [850],
 8144: [850],
 8145: [850],
 8146: [850],
 8147: [850],
 8148: [1853],
 8149: [1853],
 8150: [1853],
 8151: [1853],
 8152: [1853],
 8153: [1853],
 8154: [1853],
 8155: [1853],
 8156: [1853],
 8157: [1853],
 8158: [1853],
 8159: [1853],
 8160: [1853],
 8161: [1853],
 8162: [1853],
 8163: [1853],
 8164: [1853],
 8165: [1853],
 8166: [1853],
 8167: [1853],
 8168: [1853],
 8169: [1853],
 8170: [1853],
 8171: [1853],
 8172: [1853],
 8173: [1853],
 8174: [1853],
 8175: [1853],
 8176: [1853],
 8177: [1853],
 8178: [1853],
 8179: [1853],
 8180: [1853],
 8181: [1853],
 8182: [1853],
 8183: [1853],
 8184: [1853],
 8185: [1853],
 8186: [1853],
 8187: [1853],
 8188: [1853],
 8189: [1853],
 8190: [1853],
 8191: [1853],
 8192: [1853],
 8193: [1853],
 8194: [1853],
 8195: [1853],
 8196: [1853],
 8197: [1853],
 8198: [1853],
 8199: [1853],
 8200: [1853],
 8201: [1853],
 8202: [1853],
 8203: [1853],
 8204: [1853],
 8205: [1853],
 8206: [1853],
 8207: [1853],
 8208: [1853],
 

In [9]:
artists = list(set(value[0] for value in artworks2artists.values()))
artists

[50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
 185,
 186,
 187,
 188,
 189,
 190,
 191,
 192,
 193,
 194,
 195,
 196,
 197,
 198,
 199,
 200,
 201,
 202,
 203,
 204,
 205,
 206,
 207,
 208,
 209,
 210,
 211,
 212,
 213,
 214,
 215,
 216,
 217,
 218,
 219,
 220,
 221,
 222,
 223,
 224,


In [30]:
# remove artists that have less than a 100 associated artworks
query = f"""
MATCH (a:Artwork)-[:createdBy]->(b:Artist)
WHERE ID(b) IN {artists}
WITH b, count(a) as count
WHERE count >= 100
RETURN ID(b) as id
"""
result = graph.query(query)
result

[{'id': 59},
 {'id': 60},
 {'id': 64},
 {'id': 73},
 {'id': 85},
 {'id': 90},
 {'id': 91},
 {'id': 99},
 {'id': 111},
 {'id': 123},
 {'id': 135},
 {'id': 138},
 {'id': 163},
 {'id': 176},
 {'id': 195},
 {'id': 199},
 {'id': 211},
 {'id': 217},
 {'id': 222},
 {'id': 229},
 {'id': 231},
 {'id': 235},
 {'id': 236},
 {'id': 247},
 {'id': 258},
 {'id': 272},
 {'id': 278},
 {'id': 284},
 {'id': 296},
 {'id': 300},
 {'id': 311},
 {'id': 312},
 {'id': 313},
 {'id': 316},
 {'id': 331},
 {'id': 336},
 {'id': 353},
 {'id': 356},
 {'id': 358},
 {'id': 370},
 {'id': 389},
 {'id': 391},
 {'id': 455},
 {'id': 456},
 {'id': 457},
 {'id': 461},
 {'id': 462},
 {'id': 481},
 {'id': 516},
 {'id': 518},
 {'id': 527},
 {'id': 551},
 {'id': 553},
 {'id': 566},
 {'id': 568},
 {'id': 595},
 {'id': 597},
 {'id': 599},
 {'id': 613},
 {'id': 618},
 {'id': 622},
 {'id': 659},
 {'id': 661},
 {'id': 666},
 {'id': 670},
 {'id': 681},
 {'id': 683},
 {'id': 691},
 {'id': 693},
 {'id': 716},
 {'id': 725},
 {'id': 728},


In [36]:
az = set(row["id"] for row in result)
artists = [artist for artist in artists if artist in az]

In [37]:
artists2ids = {}
for artist in artists:
    try:
        artists2ids[artist] = triples_factory.entity_to_id[str(artist)]
    except KeyError:
        continue

In [38]:
missing_entities2indexes = {entity: entities2indexes[entity] for entity in missing_entities}

In [39]:
artwork_embs = features[list(missing_entities2indexes.values())]
artwork_embs = nn.functional.normalize(artwork_embs, p=2, dim=1)
artwork_embs

tensor[17472, 768] n=13418496 (51Mb) x∈[-0.520, 0.561] μ=0.001 σ=0.036

In [40]:
artist_embs = model.entity_representations[0](indices=torch.tensor(list(artists2ids.values())))
artist_embs

tensor[242, 768] n=185856 (0.7Mb) x∈[-0.150, 0.137] μ=-0.001 σ=0.036 grad ViewBackward0 cuda:0

In [41]:
has_artist = triples_factory.relation_to_id["createdBy"]
has_artist

3

In [42]:
has_artist_emb_0 = model.relation_representations[0](indices=torch.tensor([has_artist]))
has_artist_emb_1 = model.relation_representations[1](indices=torch.tensor([has_artist]))
has_artist_emb = [has_artist_emb_0, has_artist_emb_1]
has_artist_emb

[tensor[1, 30] x∈[-0.256, 0.268] μ=-0.021 σ=0.184 grad ViewBackward0 cuda:0,
 tensor[1, 768, 30] n=23040 (90Kb) x∈[-0.860, 0.898] μ=0.002 σ=0.208 grad ViewBackward0 cuda:0]

In [43]:
# sum has_artist_emb to all artwork_embs rows
artwork_embs = artwork_embs.cuda()
artwork_embs

tensor[17472, 768] n=13418496 (51Mb) x∈[-0.520, 0.561] μ=0.001 σ=0.036 cuda:0

In [44]:
model

TransR(
  (loss): MarginRankingLoss(
    (margin_activation): ReLU()
  )
  (interaction): TransRInteraction()
  (entity_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(117147, 768)
    )
  )
  (relation_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(17, 30)
    )
    (1): Embedding(
      (_embeddings): Embedding(17, 23040)
    )
  )
  (weight_regularizers): ModuleList()
)

In [45]:
scores = []
with torch.no_grad():
    for artist_emb in artist_embs:
        # ss = model.interaction(**model.interaction._prepare_hrt_for_functional(artwork_embs, hge, ge))
        ge = artist_emb.unsqueeze(0).repeat(artwork_embs.size(0), 1)
        ss = model.interaction(artwork_embs, has_artist_emb, ge)
        scores.append(ss)

In [46]:
scores = torch.stack(scores, dim=1)
best = torch.argmax(scores, dim=1)
best

tensor[17472] i64 0.1Mb x∈[0, 241] μ=120.110 σ=74.152 cuda:0

In [47]:
best.v

tensor[17472] i64 0.1Mb x∈[0, 241] μ=120.110 σ=74.152 cuda:0
tensor([240, 240,  76,  ..., 231,  58,  23], device='cuda:0')

In [48]:
classes = best

In [49]:
artists2ids.values()
ids2artists = {v: k for k, v in artists2ids.items()}

In [50]:
artists, artists2ids, artworks2artists, ids2artists

([59,
  60,
  64,
  73,
  85,
  90,
  91,
  99,
  111,
  123,
  135,
  138,
  163,
  176,
  195,
  199,
  211,
  217,
  222,
  229,
  231,
  235,
  236,
  247,
  258,
  272,
  278,
  284,
  296,
  300,
  311,
  312,
  313,
  316,
  331,
  336,
  353,
  356,
  358,
  370,
  389,
  391,
  455,
  456,
  457,
  461,
  462,
  481,
  516,
  518,
  527,
  551,
  553,
  566,
  568,
  595,
  597,
  599,
  613,
  618,
  622,
  659,
  661,
  666,
  670,
  681,
  683,
  691,
  693,
  716,
  725,
  728,
  729,
  744,
  757,
  759,
  777,
  798,
  804,
  836,
  853,
  859,
  869,
  871,
  874,
  902,
  932,
  937,
  946,
  959,
  979,
  986,
  993,
  1008,
  1010,
  1011,
  1032,
  1035,
  1037,
  1049,
  1050,
  1057,
  1065,
  1070,
  1074,
  1079,
  1080,
  1087,
  1090,
  1094,
  1099,
  1109,
  1120,
  1145,
  1173,
  1174,
  1175,
  1192,
  1193,
  1199,
  1210,
  1266,
  1281,
  1293,
  1296,
  1300,
  1301,
  1305,
  1311,
  1323,
  1333,
  1349,
  1354,
  1355,
  1356,
  1381,
  1383,
  139

In [51]:
ids2artists

{78235: 59,
 79211: 60,
 83060: 64,
 91591: 73,
 103004: 85,
 107691: 90,
 108626: 91,
 116215: 99,
 10392: 111,
 21743: 123,
 34587: 135,
 34920: 138,
 37321: 163,
 38573: 176,
 40387: 195,
 40783: 199,
 41947: 211,
 42525: 217,
 43012: 222,
 43688: 229,
 43884: 231,
 44268: 235,
 44365: 236,
 45428: 247,
 46486: 258,
 47828: 272,
 48381: 278,
 48957: 284,
 50113: 296,
 50507: 300,
 51547: 311,
 51638: 312,
 51733: 313,
 52034: 316,
 53482: 331,
 53962: 336,
 55616: 353,
 55908: 356,
 56098: 358,
 57251: 370,
 59065: 389,
 59257: 391,
 65391: 455,
 65483: 456,
 65576: 457,
 65947: 461,
 66039: 462,
 67824: 481,
 71182: 516,
 71361: 518,
 72213: 527,
 74502: 551,
 74685: 553,
 75949: 566,
 76141: 568,
 78729: 595,
 78923: 597,
 79114: 599,
 80472: 613,
 80958: 618,
 81346: 622,
 84881: 659,
 85069: 661,
 85530: 666,
 85920: 670,
 86969: 681,
 87151: 683,
 87900: 691,
 88095: 693,
 90247: 716,
 91116: 725,
 91399: 728,
 91497: 729,
 92911: 744,
 94129: 757,
 94320: 759,
 96047: 777,
 98

In [52]:
idx2ids = {i: v for i, v in enumerate(artists2ids.values())}

In [54]:
# compute accuracy using artwork_artist
correct = 0
total = 0
for i, entity in enumerate(missing_entities):
    cs = classes[i]
    cs = ids2artists[idx2ids[cs.item()]]
    c = artworks2artists[int(entity)][0]
    if c not in artists:
        continue
    total += 1
    if c == cs:
        correct += 1
correct

2748

In [56]:
total

9146