## Configuración

In [None]:
import collections, math
import torch
import torch.nn as nn
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import ipywidgets as widgets

from PIL import Image, ImageDraw

### Funciones utilitarias

In [None]:
def encontrar_device():
  if torch.cuda.is_available():
    device = torch.device("cuda")
  elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
  else:
    device = torch.device("cpu")
  print("Device encontrado:", device)
  return device

def muestrear_estandar(tam_imagen, num_objetos, radio_nominal, ruido_radio=0.1, distancia_min=5, margen=5, max_iter=10_000):
    x, y, r = None, None, None
    for i in range(num_objetos):
        for j in range(max_iter):
            ri = np.round(radio_nominal + ruido_radio * radio_nominal * np.random.normal()).astype(int)
            xi = np.round(np.random.uniform(ri + margen, tam_imagen - ri - margen)).astype(int)
            yi = np.round(np.random.uniform(ri + margen, tam_imagen - ri - margen)).astype(int)
            if i == 0:
                x, y, r = xi, yi, ri
                break
            else:
                dists = np.sqrt((x - xi) ** 2 + (y - yi) ** 2)
                if np.all(dists > r + ri + distancia_min):
                    x = np.append(x, xi)
                    y = np.append(y, yi)
                    r = np.append(r, ri)
                    break
                if j == max_iter - 1:
                    return None
    return x, y, r

def muestrear_area_constante(tam_imagen, num_objetos, area_total, **kwargs):
    radio_nominal = np.sqrt(area_total / num_objetos / np.pi)
    return muestrear_estandar(tam_imagen, num_objetos, radio_nominal, **kwargs)

def muestrear_casco_convexo(tam_imagen, num_objetos, radio_nominal, ruido_radio=0.1, margen=5, distancia_min=5, max_iter=50_000, radio_casco=85, tam_hull=5):
    x, y, r = None, None, None
    cx = cy = tam_imagen / 2
    theta_hull = np.arange(0, 2 * np.pi, (2 * np.pi) / tam_hull)
    theta_hull += np.random.uniform(high=np.pi)
    np.random.shuffle(theta_hull)
    for i in range(num_objetos):
        for j in range(max_iter):
            if i < tam_hull:
                theta = theta_hull[i]
                radio = radio_casco + 5 * np.random.normal()
            else:
                theta = np.random.uniform(high=2 * np.pi)
                radio = np.random.uniform(high=radio_casco - 2 * radio_nominal)
            xi = np.round(cx + radio * np.cos(theta)).astype(int)
            yi = np.round(cy + radio * np.sin(theta)).astype(int)
            ri = np.round(radio_nominal + ruido_radio * radio_nominal * np.random.normal()).astype(int)
            if i == 0:
                x, y, r = xi, yi, ri
                break
            else:
                dists = np.sqrt((x - xi) ** 2 + (y - yi) ** 2)
                if np.all(dists > r + ri + distancia_min):
                    x = np.append(x, xi)
                    y = np.append(y, yi)
                    r = np.append(r, ri)
                    break
                if j == max_iter - 1:
                    return None
    return x, y, r

def verificar_densidad_control(x, y, r, dist_min=90, dist_max=100):
    coords = np.concatenate((x[:, None], y[:, None]), axis=1)
    dists = sp.spatial.distance.cdist(coords, coords, 'euclidean')
    avg_dist = dists[np.triu_indices(len(x), 1)].mean()
    return (avg_dist >= dist_min) & (avg_dist <= dist_max)

def generar_fondo_uniforme(size, A):
    h, w, ch = size
    img = Image.new("L", (w, h), color=A)
    arr = np.array(img, dtype=np.uint8)
    if ch == 1:
        arr = arr[..., None]
    elif ch == 3:
        arr = np.repeat(arr[..., None], 3, axis=2)
    return arr

def dibujar_circulo(img, x, y, r):
    if img.ndim == 3 and img.shape[2] == 1:
        base = img[..., 0]
    else:
        base = img
    pil_img = Image.fromarray(base)
    draw = ImageDraw.Draw(pil_img)
    bbox = [x - r, y - r, x + r, y + r]
    draw.ellipse(bbox, fill=255)
    out = np.array(pil_img, dtype=np.uint8)
    if img.ndim == 3 and img.shape[2] == 1:
        out = out[..., None]
    return out

