In [2]:
from transformers import CLIPProcessor, CLIPModel
from langchain_community.graphs import Neo4jGraph
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from safetensors.torch import save_file, load_file
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import os
import lovely_tensors as lt
from torchmetrics.classification import Accuracy

lt.monkey_patch()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
graph = Neo4jGraph()

In [4]:
query = "MATCH (a:Artist) -- (artwork:Artwork) WITH a, COUNT(artwork) AS num_artworks WHERE num_artworks > 50 RETURN a.name"
result = graph.query(query)
artists = [record['a.name'] for record in result]

In [5]:
print(graph.get_schema)

Node properties are the following:
Genre {name: STRING},Style {name: STRING, summary: STRING, wikipedia_url: STRING},Artist {dbpedia_url: STRING, printed_name: STRING, image_url: STRING, birth_date: STRING, wikipedia_url: STRING, name: STRING, biography: STRING, gender: STRING, death_date: STRING, death_place: STRING, birth_place: STRING},Media {name: STRING},Tag {name: STRING},Artwork {date: STRING, title: STRING, name: STRING, image_url: STRING, dimensions: STRING, wikidata_url: STRING, described_at_url: STRING, wikipedia_url: STRING},Movement {name: STRING},Training {name: STRING},Subject {name: STRING},Field {name: STRING},People {name: STRING},Serie {name: STRING},Period {name: STRING},Gallery {name: STRING},City {name: STRING},Country {name: STRING},Emotion {name: STRING}
Relationship properties are the following:
elicits {description: STRING, arousal: INTEGER}
The relationships are the following:
(:Artist)-[:belongsToMovement]->(:Movement),(:Artist)-[:hasSubject]->(:Subject),(:A

In [6]:
artworks = {}
for artist in artists:
    query = f"MATCH (a:Artist {{name: '{artist}'}}) -- (artwork:Artwork) RETURN artwork.name"
    result = graph.query(query)
    artworks[artist] = [record['artwork.name'] for record in result]

In [7]:
images = []
artists = []
for artist in artworks:
    for artwork in artworks[artist]:
        images.append(artwork)
        artists.append(artist)

df = pd.DataFrame({"image": images, "artist": artists})

In [8]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df["artist"])

In [9]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [10]:
query = "MATCH (artwork:Artwork) RETURN artwork.name"
result = graph.query(query)
artwork_names = [record['artwork.name'] for record in result]

In [11]:
IMG_DIR = Path("data/images")

In [12]:
if not os.path.exists("clip_features.safetensors"):
    tensors = {}
    batch_size = 128
    with torch.no_grad():
        for i in tqdm(range(0, len(artwork_names), batch_size)):
            batch = artwork_names[i:i+batch_size]
            imgs = [Image.open(IMG_DIR / img) for img in batch]
            inputs = processor(images=imgs, return_tensors="pt")
            outputs = model.get_image_features(**inputs)
            for j, img in enumerate(batch):
                tensors[img] = outputs[j]
else:
    tensors = load_file("clip_features.safetensors")

In [13]:
artists = sorted(list(set(df["artist"])))
artist_to_idx = {artist: idx for idx, artist in enumerate(artists)}
accuracy = Accuracy(task="multiclass", num_classes=len(artists))

In [14]:
prompts = [f"An image of a painting by {artist.replace('-', ' ').title()}" for artist in artists]
inputs = processor(text=prompts, return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = model.get_text_features(**inputs)

In [15]:
df[df["image"] == "alfred-wallis_ship-people-and-animals.jpg"]["artist"].values[0]

'alfred-wallis'

In [16]:
correct = 0

In [17]:
for row in tqdm(test_df.iterrows()):
    key = row[1]["image"]
    tensor = tensors[key]
    similarities = torch.nn.functional.cosine_similarity(tensor.unsqueeze(0), text_features, dim=1)
    pred = torch.argmax(similarities).item()
    artist = df[df["image"] == key]["artist"].values[0]
    target = artist_to_idx[artist]
    if pred == target:
        correct += 1

16236it [00:49, 329.10it/s]


In [18]:
correct / len(test_df)

0.31522542498152256

In [None]:
artist_centroids = {}

In [None]:
for artist in tqdm(artists):
    artist_df = train_df[train_df["artist"] == artist]
    images = artist_df["image"].values
    artist_tensors = torch.stack([tensors[img] for img in images])
    artist_tensor = artist_tensors.mean(dim=0)
    artist_centroids[artist] = artist_tensor

In [None]:
artist_centroids

In [None]:
correct = 0

In [None]:
for row in tqdm(test_df.iterrows()):
    key = row[1]["image"]
    tensor = tensors[key]
    similarities = torch.stack([torch.nn.functional.cosine_similarity(tensor.unsqueeze(0), centroid.unsqueeze(0), dim=1) for centroid in artist_centroids.values()])
    pred = torch.argmax(similarities).item()
    artist = df[df["image"] == key]["artist"].values[0]
    target = artist_to_idx[artist]
    if pred == target:
        correct += 1

In [None]:
correct / len(test_df)

In [19]:
correct = 0

In [22]:
train_tensors = torch.stack([tensors[img] for img in train_df["image"].values])

In [23]:
train_tensors

tensor[64944, 512] n=33251328 (0.1Gb) x∈[-10.952, 3.523] μ=0.004 σ=0.441

In [24]:
correct = 0

In [25]:
for row in tqdm(test_df.iterrows()):
    key = row[1]["image"]
    tensor = tensors[key]
    similarities = torch.nn.functional.cosine_similarity(tensor.unsqueeze(0), train_tensors, dim=1)
    pred = torch.argmax(similarities).item()
    artist = df[df["image"] == key]["artist"].values[0]
    target = artist_to_idx[artist]
    artist_pred = train_df.iloc[pred]["artist"]
    pred = artist_to_idx[artist_pred]
    if pred == target:
        correct += 1

16236it [19:02, 14.21it/s]


In [26]:
correct / len(test_df)

0.47610248829761026