In [51]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [52]:
# @title Download the data

import os, requests, tarfile

fnames = ["kay_labels.npy", "kay_labels_val.npy", "kay_images.npz"]
urls = ["https://osf.io/r638s/download",
        "https://osf.io/yqb3e/download",
        "https://osf.io/ymnjv/download"]

for fname, url in zip(fnames, urls):
  if not os.path.isfile(fname):
    try:
      r = requests.get(url)
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        print(f"Downloading {fname}...")
        with open(fname, "wb") as fid:
          fid.write(r.content)
        print(f"Download {fname} completed!")

In [53]:
with np.load(fname) as dobj:
  dat = dict(**dobj)
labels = np.load('kay_labels.npy')
val_labels = np.load('kay_labels_val.npy')

In [54]:
print(dat.keys())

dict_keys(['stimuli', 'stimuli_test', 'responses', 'responses_test', 'roi', 'roi_names'])


`dat` has the following fields:  
- `stimuli`: stim x i x j array of grayscale stimulus images
- `stimuli_test`: stim x i x j array of grayscale stimulus images in the test set  
- `responses`: stim x voxel array of z-scored BOLD response amplitude
- `responses_test`:  stim x voxel array of z-scored BOLD response amplitude in the test set  
- `roi`: array of voxel labels
- `roi_names`: array of names corresponding to voxel labels

In [55]:
print(dat["stimuli"].shape)

(1750, 128, 128)


In [56]:
print(dat["responses"].shape)

(1750, 8428)


This is the number of voxels in each ROI. Note that `"Other"` voxels have been removed from this version of the dataset:

In [57]:
dict(zip(dat["roi_names"], np.bincount(dat["roi"])))
X = dat["stimuli"]

In [63]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display, clear_output
import io
import json
from google.colab import drive
import os
import pandas as pd

class ColabImageLabelingTool:
    def __init__(self, X, batch_size=25):
        self.images = X
        self.batch_size = batch_size
        self.num_batches = (len(X) + batch_size - 1) // batch_size
        self.current_batch = 0
        self.current_index = 0
        self.labels = self.load_progress()
        self.label_options = ["not_face", "face"]
        self.create_widgets()

    def create_widgets(self):
        self.output = widgets.Output()
        self.img_widget = widgets.Image(format='png')
        self.status_label = widgets.Label()
        self.progress_bar = widgets.FloatProgress(min=0, max=1, description='Progress:')

        self.category_buttons = [widgets.Button(description=label, button_style='info') for label in self.label_options]
        for btn, label in zip(self.category_buttons, self.label_options):
            btn.on_click(lambda b, l=label: self.label_image(l))

        self.save_button = widgets.Button(description='Save Progress', button_style='success')
        self.save_button.on_click(self.save_progress)

        self.back_button = widgets.Button(description='Back', button_style='warning')
        self.back_button.on_click(self.go_back)

        self.delete_button = widgets.Button(description='Delete Progress', button_style='danger')
        self.delete_button.on_click(self.confirm_delete)

        self.confirm_delete_button = widgets.Button(description='Confirm Delete', button_style='danger', layout=widgets.Layout(display='none'))
        self.confirm_delete_button.on_click(self.delete_progress)

        self.cancel_delete_button = widgets.Button(description='Cancel', button_style='info', layout=widgets.Layout(display='none'))
        self.cancel_delete_button.on_click(self.cancel_delete)

        self.widget = widgets.VBox([
            self.img_widget,
            widgets.HBox(self.category_buttons),
            widgets.HBox([self.back_button, self.save_button, self.delete_button]),
            widgets.HBox([self.confirm_delete_button, self.cancel_delete_button]),
            self.status_label,
            self.progress_bar,
            self.output
        ])

        display(self.widget)
        self.show_image()

    def show_image(self):
        img = self.images[self.current_batch * self.batch_size + self.current_index]

        plt.figure(figsize=(4, 2))
        plt.imshow(img, cmap='gray')
        plt.xticks([])
        plt.yticks([])

        img_buf = io.BytesIO()
        plt.savefig(img_buf, format='png')
        img_buf.seek(0)

        self.img_widget.value = img_buf.getvalue()
        plt.close()

        overall_index = self.current_batch * self.batch_size + self.current_index
        self.status_label.value = f'Batch {self.current_batch + 1}/{self.num_batches}, Image {self.current_index + 1}/{self.batch_size} (Overall: {overall_index + 1}/{len(self.images)})'
        self.progress_bar.value = (overall_index + 1) / len(self.images)

    def label_image(self, label):
        overall_index = self.current_batch * self.batch_size + self.current_index
        self.labels[overall_index] = label
        self.animate_button(label)
        self.next_image()

    def animate_button(self, label):
        for btn in self.category_buttons:
            if btn.description == label:
                btn.button_style = 'success'
            else:
                btn.button_style = 'info'

    def next_image(self):
        self.current_index += 1
        if self.current_index >= self.batch_size or self.current_batch * self.batch_size + self.current_index >= len(self.images):
            self.current_batch += 1
            self.current_index = 0

        if self.current_batch >= self.num_batches:
            with self.output:
                clear_output()
                print("All images labeled!")
                print("\nLabel counts:")
                for label in self.label_options + ["unlabeled"]:
                    count = self.labels.count(label)
                    print(f"{label}: {count}")
            return

        self.show_image()

    def go_back(self, button):
        if self.current_index > 0:
            self.current_index -= 1
        elif self.current_batch > 0:
            self.current_batch -= 1
            self.current_index = self.batch_size - 1
        self.show_image()

    def save_progress(self, button=None):
        drive.mount('/content/drive')
        save_path = '/content/drive/MyDrive/image_labeling_progress.json'
        with open(save_path, 'w') as f:
            json.dump(self.labels, f)
        print(f"Progress saved to {save_path}")
        # Continue from where it left off
        self.show_image()

    def load_progress(self):
        drive.mount('/content/drive')
        save_path = '/content/drive/MyDrive/image_labeling_progress.json'
        if os.path.exists(save_path):
            with open(save_path, 'r') as f:
                labels = json.load(f)
            print(f"Progress loaded from {save_path}")
            labeled_count = sum(1 for label in labels if label != "unlabeled")
            self.current_batch = labeled_count // self.batch_size
            self.current_index = labeled_count % self.batch_size
            return labels
        return ["unlabeled"] * len(self.images)

    def confirm_delete(self, button):
        self.status_label.value = "Are you sure you want to delete all progress? This action cannot be undone."
        self.confirm_delete_button.layout.display = 'block'
        self.cancel_delete_button.layout.display = 'block'
        self.delete_button.layout.display = 'none'

    def cancel_delete(self, button):
        self.status_label.value = "Deletion cancelled."
        self.confirm_delete_button.layout.display = 'none'
        self.cancel_delete_button.layout.display = 'none'
        self.delete_button.layout.display = 'block'
        self.show_image()  # Restore the original status label

    def delete_progress(self, button):
        drive.mount('/content/drive')
        save_path = '/content/drive/MyDrive/image_labeling_progress.json'
        if os.path.exists(save_path):
            os.remove(save_path)
            self.status_label.value = f"Progress file deleted: {save_path}"
            self.labels = ["unlabeled"] * len(self.images)
            self.current_batch = 0
            self.current_index = 0
        else:
            self.status_label.value = "No progress file found to delete."

        self.confirm_delete_button.layout.display = 'none'
        self.cancel_delete_button.layout.display = 'none'
        self.delete_button.layout.display = 'block'
        self.show_image()


