In [1]:
import sys
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
from llavart.utils.dirutils import get_model_checkpoints_dir, get_data_dir
from llavart.data.graph.ops import get_nodes_connections, get_nodes_by_id
from llavart.models.graph_retriever.ops import load_model, entities_to_ids
from llavart.models.graph_retriever.modeling import entity_constrainer
from safetensors.torch import load_file, save_file
import lovely_tensors as lt
from langchain_community.graphs import Neo4jGraph
from pykeen.triples import TriplesFactory
from pykeen.predict import predict_target
from pykeen import models
lt.monkey_patch()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = load_model(get_model_checkpoints_dir() / "graph_retriever" / "results" / "TransR")

In [4]:
features = load_file(get_data_dir() / "images-hd_features.safetensors")
features = features["features"]

In [5]:
graph = Neo4jGraph()
query = """
MATCH (a:Artwork)
RETURN ID(a) as id, a.name as name
ORDER BY id
"""
result = graph.query(query)
artworks_id2name = {row["id"]: row["name"] for row in result}

keys = list(artworks_id2name.keys())
keys = [str(key) for key in keys]

In [6]:
triples_factory = TriplesFactory.from_path(get_data_dir() / "graph" / "graph.tsv")
training, testing, validation = triples_factory.split(
    [0.8, 0.1, 0.1], random_state=42
)

In [7]:
entities2ids, missing_entities, entities2indexes = entities_to_ids(
    entities=keys,
    triples_factory=training,
    return_missing_dict=True,
    return_index_dict=True,
)

In [8]:
artworks2styles = get_nodes_connections(graph, [int(key) for key in keys], "hasStyle")
artworks2styles

{8142: [46],
 8143: [46],
 8144: [46],
 8145: [46],
 8146: [46],
 8147: [46],
 8148: [47],
 8149: [47],
 8150: [47],
 8151: [47],
 8152: [47],
 8153: [47],
 8154: [47],
 8155: [47],
 8156: [47],
 8157: [47],
 8158: [47],
 8159: [47],
 8160: [47],
 8161: [47],
 8162: [47],
 8163: [47],
 8164: [47],
 8165: [47],
 8166: [47],
 8167: [47],
 8168: [47],
 8169: [47],
 8170: [47],
 8171: [47],
 8172: [47],
 8173: [47],
 8174: [47],
 8175: [47],
 8176: [47],
 8177: [47],
 8178: [47],
 8179: [47],
 8180: [47],
 8181: [47],
 8182: [47],
 8183: [47],
 8184: [47],
 8185: [47],
 8186: [47],
 8187: [47],
 8188: [47],
 8189: [47],
 8190: [47],
 8191: [47],
 8192: [47],
 8193: [47],
 8194: [47],
 8195: [47],
 8196: [47],
 8197: [47],
 8198: [47],
 8199: [47],
 8200: [47],
 8201: [47],
 8202: [47],
 8203: [47],
 8204: [47],
 8205: [47],
 8206: [47],
 8207: [47],
 8208: [47],
 8209: [47],
 8210: [47],
 8211: [47],
 8212: [47],
 8213: [47],
 8214: [47],
 8215: [47],
 8216: [47],
 8217: [47],
 8218: [47],

In [9]:
styles = list(set(value[0] for value in artworks2styles.values()))
styles

[18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49]

In [10]:
styles2ids = {style: triples_factory.entity_to_id[str(style)] for style in styles}
styles2ids

{18: 38950,
 19: 39902,
 20: 40878,
 21: 41853,
 22: 42818,
 23: 43784,
 24: 44747,
 25: 45718,
 26: 46681,
 27: 47638,
 28: 48570,
 29: 49534,
 30: 50506,
 31: 51445,
 32: 52422,
 33: 53385,
 34: 54355,
 35: 55321,
 36: 56285,
 37: 57250,
 38: 58210,
 39: 59160,
 40: 60127,
 41: 61096,
 42: 62038,
 43: 62996,
 44: 63947,
 45: 64905,
 46: 65844,
 47: 66787,
 48: 67725,
 49: 68688}

In [11]:
missing_entities2indexes = {entity: entities2indexes[entity] for entity in missing_entities}

In [12]:
artwork_embs = features[list(missing_entities2indexes.values())]
artwork_embs = nn.functional.normalize(artwork_embs, p=2, dim=1)
artwork_embs

tensor[17472, 768] n=13418496 (51Mb) x∈[-0.520, 0.561] μ=0.001 σ=0.036

In [13]:
style_embs = model.entity_representations[0](indices=torch.tensor(list(styles2ids.values())))
style_embs

tensor[32, 768] n=24576 (96Kb) x∈[-0.188, 0.212] μ=0.001 σ=0.036 grad ViewBackward0 cuda:0

In [14]:
has_style = triples_factory.relation_to_id["hasStyle"]
has_style

8

In [15]:
has_style_emb_0 = model.relation_representations[0](indices=torch.tensor([has_style]))
has_style_emb_1 = model.relation_representations[1](indices=torch.tensor([has_style]))
has_style_emb = [has_style_emb_0, has_style_emb_1]
has_style_emb

[tensor[1, 30] x∈[-0.265, 0.256] μ=-0.002 σ=0.185 grad ViewBackward0 cuda:0,
 tensor[1, 768, 30] n=23040 (90Kb) x∈[-0.843, 0.853] μ=-0.001 σ=0.208 grad ViewBackward0 cuda:0]

In [16]:
# sum has_style_emb to all artwork_embs rows
artwork_embs = artwork_embs.cuda()
artwork_embs

tensor[17472, 768] n=13418496 (51Mb) x∈[-0.520, 0.561] μ=0.001 σ=0.036 cuda:0

In [17]:
model

TransR(
  (loss): MarginRankingLoss(
    (margin_activation): ReLU()
  )
  (interaction): TransRInteraction()
  (entity_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(117147, 768)
    )
  )
  (relation_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(17, 30)
    )
    (1): Embedding(
      (_embeddings): Embedding(17, 23040)
    )
  )
  (weight_regularizers): ModuleList()
)

