In [None]:
import torch
import torch.nn as nn
import torchvision
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
import torch.utils.data
from utils import *

import numpy as np
import math
from numpy.random import default_rng
import pandas as pd
import xarray as xr
from scipy.stats import ks_2samp
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import os
import sys
from collections import namedtuple

from models.resnet_imagenet_continuoustopo import ResNet18
from torch.utils.data import Dataset, DataLoader

import spacetorch.analyses.core as core
from spacetorch.analyses.sine_gratings import (
    get_sine_tissue,
    add_sine_colorbar,
    METRIC_DICT,
    get_smoothed_map,
)
from spacetorch.datasets import DatasetRegistry
from spacetorch.datasets.ringach_2002 import load_ringach_data
from spacetorch.utils import (
    figure_utils,
    plot_utils,
    spatial_utils,
    array_utils,
    seed_str,
)
from spacetorch.maps.pinwheel_detector import PinwheelDetector
from spacetorch.maps.screenshot_maps import (
    NauhausOrientationTissue,
    NauhausSFTissue,
    LivingstoneColorTissue,
)


from utils import get_model
import argparse

from load_model import *
from spacetorch.feature_extractor import FeatureExtractor
from einops import reduce, rearrange
from spacetorch.datasets import sine_gratings
from spacetorch.maps import v1_map


# Selectivity Map

In [None]:
llc = False
model = load_model(pool_type='gaussian', kap_kernelsize=0.23, continuous=True, local_conv=False, expname='gaussian_0.23_continuous_prog_t', epoch=100, sel_range=10)
layers, layers_names = load_layers_names1(model)


dataloader = DataLoader(
        DatasetRegistry.get("SineGrating2019"), batch_size=10, shuffle=True, num_workers=1, pin_memory=True
    )

from typing import Optional, List
from spacetorch.datasets.sine_gratings import SineResponses, SineGrating2019, Metric
# constants
metrics: List[Metric] = SineGrating2019.get_metrics()
metric_dict = SineGrating2019.get_metrics(as_dict=True)
angle_metric = metric_dict["angles"]


V1_tissues = []

