In [1]:
import os
import ast
import cv2
import torch
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageStat
from torchvision.io import read_image, ImageReadMode
from torchvision.models import convnext_large, ConvNeXt_Large_Weights

from neuralpredictors.measures.np_functions import corr, fev

In [2]:
def custom_agg(series):
    array = np.array([ast.literal_eval(x) for x in series])
    array = np.mean(array, axis=0)
    return array.tolist()

In [3]:
def voc_to_yolo_bbox(bbox, w, h):
    # xmin, ymin, xmax, ymax
    x_center = ((bbox[2] + bbox[0]) / 2) / w
    y_center = ((bbox[3] + bbox[1]) / 2) / h
    width = (bbox[2] - bbox[0]) / w
    height = (bbox[3] - bbox[1]) / h
    return [x_center, y_center, width, height]

In [4]:
datasets = ["pretrain_21067-10-18", "pretrain_23343-5-17", "pretrain_22846-10-16",
            "pretrain_23656-14-22", "pretrain_23964-4-22", "sensorium_26872-17-20",
            "sensorium+_27204-5-13"]
data_keys = [key.split("_")[1] for key in datasets]

In [5]:
frame_image_id = {}
for data_key in data_keys[:5]:
    frame_image_id[data_key] = np.load(f"./dataset/pretrain_{data_key}/meta/trials/frame_image_id.npy")

In [6]:
preds_gt = {}
for data_key in data_keys[:5]:
    pred = pd.read_csv(f"../sensorium/preds_gt/{data_key}/submission_file_live_test.csv")
    gt = pd.read_csv(f"../sensorium//preds_gt/{data_key}/ground_truth_file_test.csv")
    pred = pd.merge(pred, gt, how="left", on=["trial_indices", "image_ids", "neuron_ids"])
    preds_gt[data_key] = pred

In [7]:
avg_preds_gt = preds_gt.copy()
for data_key in data_keys[:5]:
    avg_preds_gt[data_key] = avg_preds_gt[data_key].groupby("image_ids")\
        .agg({"prediction": custom_agg, 
              "responses": custom_agg,
              "neuron_ids": custom_agg})
    avg_preds_gt[data_key].reset_index(inplace=True)
    
    mean_responses = np.vstack(avg_preds_gt[data_key].responses)
    mean_predictions = np.vstack(avg_preds_gt[data_key].prediction)
    correlation = corr(mean_responses, mean_predictions, axis=1)
    avg_preds_gt[data_key]["correlation"] = correlation
    
    true_image_ids = [np.where(frame_image_id[data_key] == the_id)[0][0] for the_id in avg_preds_gt[data_key].image_ids]
    avg_preds_gt[data_key]["true_image_ids"] = true_image_ids

In [8]:
def summary(x: list):
    return [np.min(x), np.median(x), np.max(x), np.mean(x), np.std(x)]

# statistically summarize the response for each image
merge_preds_gt = pd.DataFrame()
for data_key in data_keys[:5]:
    df = preds_gt[data_key].copy()
    df["dataset"] = data_key
    merge_preds_gt = pd.concat([merge_preds_gt, df], ignore_index=True)
    
response_summary = [summary(ast.literal_eval(x)) for x in merge_preds_gt.responses.values]

In [9]:
cols = ["response_min", "response_median", "response_max", "response_mean", "response_std"]
response_summary = np.array(response_summary)
response_summary = pd.DataFrame(response_summary, columns=cols)
pd.concat([merge_preds_gt, response_summary], axis=1)[["image_ids", "dataset"]+cols].to_csv("image_response_summary.csv", index=False)

In [10]:
# statistically summarize the response for each image after merging the repeats
merge_avg_preds_gt = pd.DataFrame()
for data_key in data_keys[:5]:
    df = avg_preds_gt[data_key].copy()
    df["dataset"] = data_key
    merge_avg_preds_gt = pd.concat([merge_avg_preds_gt, df], ignore_index=True)
    
cols = ["response_min", "response_median", "response_max", "response_mean", "response_std"]
response_summary = np.array([summary(x) for x in merge_avg_preds_gt.responses.values])
response_summary = pd.DataFrame(response_summary, columns=cols)

pd.concat([merge_avg_preds_gt, response_summary], axis=1)[["image_ids", "dataset"]+cols].to_csv("image_mergeRep_response_summary.csv", index=False)

cols = ["preds_min", "preds_median", "preds_max", "preds_mean", "preds_std"]
preds_summary = np.array([summary(x) for x in merge_avg_preds_gt.prediction.values])
preds_summary = pd.DataFrame(preds_summary, columns=cols)

pd.concat([merge_avg_preds_gt, preds_summary], axis=1)[["image_ids", "dataset"]+cols].to_csv("image_mergeRep_preds_summary.csv", index=False)

In [13]:
# correlation for the test images in each dataset
merge_preds_gt = pd.DataFrame()
for data_key in data_keys[:5]:
    df = avg_preds_gt[data_key].copy()
    df["dataset"] = data_key
    merge_preds_gt = pd.concat([merge_preds_gt, df], ignore_index=True)
    
merge_preds_gt = merge_preds_gt[["image_ids", "correlation", "dataset"]]
merge_preds_gt.to_csv("outputs_model_with_image/image_performance.csv", index=False)

In [14]:
# calculate image complexity use the spatial information
# https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6603194
SI_means = []
for trial_id in avg_preds_gt["21067-10-18"].image_ids:
    img = cv2.imread(f"./outputs_model_with_image/images/{trial_id}.png", cv2.IMREAD_GRAYSCALE)
    sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=5)
    sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=5)

    # Calculate the gradient magnitude
    SI_r = np.sqrt(sobelx**2 + sobely**2)
    SI_means.append(np.mean(SI_r))
    
brightness = []
contrast = []
for trial_id in avg_preds_gt["21067-10-18"].image_ids:
    img = Image.open(f"./test_images/{trial_id}.png")
    stat = ImageStat.Stat(img)
    brightness.append(stat.mean[0])
    contrast.append(stat.stddev[0])
    
    # brightness.append(calculate_brightness(img))
    
image_complexity = pd.DataFrame({"image_ids": avg_preds_gt["21067-10-18"].image_ids,
                                 "Complexity": SI_means,
                                 "Brightness": brightness,
                                 "Contrast": contrast})
image_complexity.to_csv("./outputs_model_with_image/image_complexity.csv", index=False)

In [23]:
# generate category info for the test images
# still need manual input...
weights = ConvNeXt_Large_Weights.DEFAULT
model = convnext_large(weights=weights)
model.eval()
preprocess = weights.transforms()

categories = []
for image_id in avg_preds_gt["21067-10-18"].true_image_ids:
    img = read_image(f"./images_png/pretrain_21067-10-18/data/images/{image_id}.png", 
                    ImageReadMode.RGB)

    batch = preprocess(img).unsqueeze(0)

    prediction = model(batch).squeeze(0).softmax(0)
    class_ids = torch.topk(prediction, k=3).indices
    scores = torch.topk(prediction, k=3).values
    category_name = [weights.meta["categories"][class_id] for class_id in class_ids]
    categories.append(", ".join(category_name))



In [24]:
image_category = pd.DataFrame({"trial_ids": avg_preds_gt["21067-10-18"].image_ids,
                               "inferred_category": categories})
image_category.to_csv("./outputs_model_with_image/category.csv", index=False)