In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# standard python packages
import os, sys
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
import matplotlib.image as mpimg
from sklearn import datasets, linear_model
from tqdm.notebook import tqdm
import random
import math
import json

from create_toybrains import ToyBrainsData
from utils.vizutils import *

## Misl / ReadMe figures

##### Image attributes plot for the ReadMe

In [None]:
plt.rcParams['axes.facecolor']='black'
plt.rcParams['savefig.facecolor']='black'
fs = 12
atrs = sorted([col for col in df.filter(regex='^gen_').columns if 'shape' not in col])
# drop all shapes except one to show
atrs_shape = [col for col in df.filter(regex='^gen_').columns if ('botr' in col)]
atrs = atrs + atrs_shape
# f, axes = plt.subplots(len(atrs), 1, figsize=(7,1.3*len(atrs)), constrained_layout=True)
f = plt.figure(constrained_layout=True, figsize=(7,1.3*len(atrs)))
f.suptitle(f"Modifying different image attributes:", 
           fontsize=fs+2, ha='right', x=0.1, fontweight='heavy')
# define each subplots row as a subfigure and set separate subtitles
subfigs = f.subfigures(len(atrs), 1)

for i, subfig in enumerate(subfigs):
    
    atr = atrs[i]
    atr_vals = df[atr].sort_values().unique()
    if len(atr_vals)>7: # if more than 7 then sample the least best and some values in between
        atr_vals = [atr_vals[0]] + np.sort(np.random.choice(atr_vals[1:-2], 7-2)).tolist() + [atr_vals[-1]]
    # print(atr)
    # color = 'darkred' if 'brain' in atr else ('darkgreen' if 'shape' in atr else 'darkblue')
    subfig.suptitle(atr.replace('gen_',''), 
                    fontsize=fs, color='r', ha='right', x=0.12, fontweight='heavy')
    axes_row = subfig.subplots(nrows=1, ncols=7)
    
    # organize the attribute sorting order so that  the images are similar to each other within each row
    col_order = atrs[:]
    related_atrs = [col for col in col_order if (atr.split('_')[1] in col) and (col!=atr)]
    col_order = [atr] + related_atrs + [c for c in col_order if c not in (related_atrs + [atrs])]
    df_sorted = df.sort_values(by=col_order, axis=0)
    
    for j, ax in enumerate(axes_row):
        if j<len(atr_vals):
            atr_val = atr_vals[j]
            sample = df_sorted.loc[df_sorted[atr]==atr_val].iloc[0]
            subID = f"{sample.name:05}"
            # print(subID) 
            img = mpimg.imread(f"toybrains/images/{subID}.jpg")
            ax.imshow(img)
            if isinstance(atr_val, float): atr_val = int(atr_val)
            if isinstance(atr_val, str): atr_val = atr_val.split('-')[1]
            ax.set_title(f"= {atr_val}", fontsize=fs-4, ha='center')
        
        ax.axis("off")

# plt.savefig("image_attrs.png", bbox_inches='tight')
plt.show()

In [None]:
plt.rcParams['axes.facecolor']='black'
plt.rcParams['savefig.facecolor']='black'
fs = 12

atrs = sorted([col for col in df.filter(regex='^cov_').columns])
# f, axes = plt.subplots(len(atrs), 1, figsize=(7,1.3*len(atrs)), constrained_layout=True)
f = plt.figure(constrained_layout=True, figsize=(7,1.3*len(atrs)))
# f.suptitle(f"Modifying different image attributes:", 
#            fontsize=fs+2, ha='right', x=0.1, fontweight='heavy')
# define each subplots row as a subfigure and set separate subtitles
subfigs = f.subfigures(len(atrs), 1)

for i, subfig in enumerate(subfigs):
    
    atr = atrs[i]
    atr_vals = df[atr].sort_values().unique()
    if len(atr_vals)>7: # if more than 7 then sample the least best and some values in between
        atr_vals = [atr_vals[0]] + np.sort(np.random.choice(atr_vals[1:-2], 7-2)).tolist() + [atr_vals[-1]]
    # print(atr)
    # color = 'darkred' if 'brain' in atr else ('darkgreen' if 'shape' in atr else 'darkblue')
    subfig.suptitle(atr.replace('gen_',''), 
                    fontsize=fs, color='blue', ha='right', x=0.12, fontweight='heavy')
    axes_row = subfig.subplots(nrows=1, ncols=7)
    
    # organize the attribute sorting order so that  the images are similar to each other within each row
    col_order = atrs[:]
    related_atrs = [col for col in col_order if (atr.split('_')[1] in col) and (col!=atr)]
    col_order = [atr] + related_atrs + [c for c in col_order if c not in (related_atrs + [atrs])]
    df_sorted = df.sort_values(by=col_order, axis=0)
    
    for j, ax in enumerate(axes_row):
        if j<len(atr_vals):
            atr_val = atr_vals[j]
            sample = df_sorted.loc[df_sorted[atr]==atr_val].iloc[0]
            subID = f"{sample.name:05}"
            # print(subID) 
            img = mpimg.imread(f"toybrains/images/{subID}.jpg")
            ax.imshow(img)
            if isinstance(atr_val, float): atr_val = int(atr_val)
            if isinstance(atr_val, str): atr_val = atr_val.split('-')[1]
            ax.set_title(f"= {atr_val}", fontsize=fs-4, ha='center')
        
        ax.axis("off")

# plt.savefig("docs/image_attrs.png", bbox_inches='tight')
plt.show()