# This notebook generates CLIP embeddings for News Nav photos on disk

To convert to Python file:
`jupyter nbconvert --to script generate_CLIP_embeddings.ipynb`

In [None]:
### FOR CLIP EMBEDDINGS
from sentence_transformers import SentenceTransformer, util
from IPython.display import Image as IPImage
from IPython.display import display
import PIL.Image
import torch

import pandas as pd
import numpy as np
import glob
import math
import time
import csv
import sys
import os

In [None]:
# function that splits a list into n chunks for multiprocessing
def chunk(file_list, n_chunks):
    
    # make chunks of files to be distributed across processes
    chunks = []
    chunk_size = math.ceil(float(len(file_list))/n_chunks)
    for i in range(0, n_chunks-1):
        chunks.append(file_list[i*chunk_size:(i+1)*chunk_size])
    chunks.append(file_list[(n_chunks-1)*chunk_size:])
    
    return chunks

In [None]:
# import CLIP model (https://huggingface.co/sentence-transformers/clip-ViT-B-32)
model = SentenceTransformer('clip-ViT-B-32')
print("Loaded model!")

In [None]:
# this function generates CLIP embeddings
def generate_embeddings(year, file_list):

    # iterate through PDFs
    for i in range(0, len(file_list)):
    
        local_fp = file_list[i]
                
        filepath = local_fp.split('/')[-1]
        
        npy_filepath =  "embeddings/" + str(year) + "_photos/" + filepath.replace(".jpg", ".npy")
        
#         # shows image
#         display(IPImage(filename=local_fp))
        
        image = PIL.Image.open(local_fp, mode='r')
        embedding = model.encode(image)
                
        np.save(npy_filepath, np.array(embedding))
        
        if i % 1000 == 0:
            print(i)

In [None]:
# need main for setting multiprocessing start method to spawn
if __name__ == '__main__':
    
    #     files = glob.glob('./../datasets/**/*')
    
    for year in range(1913, 1923):
        print("PROCESSING YEAR: " + str(year))
        
        files = glob.glob('./../datasets/' + str(year) + "_photos/*")
        
        print("DONE GLOBBING")
    
        generate_embeddings(year, files)
    
        print("DONE EMBEDDINGS")