In [59]:

def display_json_progress():
    drive.mount('/content/drive', force_remount=True)
    json_path = '/content/drive/MyDrive/image_labeling_progress.json'

    try:
        with open(json_path, 'r') as f:
            labels = json.load(f)

        print(f"Total labels: {len(labels)}")

        df = pd.DataFrame({"Label": labels})

        print("\nLabel counts:")
        print(df['Label'].value_counts())

        print("\nFirst 10 labels:")
        print(df.head(10))

        print("\nLabel statistics:")
        unlabeled = df['Label'].eq("unlabeled").sum()
        labeled = len(labels) - unlabeled
        print(f"Labeled: {labeled} ({labeled/len(labels):.2%})")
        print(f"Unlabeled: {unlabeled} ({unlabeled/len(labels):.2%})")

    except FileNotFoundError:
        print(f"No progress file found at {json_path}")
    except json.JSONDecodeError:
        print(f"Error decoding JSON from {json_path}. The file may be corrupted.")
    except Exception as e:
        print(f"An error occurred: {str(e)}")



In [64]:
tool = ColabImageLabelingTool(X)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


VBox(children=(Image(value=b''), HBox(children=(Button(button_style='info', description='not_face', style=Butt…

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Progress saved to /content/drive/MyDrive/image_labeling_progress.json


In [61]:
# Run the function to display the progress
display_json_progress()

Mounted at /content/drive
Total labels: 1750

Label counts:
Label
unlabeled    1745
nonanimal       3
animal          2
Name: count, dtype: int64

First 10 labels:
       Label
0     animal
1  nonanimal
2  nonanimal
3     animal
4  nonanimal
5  unlabeled
6  unlabeled
7  unlabeled
8  unlabeled
9  unlabeled

Label statistics:
Labeled: 5 (0.29%)
Unlabeled: 1745 (99.71%)
