In [1]:
import sys
sys.path.append('..')

import os
import pickle
import zipfile
from tqdm import tqdm
from concept import ConceptModel
from sentence_transformers import util
from PIL import Image
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
import numpy as np

## Prepare Images

In [2]:
# Next, we get about 25k images from Unsplash 
img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)
    
    photo_filename = 'unsplash-25k-photos.zip'
    if not os.path.exists(photo_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)
        
    #Extract all images
    with zipfile.ZipFile(photo_filename, 'r') as zf:
        for member in tqdm(zf.infolist(), desc='Extracting'):
            zf.extract(member, img_folder)

## Use pre-computed embeddings

In [3]:
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
if not os.path.exists(emb_filename):   #Download dataset if does not exist
    util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)
    
with open(emb_filename, 'rb') as fIn:
    img_names, img_embeddings = pickle.load(fIn)  

In [4]:
images = [Image.open("photos/"+filepath) for filepath in tqdm(img_names[:5000])]
image_names = img_names[:5000]
image_embeddings = img_embeddings[:5000]

100%|████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 2675.67it/s]


## Extract docs

In [9]:
docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data']
vectorizer = TfidfVectorizer(ngram_range=(1, 2)).fit(docs)
words = vectorizer.get_feature_names()
words = [words[index] for index in np.argpartition(vectorizer.idf_, -50_000)[-50_000:]]

## Concept Modeling

In [11]:
concept_model = ConceptModel()
concepts = concept_model.fit_transform(images=images, 
                                       docs=docs,
                                       image_names=image_names, 
                                       image_embeddings=image_embeddings)

Batches:   0%|          | 0/4201 [00:00<?, ?it/s]

In [17]:
concept_model.visualize_concepts()