# Exploring Art across Culture and Medium with Fast, Conditional, k-Nearest Neighbors

<img src="https://mmlspark.blob.core.windows.net/graphics/art/cross_cultural_matches.jpg"  width="600"/>

This notebook serves as a guideline for match-finding via k-nearest-neighbors. In the code below, we will set up code that allows queries involving cultures and mediums of art amassed from the Metropolitan Museum of Art in NYC and the Rijksmuseum in Amsterdam.

In [18]:
%%pyspark
import mmlspark
mmlspark.__spark_package_version__


StatementMeta(SamplePool, 33, 1, Finished, Available)

'1.0.0-rc3-6-a862d6b1-SNAPSHOT'

### Overview of the BallTree
The structure functioning behind the kNN model is a BallTree, which is a recursive binary tree where each node (or "ball") contains a partition of the points of data to be queried. Building a BallTree involves assigning data points to the "ball" whose center they are closest to (with respect to a certain specified feature), resulting in a structure that allows binary-tree-like traversal and lends itself to finding k-nearest neighbors at a BallTree leaf.

#### Setup
Import necessary Python libraries and prepare dataset.

In [19]:
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from mmlspark.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt

StatementMeta(SamplePool, 33, 2, Finished, Available)



Our dataset comes from a table containing artwork information from both the Met and Rijks museums. The schema is as follows:

- **id**: A unique identifier for a piece of art
  - Sample Met id: *388395* 
  - Sample Rijks id: *SK-A-2344* 
- **Title**: Art piece title, as written in the museum's database
- **Artist**: Art piece artist, as written in the museum's database
- **Thumbnail_Url**: Location of a JPEG thumbnail of the art piece
- **Image_Url** Location of an image of the art piece hosted on the Met/Rijks website
- **Culture**: Category of culture that the art piece falls under
  - Sample culture categories: *latin american*, *egyptian*, etc
- **Classification**: Category of medium that the art piece falls under
  - Sample medium categories: *woodwork*, *paintings*, etc
- **Museum_Page**: Link to the work of art on the Met/Rijks website
- **Norm_Features**: Embedding of the art piece image
- **Museum**: Specifies which museum the piece originated from

In [20]:
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet("wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet")
display(df.drop("Norm_Features"))

StatementMeta(SamplePool, 33, 3, Finished, Available)

SynapseWidget(Synapse.DataFrame, 279e428b-c4c8-43d6-a75a-39e9582a5266)

#### Define categories to be queried on
We will be using two kNN models: one for culture, and one for medium. The categories for each grouping are defined below.

In [21]:
from pyspark.sql.types import BooleanType

#mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork", 
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ['paintings', 'glass', 'ceramics']

#cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)', 
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek', 
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian', 
#            'spanish', 'swiss', 'various']

cultures = ['japanese', 'american', 'african (general)']

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(udf(lambda medium, culture, id_val: (medium in medium_set) or (culture in culture_set) or (id_val in selected_ids), BooleanType())("Classification", "Culture", "id"))

small_df.count()

StatementMeta(SamplePool, 33, 4, Finished, Available)

85709

### Define and fit ConditionalKNN models
Below, we create ConditionalKNN models for both the medium and culture columns; each model takes in an output column, features column (feature vector), values column (cell values under the output column), and label column (the quality that the respective KNN is conditioned on).

In [22]:
medium_cknn = (ConditionalKNN()
  .setOutputCol("Matches")
  .setFeaturesCol("Norm_Features")
  .setValuesCol("Thumbnail_Url")
  .setLabelCol("Classification")
  .fit(small_df))

StatementMeta(SamplePool, 33, 5, Finished, Available)



In [23]:
culture_cknn = (ConditionalKNN()
  .setOutputCol("Matches")
  .setFeaturesCol("Norm_Features")
  .setValuesCol("Thumbnail_Url")
  .setLabelCol("Culture")
  .fit(small_df))

StatementMeta(SamplePool, 33, 6, Finished, Available)



#### Define matching and visualizing methods

After the intial dataset and category setup, we prepare methods that will query and visualize the conditional kNN's results. 

`addMatches()` will create a Dataframe with a handful of matches per category.

In [24]:
def add_matches(classes, cknn, df):
  results = df
  for label in classes:
    results = (cknn.transform(results.withColumn("conditioner", array(lit(label))))
                 .withColumnRenamed("Matches", "Matches_{}".format(label)))
  return results

StatementMeta(SamplePool, 33, 7, Finished, Available)



### Putting it all together
Below, we take in the data, CKNN models, the art id values to query on, and the file path to save the output visualization to. The medium and culture models were previously trained and loaded.

In [25]:
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png

is_nice_obj = udf(lambda obj: obj in selected_ids, BooleanType())
test_df = small_df.where(is_nice_obj("id"))
  
results_df_medium = add_matches(mediums, medium_cknn, test_df)
results_df_culture = add_matches(cultures, culture_cknn, results_df_medium)

results = results_df_culture.collect()
  
original_urls = [row["Thumbnail_Url"] for row in results]
  
culture_urls = [ [row["Matches_{}".format(label)][0]["value"] for row in results] for label in cultures]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
  
medium_urls = [ [row["Matches_{}".format(label)][0]["value"] for row in results] for label in mediums]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]


StatementMeta(SamplePool, 33, 8, Finished, Available)



### Demo
The following cell performs batched queries given desired image IDs and a filename to save the visualization.


<img src="https://mmlspark.blob.core.windows.net/graphics/art/cross_cultural_matches.jpg"  width="600"/>

In [26]:
spark.stop()

StatementMeta(SamplePool, 33, 9, Finished, Available)