def dibujar_forma_aleatoria(img, x, y, r):
    if img.ndim == 3 and img.shape[2] == 1:
        base = img[..., 0]
    else:
        base = img
    pil_img = Image.fromarray(base)
    draw = ImageDraw.Draw(pil_img)
    ss = np.random.choice(range(4))
    if ss == 0:  # círculo
        bbox = [x - r, y - r, x + r, y + r]
        draw.ellipse(bbox, fill=255)
    elif ss == 1:  # rectángulo
        r1 = int(np.random.uniform(0.7, 1.0) * r)
        r2 = int(np.random.uniform(0.7, 1.0) * r)
        draw.rectangle([x - r1, y - r2, x + r1, y + r2], fill=255)
    elif ss == 2:  # elipse
        r1 = int(np.random.uniform(0.3, 1.0) * r)
        r2 = int(np.random.uniform(0.3, 1.0) * r)
        bbox = [x - r1, y - r2, x + r1, y + r2]
        draw.ellipse(bbox, fill=255)
    elif ss == 3:  # triángulo
        r1 = int(np.random.uniform(0.7, 1.0) * r)
        r2 = int(np.random.uniform(0.7, 1.0) * r)
        r3 = int(np.random.uniform(0.7, 1.0) * r)
        pts = [(x + r1, y), (x, y - r2), (x - r3, y)]
        draw.polygon(pts, fill=255)
    out = np.array(pil_img, dtype=np.uint8)
    if img.ndim == 3 and img.shape[2] == 1:
        out = out[..., None]
    return out

def generar_conjunto(numerosidades, repeticiones, fn_muestreo, args_muestreo, fn_verificacion, fn_dibujo, tam_imagen, max_iter, nivel_fondo=50):
    S, Q = [], []
    for n in numerosidades:
        for _ in range(repeticiones):
            img = generar_fondo_uniforme((tam_imagen, tam_imagen, 1), A=nivel_fondo)
            if n > 0:
                for v in range(max_iter):
                    x, y, r = fn_muestreo(tam_imagen, n, **args_muestreo)
                    if n > 1:
                        if fn_verificacion(x, y, r):
                            break
                    else:
                        break
                    if v == max_iter - 1:
                        return None, None
                if n == 1:
                    x, y, r = [x], [y], [r]
                for xi, yi, ri in zip(x, y, r):
                    img = fn_dibujo(img, xi, yi, ri)
            S.append(img)
            Q.append(n)
    S = np.array(S)
    Q = np.array(Q)
    randperm = np.random.permutation(len(Q))
    return S[randperm], Q[randperm]

def generar_estimulos(num_reps=40, rango_Q=np.array([0, 1, 2, 3, 4]), radio_punto=18, area_total=1200, tam_hull=3, tam_imagen=224):
    Ss, Qs = generar_conjunto(
        numerosidades=rango_Q,
        repeticiones=num_reps,
        fn_muestreo=muestrear_estandar,
        args_muestreo={'radio_nominal': radio_punto},
        fn_verificacion=lambda x, y, r: True,
        fn_dibujo=dibujar_circulo,
        tam_imagen=tam_imagen,
        max_iter=1000,
        nivel_fondo=50
    )
    Sc, Qc = generar_conjunto(
        numerosidades=rango_Q,
        repeticiones=num_reps,
        fn_muestreo=muestrear_area_constante,
        args_muestreo={'area_total': area_total},
        fn_verificacion=verificar_densidad_control,
        fn_dibujo=dibujar_circulo,
        tam_imagen=tam_imagen,
        max_iter=10_000,
        nivel_fondo=50
    )
    mean_por_imagen = Sc.reshape((Sc.shape[0], -1)).mean(axis=1)
    mean_fija = mean_por_imagen.min()
    Sc = mean_fija * (Sc / mean_por_imagen[:, None, None, None])
    Sss, Qss = generar_conjunto(
        numerosidades=rango_Q,
        repeticiones=num_reps,
        fn_muestreo=muestrear_casco_convexo,
        args_muestreo={'radio_nominal': radio_punto, 'tam_hull': tam_hull},
        fn_verificacion=lambda x, y, r: True,
        fn_dibujo=dibujar_forma_aleatoria,
        tam_imagen=tam_imagen,
        max_iter=1000,
        nivel_fondo=50
    )
    S = np.concatenate((Ss, Sc, Sss))
    Q = np.concatenate((Qs, Qc, Qss))
    C = np.concatenate((
        0 * np.ones_like(Qs),
        1 * np.ones_like(Qc),
        2 * np.ones_like(Qss),
    ))
    S = np.tile(S, (1, 1, 1, 3))
    S = S.transpose((0, 3, 1, 2))
    randperm = np.random.permutation(len(Q))
    S, Q, C = S[randperm], Q[randperm], C[randperm]
    S = S.astype(np.float32) / 255.0
    Q = Q.astype(int)
    C = C.astype(int)
    return S, Q, C

