In [1]:
import numpy as np
import pandas as pd
from purple_flamingo.datasets import CNSDataset

from purple_flamingo.instruct.to_filter_a_dataset import PCASorter
from purple_flamingo.instruct.to_filter_a_dataset import ImageLabelingApp
from purple_flamingo.instruct.to_filter_a_dataset import get_resnet_features

In [2]:
np.random.seed(314)

In [3]:
dataset = CNSDataset("/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract",
                     "dataset.json")

In [4]:
indeces = np.arange(len(dataset))
np.random.shuffle(indeces)

In [5]:
images_to_label = 500

In [6]:
labeled_images = []
labeled_captions = []
for i in indeces[:images_to_label]:
    labeled_images.append(dataset[i][0])
    labeled_captions.append(dataset[i][1])

In [7]:
transformed_labeled_images = np.array([get_resnet_features(i) for i in labeled_images]).reshape(-1, 1000)

In [8]:
# Initialize the PCA sorter
pca_sorter = PCASorter()
pca_sorter.fit(transformed_labeled_images)

# Sort images and captions
sorted_labeled_images = pca_sorter.transform(labeled_images)
sorted_labeled_captions = pca_sorter.transform(labeled_captions)
sorted_transformed_images = pca_sorter.transform(transformed_labeled_images)

In [9]:
# Initialize the app
app = ImageLabelingApp(sorted_labeled_images, sorted_labeled_captions)

# Start the app
app.start()

Output()

In [10]:
labels = np.array(app.get_results()).reshape(images_to_label, 1)

In [11]:
pd.DataFrame(np.concatenate([sorted_transformed_images, labels], axis=1)).to_csv('./transformed_labeled_images.csv')