In [18]:
scores = []
for style_emb in style_embs:
    # ss = model.interaction(**model.interaction._prepare_hrt_for_functional(artwork_embs, hge, ge))
    ge = style_emb.unsqueeze(0).repeat(artwork_embs.size(0), 1)
    ss = model.interaction(artwork_embs, has_style_emb, ge)
    scores.append(ss)

In [19]:
scores = torch.stack(scores, dim=1)
best = torch.argmax(scores, dim=1)
best

tensor[17472] i64 0.1Mb x∈[0, 31] μ=18.025 σ=8.961 cuda:0

In [20]:
best.v

tensor[17472] i64 0.1Mb x∈[0, 31] μ=18.025 σ=8.961 cuda:0
tensor([29, 29, 29,  ...,  3,  3,  3], device='cuda:0')

In [21]:
classes = best

In [22]:
styles2ids.values()
ids2styles = {v: k for k, v in styles2ids.items()}

In [23]:
styles, styles2ids, artworks2styles, ids2styles

([18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49],
 {18: 38950,
  19: 39902,
  20: 40878,
  21: 41853,
  22: 42818,
  23: 43784,
  24: 44747,
  25: 45718,
  26: 46681,
  27: 47638,
  28: 48570,
  29: 49534,
  30: 50506,
  31: 51445,
  32: 52422,
  33: 53385,
  34: 54355,
  35: 55321,
  36: 56285,
  37: 57250,
  38: 58210,
  39: 59160,
  40: 60127,
  41: 61096,
  42: 62038,
  43: 62996,
  44: 63947,
  45: 64905,
  46: 65844,
  47: 66787,
  48: 67725,
  49: 68688},
 {8142: [46],
  8143: [46],
  8144: [46],
  8145: [46],
  8146: [46],
  8147: [46],
  8148: [47],
  8149: [47],
  8150: [47],
  8151: [47],
  8152: [47],
  8153: [47],
  8154: [47],
  8155: [47],
  8156: [47],
  8157: [47],
  8158: [47],
  8159: [47],
  8160: [47],
  8161: [47],
  8162: [47],
  8163: [47],
  8164: [47],
  8165: [47],
  8166: [47],
  8167: [47],
  8168: [47],
  8169: [47]

In [24]:
ids2styles

{38950: 18,
 39902: 19,
 40878: 20,
 41853: 21,
 42818: 22,
 43784: 23,
 44747: 24,
 45718: 25,
 46681: 26,
 47638: 27,
 48570: 28,
 49534: 29,
 50506: 30,
 51445: 31,
 52422: 32,
 53385: 33,
 54355: 34,
 55321: 35,
 56285: 36,
 57250: 37,
 58210: 38,
 59160: 39,
 60127: 40,
 61096: 41,
 62038: 42,
 62996: 43,
 63947: 44,
 64905: 45,
 65844: 46,
 66787: 47,
 67725: 48,
 68688: 49}

In [25]:
idx2ids = {i: v for i, v in enumerate(styles2ids.values())}

In [27]:
# compute accuracy using artwork_style
correct = 0
for i, entity in enumerate(missing_entities):
    cs = classes[i]
    cs = ids2styles[idx2ids[cs.item()]]
    c = artworks2styles[int(entity)][0]
    if c == cs:
        correct += 1
correct

9645

In [28]:
9645 / 17472

0.5520260989010989