### Funciones de graficado

In [None]:
def visualizar_activaciones(activationes, capa):
    act = activationes[capa][0]
    n = act.shape[0]
    
    cols = int(math.ceil(math.sqrt(n)))
    rows = int(math.ceil(n / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
    axes = axes.flatten()
    
    for i in range(n):
      axes[i].imshow(act[i], cmap="gray")
      axes[i].axis("off")
    
    for j in range(n, len(axes)):
      axes[j].axis("off")
    
    plt.suptitle(f"Activaciones en {capa}")
    plt.show()

def visualizar_pesos(conv_layer, in_channel=0, max_plots=64):
    W = conv_layer.weight.detach().cpu()
    n = min(W.shape[0], max_plots)
    cols = int(math.ceil(math.sqrt(n)))
    rows = int(math.ceil(n / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
    axes = axes.flatten()
    
    for i in range(n):
        ker = W[i, in_channel].numpy()
        axes[i].imshow(ker, cmap="bwr")
        axes[i].axis("off")
    for j in range(n, len(axes)):
        axes[j].axis("off")
    plt.suptitle(f"Pesos {conv_layer.__class__.__name__}, canal entrada {in_channel}")
    plt.show()

### CORnet-Z

In [None]:
class Flatten(nn.Module):
  """
  Helper module for flattening input tensor to 1-D for the use in Linear modules
  """
  def forward(self, x):
    return x.view(x.size(0), -1)

class Identity(nn.Module):
  """
  Helper module that stores the current tensor. Useful for accessing by name
  """
  def forward(self, x):
    return x

class CORblock_Z(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                       stride=stride, padding=kernel_size // 2)
    self.nonlin = nn.ReLU(inplace=True)
    self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.output = Identity()  # for an easy access to this block's output

  def forward(self, inp):
    x = self.conv(inp)
    x = self.nonlin(x)
    x = self.pool(x)
    x = self.output(x)  # for an easy access to this block's output
    return x


def CORnet_Z():
  model = nn.Sequential(collections.OrderedDict([
    ('V1', CORblock_Z(3, 64, kernel_size=7, stride=2)),
    ('V2', CORblock_Z(64, 128)),
    ('V4', CORblock_Z(128, 256)),
    ('IT', CORblock_Z(256, 512)),
    ('decoder', nn.Sequential(collections.OrderedDict([
      ('avgpool', nn.AdaptiveAvgPool2d(1)),
      ('flatten', Flatten()),
      ('linear', nn.Linear(512, 1000)),
      ('output', Identity())
    ])))
  ]))

  # weight initialization
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
      torch.nn.init.xavier_uniform_(m.weight)
      if m.bias is not None:
          torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

  model = nn.DataParallel(model)

  url = 'https://s3.amazonaws.com/cornet-models/cornet_z-5c427c9c.pth'
  ckpt_data = torch.hub.load_state_dict_from_url(url, map_location=device)
  model.load_state_dict(ckpt_data['state_dict'])

  return model

In [None]:
device = encontrar_device()
cnn = CORnet_Z()
print(cnn.module)

In [None]:
rango_Q = np.array([0, 1, 2, 3, 4])
colores_Q = ['red', 'orange', 'green', 'blue', 'purple']

np.random.seed(12345)
S, Q, C = generar_estimulos(num_reps=40, rango_Q=rango_Q, radio_punto=18, tam_hull=3)

print("Forma de S:", S.shape)
print("Forma de Q:", Q.shape)
print("Forma de C:", C.shape)

In [None]:
@widgets.interact(i=(0, len(S)-1))
def visualizar_dataset(i):
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.set_axis_off()
    ax.imshow(S[i].transpose((1, 2, 0)))
    plt.title(f"Q: {Q[i]}, C: {C[i]}")
    plt.show()
    plt.close(fig)

In [None]:
activations = {}

def make_hook(name):
    def _hook(m, i, o): activations[name] = o.detach().cpu()
    return _hook

hooks = [
    cnn.module.V1.register_forward_hook(make_hook('V1')),
    cnn.module.V2.register_forward_hook(make_hook('V2')),
    cnn.module.V4.register_forward_hook(make_hook('V4')),
    cnn.module.IT.register_forward_hook(make_hook('IT'))
]

s = torch.tensor(S[0]).unsqueeze(0).to(device)
out = cnn(s)

@widgets.interact(capa=['V1', 'V2', 'V4', 'IT'])
def simular_inferencia(capa):
  # activations.clear()
  # out = cnn(Xi.unsqueeze(0).to(device))
  visualizar_activaciones(activations, capa)

In [None]:
activations.clear()
s = torch.tensor(S).to(device)
out = cnn(s)
print(activations["IT"].shape)

In [None]:
IT = activations["IT"].flatten(1).numpy()
print(IT.shape)

In [None]:
def anova_two_way(A, B, Y):
    num_cells = Y.shape[1]

    A_levels = np.unique(A); a = len(A_levels)
    B_levels = np.unique(B); b = len(B_levels)
    Y4D = np.array([[Y[(A==i)&(B==j)] for j in B_levels] for i in A_levels])

    r = Y4D.shape[2]

    Y = Y4D.reshape((-1, Y.shape[1]))

    # only test cells (units) that are active (gave a nonzero response to at least one stimulus) to avoid division by zero errors
    active_cells = np.where(np.abs(Y).max(axis=0)>0)[0]
    Y4D = Y4D[:,:,:,active_cells]
    Y = Y[:, active_cells]

    N = Y.shape[0]

    Y_mean = Y.mean(axis=0)
    Y_mean_A = Y4D.mean(axis=1).mean(axis=1)
    Y_mean_B = Y4D.mean(axis=0).mean(axis=1)
    Y_mean_AB = Y4D.mean(axis=2)


    SSA = r*b*np.sum((Y_mean_A - Y_mean)**2, axis=0)
    SSB = r*a*np.sum((Y_mean_B - Y_mean)**2, axis=0)
    SSAB = r*((Y_mean_AB - Y_mean_A[:,None] - Y_mean_B[None,:] + Y_mean)**2).sum(axis=0).sum(axis=0)
    SSE = ((Y4D-Y_mean_AB[:,:,None])**2).sum(axis=0).sum(axis=0).sum(axis=0)
    SST = ((Y-Y_mean)**2).sum(axis=0)

    DFA = a - 1; DFB = b - 1; DFAB = DFA*DFB
    DFE = (N-a*b); DFT = N-1

    MSA = SSA / DFA
    MSB = SSB / DFB
    MSAB = SSAB / DFAB
    MSE = SSE / DFE

    FA = MSA / MSE
    FB = MSB / MSE
    FAB = MSAB / MSE

    pA = np.nan*np.zeros(num_cells)
    pB = np.nan*np.zeros(num_cells)
    pAB = np.nan*np.zeros(num_cells)

    pA[active_cells] = sp.stats.f.sf(FA, DFA, DFE)
    pB[active_cells] = sp.stats.f.sf(FB, DFB, DFE)
    pAB[active_cells] = sp.stats.f.sf(FAB, DFAB, DFE)

    return pA, pB, pAB

pN, pC, pNC = anova_two_way(Q, C, IT)
anova_cells = np.where((pN<0.01) & (pNC>0.01) & (pC>0.01))[0]

In [None]:
H = IT[:, anova_cells]
print(H.shape)

In [None]:
def average_tuning_curves(Q, H, rango_Q):
    Qrange = np.unique(Q)
    tuning_curves = np.array([H[Q==j,:].mean(axis=0) for j in rango_Q])
    return tuning_curves

def preferred_numerosity(Q, H, rango_Q):
    tuning_curves = average_tuning_curves(Q, H, rango_Q)
    pref_num = np.unique(Q)[np.argmax(tuning_curves, axis=0)]
    return pref_num

pref_num = preferred_numerosity(Q, R, rango_Q)
print(pref_num[0])

In [None]:
# Graficamos la distribución de numerosidad preferida
hist = [np.sum(pref_num==q) for q in rango_Q]
hist /= np.sum(hist)

plt.figure(figsize=(3,3))
plt.bar(rango_Q, 100*hist, width=0.8)
plt.xlabel('Numerosidad preferida')
plt.ylabel('Porcentaje de unidades')
print('Número de unidades con numerosidad preferida = %i (%0.2f%%)'%(len(anova_cells), 100*len(anova_cells)/H.shape[1]))

In [None]:
# Graficamos curvas para algunas unidades
def random_unit(preferred_numerosity):
    filtered_units = np.array(np.where(pref_num == preferred_numerosity))
    filtered_units = filtered_units.reshape(filtered_units.shape[1])
    return np.random.choice(filtered_units)

rep_units = [random_unit(0), random_unit(1), random_unit(2), random_unit(3), random_unit(4)]
plt.figure(figsize=(8,1.75))
for i, unit in enumerate(rep_units):
    plt.subplot(1, len(rep_units), i+1)
    for j in np.unique(C):
        tc = np.array([R[(Q==q) & (C==j), unit].mean() for q in rango_Q])
        plt.plot(rango_Q, tc, linewidth=0.5)

        tc = np.array([R[(Q==q), unit].mean() for q in rango_Q])
        err = np.array([R[(Q==q), unit].std() for q in rango_Q]) / np.sqrt(np.sum((Q==Q[0])))
        plt.errorbar(rango_Q, tc, err, color='k', linewidth=1.5)

        plt.xlabel('Numerosidad'); plt.ylabel('Activación')
        plt.title('Unidad %i' %(unit,))
        plt.xticks(rango_Q)

plt.tight_layout()

In [None]:
# Calculate average tuning curve of each unit
tuning_curves = average_tuning_curves(Q, R, rango_Q)

# Calculate population tuning curves for each preferred numerosity
tuning_mat = np.array([np.mean(tuning_curves[:,pref_num==q], axis=1) for q in rango_Q]) # one row for each pref numerosity
tuning_err = np.array([np.std(tuning_curves[:,pref_num==q], axis=1) / np.sqrt(np.sum(pref_num==q)) # standard error for each point on each tuning curve
                       for q in rango_Q])

# Normalize population tuning curves to the 0-1 range
tmmin = tuning_mat.min(axis=1)[:,None]
tmmax = tuning_mat.max(axis=1)[:,None]
tuning_mat = (tuning_mat-tmmin) / (tmmax-tmmin)
tuning_err = tuning_err / (tmmax-tmmin) # scale standard error to be consistent with above normalization

# Plot population tuning curves on linear scale
plt.figure(figsize=(8,2.75))
plt.subplot(1,2,1)
for i, (tc, err) in enumerate(zip(tuning_mat, tuning_err)):
    plt.errorbar(rango_Q, tc, err, color=colores_Q[i])
    plt.xticks(rango_Q)
plt.xlabel('Numerosity')
plt.ylabel('Normalized Neural Activity')

# Plot population tuning curves on log scale
plt.subplot(1,2,2)
for i, (tc, err) in enumerate(zip(tuning_mat, tuning_err)):
    plt.errorbar(rango_Q+1, tc, err, color=colores_Q[i]) # offset x axis by one to avoid taking the log of zero
    plt.xscale('log', base=2)
    plt.xticks(ticks=rango_Q+1, labels=rango_Q)
plt.xlabel('Numerosidad')
plt.ylabel('Actividad normalizada')

# Average responses of zero-tuned units to numerosities 1, 2, and 3
R01 = tuning_curves[:,pref_num==0][1]
R02 = tuning_curves[:,pref_num==0][2]
R03 = tuning_curves[:,pref_num==0][3]

print('For zero-tuned units: average response to 1 vs average response to 2: p-value = %e'
           %(sp.stats.wilcoxon(R01, R02).pvalue))

print('For zero-tuned units: average response to 2 vs average response to 3: p-value = %e'
           %(sp.stats.wilcoxon(R02, R03).pvalue))