for layer in layers:
    # 1.Extract the features and labels
    features, _, labels = FeatureExtractor(dataloader, 32).extract_features(model, [layer], return_inputs_and_labels = True)

    # 2. Average across spatial dimension
    if llc:
        f_shape = features.shape
        avg_features = torch.from_numpy(features).reshape(f_shape[0], 9, f_shape[2]* f_shape[3])#.permute(0, 1, 3, 2, 4).reshape(f_shape[0], 1, f_shape[-2]*3 * f_shape[-2]*3).numpy()

        def car_to_polar():
            phis =[]
            rhos =[]
            ind = []
            for i in range(-56//2, 56//2):
                for j in range(-56//2, 56//2):
                    rho = np.sqrt(i**2 + j**2)
                    phi = np.arctan2(j, i)
                    phis.append(phi)
                    rhos.append(rho)

            ind = range(0, len(phis))
            return phis, rhos, ind

        phis, rhos, ind = car_to_polar()


        polar_coord = sorted(zip(phis, rhos, ind), key=lambda pair: pair[1])
        polar_coord = [p for _, _, p in sorted(polar_coord, key=lambda pair: pair[0])]

        temp = avg_features # create the tensor
        for i in range(len(polar_coord)):
            temp[:, :, i] = avg_features[:, :, polar_coord[i]]

        temp =  temp.reshape(f_shape[0], 3, 3, 56, 56).permute(0, 3, 1, 4, 2).reshape(f_shape[0], -1)

    else:

        avg_features = reduce(features, 'b c h w -> b c', 'mean')


    # 3. Unit position
    kw, kh = get_closest_factors(avg_features.shape[1])
    coord = get_coord1(kw, kh)

    # 4. Median match to human data
    avg_response = sine_gratings.SineResponses(avg_features, labels)

    # 5. Interpolated map # cache deleted
    V1_tissue = v1_map.V1Map(coord, avg_response)
    V1_tissues.append(V1_tissue)



In [None]:
def make_parameter_map2(
        tissue,
        axis,
        metric: Metric = angle_metric,
        scale_points=True,
        num_colors=None,
        final_psm=1.0,
        final_s: Optional[float] = 1,
        **kwargs,
    ):
        """
        Plots parameter map for the given metric, e.g., "angles", "sfs", "colors"
        """



        colors = tissue.get_unit_colors(metric=metric)

        if scale_points:

            if metric.name == "angles":
                selectivity = 1.0 - tissue.responses.circular_variance
            else:
                selectivity = tissue.responses.get_peak_heights(metric.name)

            selectivity = np.where(np.isnan(selectivity), 0, selectivity)
            selectivity = (selectivity - np.min(selectivity)) / np.ptp(selectivity) + 0.5
        else:
            selectivity = np.ones((len(self.positions),))

        point_sizes = (
            final_psm  *300* selectivity
        )

        # plot points
        preferences = tissue.get_preferences(metric)
        kw, kh = get_closest_factors(tissue.positions.shape[0])
        ori = rearrange(preferences, '(d1 d2)  -> d1 d2', d1=kw, d2=kh)

        mappable = axis.imshow(ori.T, cmap=metric.colormap, interpolation="nearest")

        return mappable

In [None]:

ncols = len(V1_tissues) + 1

nrows = len(METRIC_DICT)

fig, ax_rows = plt.subplots(
    ncols=ncols, nrows=nrows, figsize=(50, 9), gridspec_kw={"hspace": 0.05} #param
)

if llc:
    fig, ax_rows = plt.subplots(
        ncols=ncols, nrows=nrows, figsize=(30, 90), gridspec_kw={"hspace": 0.05} #param
    )

for ax in ax_rows.ravel():
    ax.axis("off")

# plot models
for axes, tissue in zip(ax_rows.T[1:], V1_tissues):


    # make a plot for each "metric": orientations, spatial frequencies, and colors
    for (metric_name, metric), ax in zip(METRIC_DICT.items(), axes):

        scatter_handle = make_parameter_map2(
            tissue,
            ax,
            metric = metric,
            scale_points=True,
            final_psm=48*2.5/tissue.positions.shape[0],#0.5
            rasterized=True,
            linewidths=0.03,
            edgecolor=(0, 0, 0, 0.5),
        )

        # add a colorbar if we're in the last column
        if axes[0] == ax_rows[0, -1]:
            cbar = add_sine_colorbar(fig, ax, metric, label=metric.xlabel)
            cbar.ax.tick_params(labelsize=5)
            cbar.ax.set_yticklabels(
                [metric.xticklabels[0], metric.xticklabels[-1]], rotation=90
            )
            cbar.set_label(label="", fontsize=8)


    plt.subplots_adjust(hspace=0.1, wspace=0)

plt.savefig("V1/V1_connected.pdf", format="pdf")
plt.savefig("V1/V1_connected.jpg", format="jpg")



# Pairwise Correlation

In [None]:
from numpy import diff

from spacetorch.utils.array_utils import lower_tri, midpoints_from_bin_edges
from spacetorch.utils.spatial_utils import agg_by_distance
from scipy.spatial.distance import pdist, squareform




def compute_V1(layers, dataloader, model):

    vtc_tissues = []
    for i in range(0, len(layers), 4):
        # 1.Extract the features and labels
        features, _, labels = FeatureExtractor(dataloader, 32).extract_features(model, [layers[i], layers[i+1], layers[i+2],layers[i+3]], return_inputs_and_labels = True)


        fs = []
        # 2. Average across spatial dimension
        for f in features:
            fs.append(reduce(f, 'b c h w -> b c', 'mean'))

        avg_features = np.concatenate(fs, axis=1)

        # 3. Unit position
        kw, kh = get_closest_factors(avg_features.shape[1])

        coord = get_coord0(int(kw/2), int(kh*2)) #calculate this no noise pls


        # 4. Median match to human data
        avg_response = sine_gratings.SineResponses(avg_features, labels)#
        # 5. Interpolated map # cache deleted
        vtc_tissue = v1_map.V1Map(coord, avg_response)#
        vtc_tissues.append(vtc_tissue)

    return vtc_tissues

def tissue_list(model_list):

    dataloader = DataLoader(DatasetRegistry.get("SineGrating2019"), batch_size=10, shuffle=True, num_workers=1, pin_memory=True)

    tissue_list = []
    for model in model_list:
        layers, layers_names = load_layers_names_forcontinuous2(model)
        vtc_tissues = compute_V1(layers, dataloader, model)
        tissue_list.append(vtc_tissues)

    return list(zip(*tissue_list))


def plot_correlation_between_pairwise_distances(tissues_lists):

    smoothness = []
    for tissue in tissues_lists:

        fig, ax = plt.subplots(figsize=(2.5, 1.5), gridspec_kw={"wspace": 0.6},)

        for i,t in enumerate(tissue):
            activation = t.responses._data.T

            correlation = lower_tri(np.corrcoef(activation))
            dists = lower_tri(squareform(pdist(t.positions)))

            #nan
            nan_mask = np.isnan(correlation)
            correlation = correlation[~nan_mask]
            dists = dists[~nan_mask]

            means, spreads, bin_edges = agg_by_distance(
                    dists,
                    correlation,
                    num_bins=10,
                    bin_edges = np.linspace(0, 20, 14)
                )

            midpoints = midpoints_from_bin_edges(bin_edges)/20 * 100

            if i ==0:
                line_handle = ax.plot(
                midpoints, np.zeros(13),
                label = expname_list[i],
                color = "#ccc",
                mec = 'k',
                mfc = "#ccc",
                marker='X',
                markevery=2,
                markersize = 4,
                mew=0.5,
                lw=1.2,
                )
                ax.fill_between(
                midpoints,
                np.zeros(13) - 0.02,
                np.zeros(13) + 0.02,
                alpha=0.3,
                facecolor=line_handle[0].get_color(),
                )

            y_dataset = correlation
            delta = diff(y_dataset)#/diff(x_dataset)

            std = np.std(delta)
            #ax.scatter(dists + np.random.uniform(-0.4, 0.4, len(dists)), correlation, alpha=0.7, c=['#7402E5','#ccc'][i] ,vmin=min(correlation), vmax=max(correlation))
            line_handle = ax.plot(
            midpoints, means,
            label = expname_list[i],
            color = [ "k", "k", '#7402E5'][i],
            mec = 'k',
            mfc = "#ccc",
            marker=['p', 'P', ''][i],
            markevery=2,
            markersize = 4,
            mew=0.5,
            lw=1.2,

            )

            ax.fill_between(
            midpoints,
            means - spreads,
            means + spreads,
            alpha=0.3,
            facecolor=line_handle[0].get_color(),
            )



        ax.legend().remove()
        ax.set_yticks([0,  1])


        ax.set_ylabel("Pairwise Correlation", labelpad=-1)
        ax.set_xlabel("Pairwise Distance(%)")

        plot_utils.remove_spines(ax)

        plt.savefig("corr/V1_pairwise_activation_correlation_plot.png", bbox_inches='tight')
        plt.savefig("corr/V1_pairwise_activation_correlation_plot.pdf", format='pdf', bbox_inches='tight')


In [None]:
model_list=[#load_model(pool_type='gaussian', kap_kernelsize=0.10, continuous=True, local_conv=False, expname='gaussian_0.1_continuous_prog', epoch=92, sel_range=5),
            load_model(pool_type='mexicanhat', kap_kernelsize=5.0, continuous=True, local_conv=False, expname='mexicanhat_5_continuous_prog_fixed_p', epoch=57, sel_range=5),
            load_model(pool_type='mean', kap_kernelsize=0.1, continuous=True, local_conv=False, expname='mean_01_continuous_prog_tt', epoch=100, sel_range=1),
            load_model(pool_type='gaussian', kap_kernelsize=0.23, continuous=True, local_conv=False, expname='gaussian_0.23_continuous_prog_t', epoch=100, sel_range=10)]


expname_list=[
             'Mexicanhat',
             'Mean',
             'Gaussian']

tissue_list= tissue_list(model_list)

plot_correlation_between_pairwise_distances(tissue_list)

 # Pairwise Selectivity

In [None]:

V1_tissue = V1_tissues[0]



def get_curves(
    tissue,
    metric_name: str,
    shuffle: bool = False,
    num_samples: int = 320,  # number of images
    sample_size: int = 64,
    verbose: bool = False,
):

    # compute largest possible distance given the window size #we need window as well 8 no curves 9 has
    max_dist = 20

    # create 9 bins, going from 0 (closest) to max_dist
    bin_edges = np.linspace(0, max_dist, 14) #maybe related...too much bins <= max_dist
    midpoints = array_utils.midpoints_from_bin_edges(bin_edges)

    # convenience: store arguments shared by both conditional flows into a dict
    common = {
        "num_samples": num_samples,
        "bin_edges": bin_edges,
        "shuffle": shuffle,
        "verbose": verbose,
    }


    _, curves = tissue.metric_difference_over_distance(
            distance_cutoff=max_dist, metric=metric_name, **common
        )


    # normalize midpoints to be a fraction of the hypercolumn width
    return midpoints/20 , curves


In [None]:
curve_dict = {}
for metric_name in METRIC_DICT.keys():
    curve_dict[metric_name] = {}

    all_curves = []

    distances, curves = get_curves(V1_tissue, metric_name, shuffle=False)
    _, chance_curves = get_curves(V1_tissue, metric_name, shuffle=True)
    chance_mean = np.nanmean(np.concatenate(chance_curves))
    norm_curves = [curve / chance_mean for curve in curves]


    all_curves.extend(norm_curves)

    curve_dict[metric_name] = {
            "Distances": distances * 100,  # convert to percentages
            "Curves": all_curves,
    }


fig, ax = plt.subplots(figsize=(2.5, 1.5), gridspec_kw={"wspace": 0.6})

line_handle = ax.plot(
                curve_dict["angles"]["Distances"], np.ones(13),
                color = "#ccc",
                mec = 'k',
                mfc = "#ccc",
                marker='X',
                markevery=2,
                markersize = 4,
                mew=0.5,
                lw=1.2,
                )
ax.fill_between(
                curve_dict["angles"]["Distances"],
                np.ones(13) - 0.02,
                np.ones(13) + 0.02,
                alpha=0.3,
                facecolor=line_handle[0].get_color(),
                )


for i, item in enumerated(curve_dict.items()):


    res = item[1]
    curves = np.stack(res["Curves"])

    mn_curve = np.mean(curves, axis=0)
    se = np.std(curves, axis=0)

    ax.plot(
            res["Distances"], mn_curve,
            color = [ "#5875e1", "#ea7b60", 'gold'][i],
            label= ["Orientation", "Spatial Freq", "Color"][i],
            mec = 'k',
            mfc = "#ccc",
            markevery=2,
            markersize = 4,
            mew=0.5,
            lw=2,
            )



ax.set_yticks([0,  1])

ax.legend()
ax.set_ylabel("Δ Preference", labelpad=-1)
ax.set_xlabel("Pairwise Distance(%)")
plot_utils.remove_spines(ax)

plt.savefig("V1_p.png", bbox_inches='tight')
plt.savefig("V1_p.pdf", format='pdf', bbox_inches='tight')

# Orientation Diff

In [None]:
def ori_quant(pre):
    diff = []
    for i, v in enumerate(pre):
      # for layer 1 only
      for y in range(-2,3):
        for x in range(-2,3):

          if i+y*4+x >=0 and i+y*4+x < 64 and y*4+x!=0:
            diff.append(abs(pre[i+y*4+x]-v))

    def percentage(diff):
      c = 0
      p = len(diff)
      for v in diff:
        if v <= 45:
          c = c+1
      return (c/p)*100

    weights = np.ones_like(diff)/len(diff)*100
    fig, ax = plt.subplots(figsize=(2.5,1.5))
    plt.hist(diff, weights = weights, color='#5875e1', ec='black', bins=30)
    #plt.errorbar(xval, yval, xerr = 0.4, yerr = 1)
    plt.xlabel("Δ orientation (Degree)")
    plt.ylabel("Percentage (%)")
    plot_utils.remove_spines(ax)
    #plt.title("Distribution of Orientation Difference within the Laterally-connected Area", fontweight=600)
    plt.savefig("V1/v1_q.png", bbox_inches='tight')
    plt.savefig("V1/V1_q.pdf", format="pdf", bbox_inches='tight')
    plt.clf()


In [None]:

for tissue in V1_tissues:

        pre = tissue.get_preferences(angle_metric)
        ori_quant(pre)

# Orth

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm


xlist = np.linspace(0, kw, kw)
ylist = np.linspace(0, kh, kh)


X, Y = np.meshgrid(xlist, ylist)


Z1 = V1_tissue.get_preferences(metric_dict["angles"]).reshape(kw, kh)

fig,ax=plt.subplots(figsize=(5,5))

CS = ax.contour(X, Y, Z1, cmap=cm.Blues)
plt.clabel(CS, inline=1, fontsize=5)



Z2 = V1_tissue.get_preferences(metric_dict["sfs"]).reshape(kw, kh)


CS = ax.contour(X, Y, Z2, cmap=cm.Reds)
plt.clabel(CS, inline=1, fontsize=5, colors="red")

plt.xticks([], [])
plt.yticks([], [])

plt.show()
plt.savefig("V1/V1_orth_connected.pdf", format="pdf")
plt.savefig("V1/V1_orth_connected.png", format="png")



def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))



# Locally Connected

In [None]:
llc=True
model = load_model_LC_car(pool_type='gaussian', kap_kernelsize=11, continuous=False, local_conv=False, expname='gaussian_11_LLC_car', epoch=90, sel_range=10)


layers, layers_names = load_layers_names1(model)


dataloader = DataLoader(
        DatasetRegistry.get("SineGrating2019"), batch_size=10, shuffle=True, num_workers=1, pin_memory=True
    )

from typing import Optional, List
from spacetorch.datasets.sine_gratings import SineResponses, SineGrating2019, Metric
# constants
metrics: List[Metric] = SineGrating2019.get_metrics()
metric_dict = SineGrating2019.get_metrics(as_dict=True)
angle_metric = metric_dict["angles"]

V1_tissues = []


for layer in layers[3:]:
    # 1.Extract the features and labels
    features, _, labels = FeatureExtractor(dataloader, 32).extract_features(model, [layer], return_inputs_and_labels = True)

    # 2. Average across spatial dimension

    if llc:
        f_shape = features.shape
        avg_features = torch.from_numpy(features).permute(0, 2, 3, 1).reshape(f_shape[0], f_shape[2], f_shape[3], 3, 3).permute(0, 1, 3, 2, 4).reshape(f_shape[0], 1, f_shape[-2]*3 * f_shape[-2]*3).numpy()

        def car_to_polar():
            phis =[]
            rhos =[]
            ind = []
            for i in range(-56*3//2, 56*3//2):
                for j in range(-56*3//2, 56*3//2):
                    rho = np.sqrt(i**2 + j**2)
                    phi = np.arctan2(j, i)
                    phis.append(phi)
                    rhos.append(rho)
            ind = range(0, len(phis))
            return phis, rhos, ind

        phis, rhos, ind = car_to_polar()


        polar_coord = sorted(zip(phis, rhos, ind), key=lambda pair: pair[1])
        polar_coord = [p for _, _, p in sorted(polar_coord, key=lambda pair: pair[0])]

        temp = avg_features # create the tensor
        #for i in range(len(polar_coord)):
            #temp[ :, :, i] = avg_features[ :, :, polar_coord[i]]

        avg_features = rearrange(temp, 'b h w -> b (h w)') #c

        avg_features =  torch.from_numpy(features).permute(0, 2, 3, 1).reshape(f_shape[0], 56, 56, 3, 3).permute(0, 1, 3, 2, 4).reshape(f_shape[0], 1, 56*3, 56*3).numpy()
        avg_features = rearrange(temp, 'b h w -> b (h w)')

    else:
        avg_features = reduce(features, 'b c h w -> b c', 'mean')


    # 3. Unit position
    kw, kh = get_closest_factors(avg_features.shape[1])
    coord = get_coord1(kw, kh)

    # 4. Median match to human data
    avg_response = sine_gratings.SineResponses(avg_features, labels)

    # 5. Interpolated map # cache deleted
    V1_tissue = v1_map.V1Map(coord, avg_response)
    V1_tissues.append(V1_tissue)

    if llc:
        break


def make_parameter_map2(
        tissue,
        axis,
        metric: Metric = angle_metric,
        scale_points=True,
        num_colors=None,
        final_psm=1.0,
        final_s: Optional[float] = 1,
        **kwargs,
    ):
        """
        Plots parameter map for the given metric, e.g., "angles", "sfs", "colors"
        """

        colors = tissue.get_unit_colors(metric=metric)

        if scale_points:

            if metric.name == "angles":
                selectivity = 1.0 - tissue.responses.circular_variance
            else:
                selectivity = tissue.responses.get_peak_heights(metric.name)

            selectivity = np.where(np.isnan(selectivity), 0, selectivity)
            selectivity = (selectivity - np.min(selectivity)) / np.ptp(selectivity) + 0.5
        else:
            selectivity = np.ones((len(self.positions),))

        point_sizes = (
            final_psm  *300* selectivity
        )

        # plot points
        preferences = tissue.get_preferences(metric)
        kw, kh = get_closest_factors(tissue.positions.shape[0])
        ori = rearrange(preferences, '(d1 d2)  -> d1 d2', d1=kw, d2=kh)

        mappable = axis.imshow(ori.T, cmap=metric.colormap, interpolation="nearest")

        return mappable



ncols = len(V1_tissues) + 1

nrows = len(METRIC_DICT)

fig, ax_rows = plt.subplots(
    ncols=ncols, nrows=nrows, figsize=(50, 9), gridspec_kw={"hspace": 0.05} #param
)

if llc:
    fig, ax_rows = plt.subplots(
        ncols=ncols, nrows=nrows, figsize=(30, 90), gridspec_kw={"hspace": 0.05} #param
    )

for ax in ax_rows.ravel():
    ax.axis("off")

# plot models
for axes, tissue in zip(ax_rows.T[1:], V1_tissues):


    # restrict to a smaller window (15% of total width on each side)
    #tissue.set_mask_by_pct_limits([[30, 45], [30, 45]])

    # make a plot for each "metric": orientations, spatial frequencies, and colors
    for (metric_name, metric), ax in zip(METRIC_DICT.items(), axes):
        #ax.set_title(layers_names[V1_tissues.index(tissue)])
        scatter_handle = make_parameter_map2(
            tissue,
            ax,
            metric = metric,
            scale_points=True,
            final_psm=48*2.5/tissue.positions.shape[0],#0.5
            rasterized=True,
            linewidths=0.03,
            edgecolor=(0, 0, 0, 0.5),
        )

        # add a colorbar if we're in the last column
        if axes[0] == ax_rows[0, -1]:
            cbar = add_sine_colorbar(fig, ax, metric, label=metric.xlabel)
            cbar.ax.tick_params(labelsize=5)
            cbar.ax.set_yticklabels(
                [metric.xticklabels[0], metric.xticklabels[-1]], rotation=90
            )
            cbar.set_label(label="", fontsize=8)

    #plot_utils.add_scale_bar(axes[-1], width=1)
    plt.subplots_adjust(hspace=0.1, wspace=0) #wspace continuous

plt.savefig("V1/V1_connected.pdf", format="pdf")
plt.savefig("V1/V1_connected.jpg", format="jpg")



# Pinwheels

In [None]:
from typing import List

import numpy as np
from skimage.measure import label


def circdiff(x, y) -> float:
    """Circular difference between two angles in the range[0, 180]

    For example:
        circdiff(5, 3) = 2
        circdiff(179, 5) = 6
    """
    raw = x - y
    raw = (raw + 90) % 180 - 90
    return raw


def increments(angles) -> List[float]:
    """
    Given a set of angles, return the list of increments from angle to angle
    """
    return [circdiff(angles[i + 1], angles[i]) for i in range(len(angles) - 1)]

In [None]:
def _get_winding_numbers(ori):

        ori = ori.reshape(168,168)
        rows, cols = ori.shape
        winding_numbers = np.zeros_like(ori)
        print(ori.shape)

        for row in range(1, rows - 1):
            for col in range(1, cols - 1):
                values = [
                    ori[row - 1, col - 1],  # NW
                    ori[row - 1, col - 0],  # N
                    ori[row - 1, col + 1],  # NE
                    ori[row - 0, col + 1],  # E
                    ori[row + 1, col + 1],  # SE
                    ori[row + 1, col - 0],  # S
                    ori[row + 1, col - 1],  # SW
                    ori[row + 0, col - 1],  # W
                ]
                incs = increments(values)
                rad_incs = np.radians(incs)
                wn = sum(rad_incs) / (2 * np.pi)
                #print(wn)
                winding_numbers[row, col] = wn*1000

        return winding_numbers

In [None]:
import scipy.stats
def _circstd(arr: np.ndarray, high: float=180):
        if len(arr) == 0:
            return np.nan
        cs = scipy.stats.circstd(arr, high=high - 1)
        return cs / len(arr)


In [None]:
def count_pinwheels(
        ori, winding_numbers, min_px_count: int = 2, thresh: float = 0, var_thresh: float = 1#3.5#0.3
    ):

        var = _circstd(ori)
        print(var, var_thresh, np.min(winding_numbers), thresh)
        pos_passing = (var < var_thresh) & (winding_numbers > thresh)
        neg_passing = (var < var_thresh) & (winding_numbers < (-thresh))
        print(pos_passing)

        counts = [0, 0]
        centers: List[List[float]] = [[], []]
        for idx, mask in enumerate([pos_passing, neg_passing]):
            islands = label(mask)
            unique_labels = np.unique(islands)

            for lab in unique_labels:
                #if lab == _BACKGROUND_CLUSTER:
                    #continue

                # get rows and columns. flipud since rows go up as y goes down
                matching_px = np.stack(np.flipud(np.nonzero(lab == islands))).T
                if len(matching_px) < min_px_count:
                    continue
                counts[idx] += 1
                ctr = np.mean(matching_px, axis=0)
                centers[idx].append(ctr)

        return counts

In [None]:
import pickle
with open('v1_s90.pkl', 'rb') as f:
    ori = pickle.load(f)

In [None]:
pos, neg = count_pinwheels(ori, _get_winding_numbers(ori))

(168, 168)
0.0029742917196110374 1 -2 0
[[False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]
 ...
 [False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]]
