In [3]:
from __future__ import print_function, division
from array import array
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import h5py
import random
import scipy as sp
import matplotlib.pyplot as plt
import torch.nn.functional as F
from pathlib import Path
from tqdm.auto import tqdm, trange
import scipy.io
import time
import math
import shutil
from sklearn.decomposition import PCA
import scipy.signal

In [4]:
plt.rc("figure", dpi=100)
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE)
plt.rc('axes', titlesize=MEDIUM_SIZE)
plt.rc('axes', labelsize=SMALL_SIZE)
plt.rc('xtick', labelsize=SMALL_SIZE)
plt.rc('ytick', labelsize=SMALL_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=BIGGER_SIZE)

## Load data

In [5]:
dy_templates_path = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/DY016/'
temp_file = 'kilosort_cleaned_templates.npy'
good_unit_file = 'good_units_kilosort.npy'

kilo_temps = np.load(dy_templates_path + temp_file)
good_units = np.load(dy_templates_path + good_unit_file)
kilo_good = kilo_temps[good_units]

In [6]:
mcs = np.array([np.abs(wf).max(0).argmax() for wf in kilo_temps])
align_ts = np.asarray([np.abs(wf[:, mc]).argmax() for (wf, mc) in zip(kilo_temps, mcs)])

In [7]:
unit_max_channels = np.array([np.argmax(np.max(template, axis=0) - np.min(template, axis=0)) for template in kilo_temps])
max_peak_inds = np.array([np.argmax(template, axis=0) for template in kilo_temps])
max_chan_templates = np.array([kilo_temps[i].T[unit_max_channels[i]] for i in range(len(unit_max_channels))])

In [8]:
# peak to peak computations to get the max 
ptps = np.array([np.max(template.max(0) - template.min(0)) for template in kilo_temps])
high_ptp_indices = np.array([i for i in range(len(kilo_temps)) if ptps[i] > 5])
high_ptp_temps = kilo_temps[high_ptp_indices]
max_chan_hptp_temps = max_chan_templates[high_ptp_indices]

In [9]:
wf_interest_dy = [8, 9, 11, 33, 65, 69, 109, 13, 329, 151]
wf_dy_idx = [0, 1, 2, 10, 22, 25, 38, 3, 50, 102]

In [10]:
dy_wfs_interest = np.load('/Users/ankit/Documents/PaninskiLab/contrastive_spikes/DY016/spikes_train.npy')
dy_wfs_test = np.load('/Users/ankit/Documents/PaninskiLab/contrastive_spikes/DY016/spikes_test.npy')

In [22]:
dy_wfs = np.load('/Users/ankit/Documents/PaninskiLab/contrastive_spikes/DY016/kilosort_cleaned_templates.npy')
unit_ids = np.load('/Users/ankit/Documents/PaninskiLab/contrastive_spikes/DY016/unit_ids.npy')

In [23]:
dy_wfs = np.transpose(dy_wfs, (0, 2, 1))[np.arange(dy_wfs.shape[0]), mcs]
print(dy_wfs.shape)

dy_wfs_sel = dy_wfs[unit_ids]
print(dy_wfs_sel.shape)

(482, 121)
(118, 121)


## Contrastive Model definitions

In [29]:
import torch.nn as nn
from collections import OrderedDict
from matplotlib.gridspec import GridSpec
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Model definition
class SingleChanDenoiser(nn.Module):
    """Cleaned up a little. Why is conv3 here and commented out in forward?"""

    def __init__(
        self, n_filters=[16, 8, 4], filter_sizes=[5, 11, 21], spike_size=121, out_size=2
    ):
        super(SingleChanDenoiser, self).__init__()
        feat1, feat2, feat3 = n_filters
        size1, size2, size3 = filter_sizes
        print(out_size)
        self.conv1 = nn.Sequential(nn.Conv1d(1, feat1, size1), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv1d(feat1, feat2, size2), nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv1d(feat2, feat3, size3), nn.ReLU())
        n_input_feat = feat2 * (spike_size - size1 - size2 + 2)
        self.fc = nn.Linear(n_input_feat, out_size)

    def forward(self, x):
        x = x[:, None]
        x = self.conv1(x)
        x = self.conv2(x)
        # x = self.conv3(x)
        x = x.view(x.shape[0], -1)
        return self.fc(x)

    def load(self, fname_model):
        checkpoint = torch.load(fname_model, map_location="cpu")
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for key in state_dict:
            if "backbone" in key and "fc" not in key:
                new_key = '.'.join(key.split('.')[1:])
                new_state_dict[new_key] = state_dict[key]
        new_state_dict["fc.weight"] = state_dict["backbone.fc.2.weight"]
        new_state_dict["fc.bias"] = state_dict["backbone.fc.2.bias"]
        self.load_state_dict(new_state_dict)
        return self

    
class Encoder(nn.Module):
    def __init__(self, Lv=[200, 150, 100, 75], ks=[11, 21, 31], out_size=2, proj_dim=5):
        super(Encoder, self).__init__()
        self.proj_dim = out_size if out_size < proj_dim else proj_dim
        self.enc_block1d = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=Lv[0], kernel_size=ks[0], padding=math.ceil((ks[0]-1)/2)),
            nn.BatchNorm1d(Lv[0]),
            nn.ReLU(),
            nn.MaxPool1d(2),
            # nn.Dropout(p=0.2),
            nn.Conv1d(Lv[0], Lv[1], ks[1], padding=math.ceil((ks[1]-1)/2)),
            nn.BatchNorm1d(Lv[1]),
            nn.ReLU(),
            nn.MaxPool1d(4),
            # nn.Dropout(p=0.2),
            nn.Conv1d(Lv[1], Lv[2], ks[2], padding=math.ceil((ks[2]-1)/2)),
            nn.BatchNorm1d(Lv[2]),
            nn.ReLU(),
            nn.MaxPool1d(4)
        )
        self.avgpool1d = nn.AdaptiveAvgPool1d((1))

        self.fcpart = nn.Sequential(
            nn.Linear(Lv[2] * 1 * 1, Lv[3]),
            nn.ReLU(),
            # nn.Dropout(p=0.2),
            nn.Linear(Lv[3], out_size),
            Projector(rep_dim=out_size, proj_dim=self.proj_dim)
            )
        self.Lv = Lv

    def forward(self, x):
        x = self.enc_block1d(x)
        x = self.avgpool1d(x)
        x = x.view(-1, self.Lv[2] * 1 * 1)
        x = self.fcpart(x)
        return x

    def load(self, fname_model):
        checkpoint = torch.load(fname_model, map_location="cpu")
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for key in state_dict:
            # if "backbone" in key and "fc" not in key:
            new_key = '.'.join(key.split('.')[1:])
            new_state_dict[new_key] = state_dict[key]
        self.load_state_dict(new_state_dict)
        return self
    

class Encoder2(nn.Module):
    def __init__(self, Lv=[64, 128, 256, 256, 256], ks=[11], out_size = 2, proj_dim=5, fc_depth=2):
        super(Encoder2, self).__init__()
        self.proj_dim = out_size if out_size < proj_dim else proj_dim
        self.enc_block1d = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=Lv[0], kernel_size=ks[0], padding=math.ceil((ks[0]-1)/2)),
            nn.BatchNorm1d(Lv[0]),
            nn.ReLU(),
            nn.MaxPool1d(2),
            # nn.Dropout(p=0.2),
            nn.Conv1d(Lv[0], Lv[1], ks[0], padding=math.ceil((ks[0]-1)/2)),
            nn.BatchNorm1d(Lv[1]),
            nn.ReLU(),
            nn.MaxPool1d(4),
            # nn.Dropout(p=0.2),
            nn.Conv1d(Lv[1], Lv[2], ks[0], padding=math.ceil((ks[0]-1)/2)),
            nn.BatchNorm1d(Lv[2]),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Conv1d(Lv[2], Lv[3], ks[0], padding=math.ceil((ks[0]-1)/2)),
            nn.BatchNorm1d(Lv[2]),
            nn.ReLU(),
        )
        self.avgpool1d = nn.AdaptiveAvgPool1d((1))
        list_layers = [nn.Linear(Lv[3] * 1 * 1, Lv[4]), nn.ReLU(inplace=True)]
        for _ in range(fc_depth-2):
            list_layers += [nn.Linear(Lv[4], Lv[4]), nn.ReLU(inplace=True)]
        list_layers += [nn.Linear(Lv[4], out_size), nn.ReLU(inplace=True)]
        list_layers += [Projector(rep_dim=out_size, proj_dim=self.proj_dim)]
        
        self.fcpart = nn.Sequential(*list_layers)
        
        # nn.Sequential(
        #     nn.Linear(Lv[2] * 1 * 1, Lv[3]),
        #     nn.ReLU(),
        #     # nn.Dropout(p=0.2),
        #     nn.Linear(Lv[3], out_size),
        #     )
        self.Lv = Lv
        # self.projector = Projector2(rep_dim=out_size, proj_dim=self.proj_dim)
    def forward(self, x):
        x = self.enc_block1d(x)
        # print(x.shape)
        x = self.avgpool1d(x)
        x = x.view(-1, self.Lv[2] * 1 * 1)
        x = self.fcpart(x)
        # x = self.projector(x)
        return x
    
    def load(self, fname_model):
        checkpoint = torch.load(fname_model, map_location="cpu")
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for key in state_dict:
            # if "backbone" in key and "fc" not in key:
            new_key = '.'.join(key.split('.')[1:])
            new_state_dict[new_key] = state_dict[key]
            if 'pos_encoder' in key:
                new_state_dict[key] = state_dict[key].transpose()
        self.load_state_dict(new_state_dict)
        return self
    
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

    
class AttentionEnc(nn.Module):
    def __init__(self, spike_size=121, n_channels=1, out_size=2, proj_dim=5, fc_depth=2, nlayers=9, nhead=8, dropout=0.1, expand_dim=16):
        super(AttentionEnc, self).__init__()
        self.spike_size = spike_size
        self.expand_dim = expand_dim
        self.proj_dim = out_size if out_size < proj_dim else proj_dim
        if expand_dim != 1:
            self.encoder = nn.Linear(n_channels, expand_dim)
        else:
            nhead = 1
        self.pos_encoder = PositionalEncoding(expand_dim, dropout, spike_size)
        encoder_layers = TransformerEncoderLayer(expand_dim, nhead, 512, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        list_layers = [nn.Linear(self.spike_size * expand_dim, 256), nn.ReLU(inplace=True)]
        for _ in range(fc_depth-2):
            list_layers += [nn.Linear(256, 256), nn.ReLU(inplace=True)]
        list_layers += [nn.Linear(256, out_size)]
        list_layers += [Projector(rep_dim=out_size, proj_dim=self.proj_dim)]
        
        self.fcpart = nn.Sequential(*list_layers)
        
        # self.fcpart = nn.Sequential(
        #     nn.Linear(self.spike_size * expand_dim, self.spike_size),
        #     nn.ReLU(),
        #     nn.Linear(self.spike_size, out_size),
            
        #     # nn.ReLU(),
        #     # nn.Dropout(p=0.2),
        #     # nn.Linear(5 * self.spike_size * expand_dim, out_size),
        #     Projector(rep_dim=out_size, proj_dim=self.proj_dim)
        # )

    def init_weights(self):
        initrange = 0.1
        # self.encoder.weight.data.uniform_(-initrange, initrange)
        self.fcpart.bias.data.zero_()
        self.fpcart.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask=None):
        """
        Args:
            src: Tensor, shape [batch_size, seq_len]
            src_mask: Tensor, shape [seq_len, seq_len]
        Returns:
            output Tensor of shape [batch_size, proj_dim]
        """
        src = torch.transpose(src, 1, 2)
        if self.expand_dim != 1:
            src = self.encoder(src) * math.sqrt(self.expand_dim)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = output.view(-1, self.spike_size * self.expand_dim)
        output = self.fcpart(output)
        return output

    def load(self, fname_model):
        checkpoint = torch.load(fname_model, map_location="cpu")
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for key in state_dict:
            # if "backbone" in key and "fc" not in key:
            new_key = '.'.join(key.split('.')[1:])
            new_state_dict[new_key] = state_dict[key]
            if 'pos_encoder' in key:
                new_state_dict[new_key] = state_dict[key].transpose(0, 1)
        self.load_state_dict(new_state_dict)
        return self
    
class FullyConnectedEnc(nn.Module):
    def __init__(self, Lv=[121, 550, 1100, 250], out_size=2, proj_dim=5):
        super(FullyConnectedEnc, self).__init__()
        self.proj_dim = out_size if out_size < proj_dim else proj_dim

        self.fcpart = nn.Sequential(
            nn.Linear(Lv[0], Lv[1]),
            nn.ReLU(),
            # nn.Dropout(p=0.2),
            nn.Linear(Lv[1], Lv[2]),
            nn.ReLU(),
            nn.Linear(Lv[2], Lv[3]),
            nn.ReLU(),
            nn.Linear(Lv[3], out_size),
            Projector(rep_dim=out_size, proj_dim=self.proj_dim)
            )
        self.Lv = Lv

    def forward(self, x):
        x = self.fcpart(x)
        return x

    def load(self, fname_model):
        checkpoint = torch.load(fname_model, map_location="cpu")
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for key in state_dict:
            # if "backbone" in key and "fc" not in key:
            new_key = '.'.join(key.split('.')[1:])
            new_state_dict[new_key] = state_dict[key]
        self.load_state_dict(new_state_dict)
        return self

class Projector(nn.Module):
    ''' Projector network accepts a variable number of layers indicated by depth.
    Option to include batchnorm after every layer.'''

    def __init__(self, Lvpj=[512, 128], rep_dim=5, proj_dim=5, bnorm = False, depth = 3):
        super(Projector, self).__init__()
        print(f"Using projector; batchnorm {bnorm} with depth {depth}; hidden_dim={Lvpj[0]}")
        nlayer = [nn.BatchNorm1d(Lvpj[0])] if bnorm else []
        list_layers = [nn.Linear(rep_dim, Lvpj[0])] + nlayer + [nn.ReLU()]
        for _ in range(depth-2):
            list_layers += [nn.Linear(Lvpj[0], Lvpj[0])] + nlayer + [nn.ReLU()]
        list_layers += [nn.Linear(Lvpj[0], proj_dim)]
        self.proj_block = nn.Sequential(*list_layers)

    def forward(self, x):
        x = self.proj_block(x)
        return x
    


## Functions to Compute and Plot Representations

In [34]:
# import umap.umap_ as umap
from sklearn.decomposition import PCA

# def learn_manifold_umap(data, umap_dim, umap_min_dist=0.2, umap_metric='euclidean', umap_neighbors=10):
#     md = float(umap_min_dist)
#     return umap.UMAP(random_state=0, metric=umap_metric, n_components=umap_dim, n_neighbors=umap_neighbors,
#                     min_dist=md).fit_transform(data)

def pca_train(train, test, n_comps):
    pca_ = PCA(n_components=n_comps)
    pca_.fit(train)
    print('train done')
    test_comps = pca_.transform(test)
    print('pca test done')
    return test_comps, pca_.explained_variance_ratio_

def pca(S, n_comps):
    pca_ = PCA(n_components=n_comps)
    return pca_.fit_transform(S), pca_.explained_variance_ratio_

In [35]:
# Run templates through model to get contrastive representations
def compute_reps(model, og_temps, tform_temps):
    og_reps = []
    tform_reps = []
    model = model.double()
    for i, og_temp in enumerate(og_temps):
        tf_temp = tform_temps[i][None, :]
        with torch.no_grad():
            og_rep = model(torch.from_numpy(og_temp.reshape(1, -1)).double())
            tf_rep = model(tf_temp.double())
        og_reps.append(og_rep.numpy())
        tform_reps.append(tf_rep.numpy())
    
    return np.squeeze(np.array(og_reps)), np.squeeze(np.array(tform_reps))

# Run templates through model to get contrastive representations
def compute_reps_test(model, test_wfs):
    og_reps = []
    model = model.double()
    for i, og_temp in enumerate(test_wfs):
        with torch.no_grad():
            og_rep = model(torch.from_numpy(og_temp.reshape(1, 1, -1)).double())
        og_reps.append(og_rep.numpy())
    
    return np.squeeze(np.array(og_reps))


# plot representations of a few templates along with the templates themselves
def plot_reps(og_temps, tform_temps, og_reps, tform_reps, title=None, save_name=None):
    n_temps = len(og_temps)
    lat_dim = og_reps.shape[1]
    num_sels = 6 if lat_dim > 2 else 4
    temp_sels = np.random.choice(np.arange(n_temps), num_sels)
    
    max_chan_max = [max(np.max(og_temps[temp_sels[i]]), np.max(tform_temps[temp_sels[i]].numpy())) for i in range(0, num_sels)]
    max_chan_min = [min(np.min(og_temps[temp_sels[i]]), np.min(tform_temps[temp_sels[i]].numpy())) for i in range(0, num_sels)]
    # max_chan_max = max([np.max(temp) for temp in tot_temps])
    # max_chan_min = min([np.min(temp) for temp in tot_temps])
    colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black'] if lat_dim > 2 else ['blue', 'red', 'green', 'yellow']
    
    fig = plt.figure(figsize=(12, 8), constrained_layout=True) if lat_dim > 2 else plt.figure(figsize=(12, 4), constrained_layout=True)
    gs = GridSpec(4, 6, figure=fig) if lat_dim > 2 else GridSpec(2, 6, figure=fig)
    
    ax0 = fig.add_subplot(gs[2:, :3]) if lat_dim > 2 else fig.add_subplot(gs[:2, :2])
    ax0.scatter(og_reps[temp_sels, 0], og_reps[temp_sels, 1], c=colors, clip_on=False)
    ax0.scatter(tform_reps[temp_sels, 0], tform_reps[temp_sels, 1], c=colors, clip_on=False)
    pnt_names = ['wf {}'.format(str(temp_sels[i])) for i in range(num_sels)] + ['augmented wf {}'.format(str(temp_sels[i])) for i in range(num_sels)] 
    
    if og_reps.shape[1] > 2:
        ax1 = fig.add_subplot(gs[2:, 4:], projection='3d')
        ax1.scatter(og_reps[temp_sels, 2], og_reps[temp_sels, 3], og_reps[temp_sels, 4], c=colors, clip_on=False)
        ax1.scatter(tform_reps[temp_sels, 2], tform_reps[temp_sels, 3], tform_reps[temp_sels, 3], c=colors, clip_on=False) 
    
    for i, txt in enumerate(pnt_names):
        if i < num_sels:
            ax0.annotate(txt, (og_reps[temp_sels, 0][i], og_reps[temp_sels, 1][i]))
            if og_reps.shape[1] > 2:
                ax1.text(og_reps[temp_sels, 2][i], og_reps[temp_sels, 3][i], og_reps[temp_sels, 4][i], txt)
        else:
            ax0.annotate(txt, (tform_reps[temp_sels, 0][i-num_sels], tform_reps[temp_sels, 1][i-num_sels]))
            if og_reps.shape[1] > 2:
                ax1.text(tform_reps[temp_sels, 2][i-num_sels], og_reps[temp_sels, 3][i-num_sels], og_reps[temp_sels, 4][i-num_sels], txt)
    
    if lat_dim > 2:
        axs = [fig.add_subplot(gs[0, i]) for i in range(num_sels)] + [fig.add_subplot(gs[1, i]) for i in range(num_sels)]
    else:
        axs = [fig.add_subplot(gs[0, i]) for i in range(2, 2+num_sels)] + [fig.add_subplot(gs[1, i]) for i in range(2, 2+num_sels)]
        
    x = np.arange(0, 121)

    for i in range(num_sels):
        # axs[0] = fig.add_subplot(gs[i//2, 2 + 2*(i%2)])
        axs[2*i].set_ylim(max_chan_min[i]-0.5, max_chan_max[i]+0.5)
        axs[2*i].title.set_text('wf {}'.format(str(temp_sels[i])))
        axs[2*i].plot(x, og_temps[temp_sels[i]], linewidth=2, markersize=12, color=colors[i])
        axs[2*i].get_xaxis().set_visible(False)

        # ax2 = fig.add_subplot(gs[i//2, 3 + 2*(i%2)])
        axs[2*i+1].title.set_text('augmented wf {}'.format(str(temp_sels[i])))
        axs[2*i+1].set_ylim(max_chan_min[i]-0.5, max_chan_max[i]+0.5)
        axs[2*i+1].plot(x, tform_temps[temp_sels[i]].numpy(), linewidth=2, markersize=12, color=colors[i])
        axs[2*i+1].get_xaxis().set_visible(False)
    
    # fig.subplots_adjust(wspace=0)

    fig.suptitle(title)
    
    if save_name is not None:
        plt.savefig(save_name)

### MLP 

In [36]:
import torch
from torch import nn
from torch.utils.data import DataLoader

class MLP(nn.Module):
    '''
        Multilayer Perceptron.
    '''
    def __init__(self, input_size=2, layer_sizes=[100, 50, 10]):
        super().__init__()
        if len(layer_sizes) == 3:
            self.layers = nn.Sequential(
                nn.Flatten(),
                nn.Linear(input_size, layer_sizes[0]),
                nn.ReLU(),
                nn.Linear(layer_sizes[0], layer_sizes[1]),
                nn.ReLU(),
                nn.Linear(layer_sizes[1], layer_sizes[2])
            )
        else:
            self.layers = nn.Sequential(
                nn.Flatten(),
                nn.Linear(input_size, layer_sizes[0]),
                nn.ReLU(),
                nn.Linear(layer_sizes[0], layer_sizes[1])
            )

    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)
    
def train(data, labels, layers=[1000, 50, 10], epochs=25):
    mlp = MLP(input_size=data.shape[1], layer_sizes=layers)
    train_data = list(zip(data, labels))
    trainloader = DataLoader(train_data, batch_size=256, shuffle=True)
  
    # Define the loss function and optimizer
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

    # Run the training loop
    for epoch in range(epochs): # 5 epochs at maximum

        # Print epoch
        # if epoch % (epochs//10):
        #     print(f'Starting epoch {epoch+1}')

        # Set current loss value
        current_loss = 0.0

        # Iterate over the DataLoader for training data
        for i, data in enumerate(trainloader):

            # Get inputs
            inputs, targets = data

            # Zero the gradients
            optimizer.zero_grad()

            # Perform forward pass
            outputs = mlp(inputs.float())

            # Compute loss
            loss = loss_function(outputs, targets)

            # Perform backward pass
            loss.backward()

            # Perform optimization
            optimizer.step()

            # Print statistics
            current_loss += loss.item()

            # if i % 500 == 499:
            #     print('Loss after mini-batch %5d: %.3f' %
            #         (i + 1, current_loss / 500))
            #     current_loss = 0.0

    # Process is complete.
    print('Training process has finished.')
    return mlp

In [37]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def class_scores(train_reps, test_reps, num_classes, layers=[1000, 100, 10], epochs=100):
    labels_train = np.array([[i for j in range(1200)] for i in range(num_classes)]).reshape(-1)
    labels_test = np.array([[i for j in range(300)] for i in range(num_classes)]).reshape(-1)
    
    # knn = KNeighborsClassifier(n_neighbors = 10)
    # knn.fit(train_reps, labels_train)
    mlp = train(train_reps, labels_train, layers, epochs=epochs)
    per_class_acc = {}
    for i in range(num_classes):
        # class_score = knn.score(test_reps[300*i:300*(i+1)], labels_test[300*i:300*(i+1)])*100
        with torch.no_grad():
            pred_labels = np.argmax(mlp(torch.from_numpy(test_reps[300*i:300*(i+1)]).float()).numpy(), axis=1)
        class_score = accuracy_score(labels_test[300*i:300*(i+1)], pred_labels) * 100
        per_class_acc['wf {}'.format(str(wf_interest_dy[i]))] = class_score
    return per_class_acc
        
def avg_score(train_reps, test_reps, num_classes, layers=[1000, 100, 10], epochs=100):
    labels_train = np.array([[i for j in range(1200)] for i in range(num_classes)]).reshape(-1)
    labels_test = np.array([[i for j in range(300)] for i in range(num_classes)]).reshape(-1)
    
    # knn = KNeighborsClassifier(n_neighbors = 10)
    # knn.fit(train_reps, labels_train)
    # acc['score'] = knn.score(test_reps, labels_test)*100
    
    mlp = train(train_reps, labels_train, layers, epochs=100)
    with torch.no_grad():
        pred_labels = np.argmax(mlp(torch.from_numpy(test_reps).float()).numpy(), axis=1)
    acc = {}
    acc['score'] = accuracy_score(pred_labels, labels_test)*100
    return acc

In [38]:
def per_class_accs(train_reps, test_reps, models, num_classes):
    class_res = {}

    for i in range(len(train_reps)):
        class_res[models[i]] = class_scores(train_reps[i], test_reps[i], num_classes)
        
    return class_res

def avg_class_accs(train_reps, test_reps, models, num_classes):
    class_res = {}

    for i in range(len(train_reps)):
        class_res[models[i]] = avg_score(train_reps[i], test_reps[i], num_classes)
        
    return class_res

In [43]:
def get_enc_backbone(enc):
    last_layer = list(list(enc.children())[-1].children())[:-1]
    enc.fcpart = nn.Sequential(*last_layer)
    return enc

def get_fcenc_backbone(enc):
    last_layer = list(list(enc.children())[-1].children())[:-1]
    enc.fcpart = nn.Sequential(*last_layer)
    return enc

def get_ckpt_results(ckpt, lat_dim, train_data, test_data, plot=False, wfs=None, wfs_interest=None, title=None, enc_type=None, Lv=None, ks=None, fc=None, save_name=None):
    if enc_type is None or enc_type == 'encoder':
        Lv = [200, 150, 100, 75] if Lv is None else Lv
        ks = [11, 21, 31] if ks is None else ks
        enc = Encoder(Lv=Lv, ks=ks, out_size=lat_dim).load(ckpt)
        backbone = get_enc_backbone(enc)
    elif enc_type == 'fc_encoder':
        Lv = [121, 550, 1100, 250] if Lv is None else Lv
        enc = FullyConnectedEnc(Lv=Lv, out_size=lat_dim).load(ckpt)
        backbone = get_fcenc_backbone(enc)
    elif enc_type == 'custom_encoder2':
        Lv=[64, 128, 256, 256, 256] if Lv is None else Lv
        ks = [11] if ks is None else ks
        enc = Encoder2(Lv=Lv, ks=ks, out_size=lat_dim).load(ckpt)
        backbone = get_enc_backbone(enc)
    elif enc_type == 'attention_encoder':
        fc_depth = 2 if fc is None else fc
        enc = AttentionEnc(out_size=lat_dim, proj_dim=5, fc_depth=fc_depth, dropout=0.1, expand_dim=16).load(ckpt)
        backbone = get_fcenc_backbone(enc)
        
    contr_reps_train = compute_reps_test(backbone, train_data)
    contr_reps_test = compute_reps_test(backbone, test_data)

    if lat_dim > 2:
        # contr_reps_test_umap = learn_manifold_umap(contr_reps_test, 2) 
        contr_reps_test_pca, _ = pca(contr_reps_test, 2)
    else:
        contr_reps_test_pca = contr_reps_test

    pca_tr, _ = pca(train_data, lat_dim)
    pca_test, _ = pca_train(train_data, test_data, lat_dim)
    
    if plot:
        plot_contr_v_pca(pca_test, contr_reps_test_pca, wfs, wfs_interest, title=title, save_name=save_name)
    
    return contr_reps_train, contr_reps_test, contr_reps_test_pca, pca_tr, pca_test

In [44]:
def plot_contr_v_pca(pca_reps, contr_reps, wfs, wf_interest, title=None, save_name=None, wf_selection=None):
    og_wfs = wfs[wf_interest]
    n_temps = len(pca_reps)
    lat_dim = pca_reps.shape[1]
    num_wfs = len(og_wfs)
    
    max_chan_max = np.max(np.max(og_wfs, axis=1))
    max_chan_min = np.min(np.min(og_wfs, axis=1))
    # max_chan_max = max([np.max(temp) for temp in tot_temps])
    # max_chan_min = min([np.min(temp) for temp in tot_temps])
    if wf_selection is None:
        colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink'][:num_wfs]
    else:
        colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink'][wf_selection[0]:wf_selection[1]]
        print(colors)
    num_reps = int(len(pca_reps) / num_wfs)
    print(num_reps)
    labels = np.array([[colors[i] for j in range(num_reps)] for i in range(num_wfs)])
    labels = labels.flatten()
    print(labels.shape)
    
    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(4, num_wfs, figure=fig)
    
    ax0 = fig.add_subplot(gs[:3, :int(num_wfs/2)])
    ax0.title.set_text('PCA wf representations')
    ax0.scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels, clip_on=False)
    
    ax1 = fig.add_subplot(gs[:3, int(num_wfs/2):])
    ax1.title.set_text('Contrastive wf representations')
    ax1.scatter(contr_reps[:, 0], contr_reps[:, 1], c=labels, clip_on=True) 
    # ax1.set_xlim([0, 25])
    # ax1.set_ylim([-7, 15])
    
    axs = [fig.add_subplot(gs[3, i]) for i in range(num_wfs)]
        
    x = np.arange(0, 121)

    for i in range(num_wfs):
        # axs[0] = fig.add_subplot(gs[i//2, 2 + 2*(i%2)])
        axs[i].set_ylim(max_chan_min-0.5, max_chan_max+0.5)
        axs[i].title.set_text('unit {}'.format(str(wf_interest[i])))
        axs[i].plot(x, og_wfs[i], linewidth=2, markersize=12, color=colors[i])
        axs[i].get_xaxis().set_visible(False)
    
    # fig.subplots_adjust(wspace=0)

    fig.suptitle(title)
    
    if save_name is not None:
        plt.savefig(save_name)
        
def plot_recon_v_spike(wf_train, wf_test, wfs, wf_interest, ckpt, lat_dim, title, save_name=None, wf_selection=None):
    og_wfs = wfs[wf_interest]
    tot_spikes, n_times = wf_test.shape
    spike_sel = np.random.choice(tot_spikes)
    spike = wf_test[spike_sel]
    num_wfs = 10
    
    pca_aug = PCA_Reproj()
    pca_train = np.array([pca_aug(wf) for wf in wf_train])
    pca_test = np.array([pca_aug(wf) for wf in wf_test])
    
    _, contr_spikes_test, contr_spikes_test_pca, _, pca_spikes_test = get_ckpt_results(ckpt, lat_dim, wf_train, wf_test)
    # contr_spikes_test_pca = contr_spikes_test_pca.reshape(4, num_ex, -1)
    # pca_spikes_test = pca_spikes_test.reshape(4, num_ex, -1)
    
    _, contr_recon_test, contr_recon_test_pca, _, pca_recon_test = get_ckpt_results(ckpt, lat_dim, pca_train, pca_test)
    # contr_recon_test_pca = contr_recon_test_pca.reshape(4, num_ex, -1)
    # pca_spikes_test = pca_spikes_test.reshape(4, num_ex, -1)
    
    max_chan_max = np.max(np.max(og_wfs, axis=1))
    max_chan_min = np.min(np.min(og_wfs, axis=1))
    # max_chan_max = max([np.max(temp) for temp in tot_temps])
    # max_chan_min = min([np.min(temp) for temp in tot_temps])
    if wf_selection is None:
        colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink'][:num_wfs]
    else:
        colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink'][wf_selection[0]:wf_selection[1]]
        print(colors)
    num_reps = int(len(wf_test) / num_wfs)
    print(num_reps)
    labels = np.array([[colors[i] for j in range(num_reps)] for i in range(num_wfs)])
    labels = labels.flatten()
    print(labels.shape)
    
    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(4, num_wfs, figure=fig)
    
    ax0 = fig.add_subplot(gs[:3, :int(num_wfs/2)])
    ax0.title.set_text('Contrastive spike representations')
    ax0.scatter(contr_spikes_test_pca[:, 0], contr_spikes_test_pca[:, 1], c=labels, clip_on=False)
    
    ax1 = fig.add_subplot(gs[:3, int(num_wfs/2):])
    ax1.title.set_text('Contrastive pca recon. spike representations')
    ax1.scatter(contr_recon_test_pca[:, 0], contr_recon_test_pca[:, 1], c=labels, clip_on=True) 
    # ax1.set_xlim([0, 25])
    # ax1.set_ylim([-7, 15])
    
    axs = [fig.add_subplot(gs[3, i]) for i in range(num_wfs)]
        
    x = np.arange(0, 121)

    for i in range(num_wfs):
        # axs[0] = fig.add_subplot(gs[i//2, 2 + 2*(i%2)])
        axs[i].set_ylim(max_chan_min-0.5, max_chan_max+0.5)
        axs[i].title.set_text('unit {}'.format(str(wf_interest[i])))
        axs[i].plot(x, og_wfs[i], linewidth=2, markersize=12, color=colors[i])
        axs[i].get_xaxis().set_visible(False)
    
    # fig.subplots_adjust(wspace=0)

    fig.suptitle(title)
    
    if save_name is not None:
        plt.savefig(save_name)

In [45]:
def plot_aug_shifts(wf_train, wf_test, ckpt, lat_dim, title, save_name=None):
    tot_spikes, n_times = wf_test.shape
    spike_sel = np.random.choice(tot_spikes)
    spike = wf_test[spike_sel]
    num_ex = 10
    
    jit = Jitter()
    collide = Collide()
    noise = SmartNoise()
    
    amp_jitter_spikes = np.array([scale * spike for scale in np.linspace(0.9, 1.1, num=num_ex)])
    jitter_spikes = np.array([jit(spike) for i in range(num_ex)])
    collided_spikes = np.array([collide(spike) for i in range(num_ex)])
    noised_spikes = np.array([noise(spike) for i in range(num_ex)])
    aug_spikes = np.array([amp_jitter_spikes, jitter_spikes, collided_spikes, noised_spikes])
    aug_titles = ['Amplitude Jitter', 'Jitter', 'Collision', 'Noise']

    _, contr_reps_test, contr_reps_test_pca, _, pca_test = get_ckpt_results(ckpt, lat_dim, wf_train, aug_spikes.reshape(-1, 121))
    contr_reps_test_pca = contr_reps_test_pca.reshape(4, num_ex, -1)
    pca_test = pca_test.reshape(4, num_ex, -1)
    
#     sel_spikes = np.array([max_chan_spikes[i, sels[i], :] for i in range(num_temps)])
    
#     max_chan_max = [np.max([np.max(sel_spikes[i][j] for j in range(num_sel_spikes))]) for i in range(num_temps)]
#     print(max_chan_max)
#     max_chan_min = [np.min([np.min(sel_spikes[i][j] for j in range(num_sel_spikes))]) for i in range(num_temps)]
    # max_chan_max = max([np.max(temp) for temp in tot_temps])
    # max_chan_min = min([np.min(temp) for temp in tot_temps])
    colors = ['blue', 'red', 'green', 'magenta']
    # cmap = plt.cm.get_cmap('hsv', 10)
    # colors = [cmap(i) for i in range(10)]
    alphas = np.linspace(0.1, 1, num=10)
    SMALL_SIZE = 12
    MEDIUM_SIZE = 16
    BIGGER_SIZE = 20
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    
    # fig, ax = plt.subplots(4, 6, figsize=(18, 30))
    fig = plt.figure(figsize=(22, 18), constrained_layout=True)
    fig.tight_layout()
    gs = GridSpec(8, 9, figure=fig)
    
    for row in range(8):
        
        zero_title = '{}'.format(aug_titles[row//2])
        
        x = np.arange(0, n_times)
        ax0 = fig.add_subplot(gs[row, 0])
        if row % 2 == 0:
            ax0.title.set_text(zero_title)
        ax0.plot(x, aug_spikes[row//2, (row%2)*5 + 0], linewidth=2, markersize=12, color=colors[row//2], alpha=alphas[(row%2)*5 + 0])
        # ax0.plot(x, aug_spikes[row//2, (row%2)*5 + 0], linewidth=2, markersize=12, color=colors[(row%2)*5 + 0])
        ax0.get_xaxis().set_visible(False)
        # ax0.get_yaxis().set_visible(False)
        
        # print(max_chan_min[row] + ' ')
        # print(max_chan_max[row])
        # ax[row, 0].set_ylim(max_chan_min[row]-0.5, max_chan_max[row]+0.5)
        ax1 = fig.add_subplot(gs[row, 1], sharey=ax0)
        ax1.plot(x, aug_spikes[row//2, (row%2)*5 + 1], linewidth=2, markersize=12, color=colors[row//2], alpha=alphas[(row%2)*5 + 1])
        # ax1.plot(x, aug_spikes[row//2, (row%2)*5 + 1], linewidth=2, markersize=12, color=colors[(row%2)*5 + 1])
        ax1.get_xaxis().set_visible(False)
        ax1.get_yaxis().set_visible(False)

        # ax[row, 1].title.set_text('view 1')
        ax2 = fig.add_subplot(gs[row, 2], sharey=ax0)
        ax2.plot(x, aug_spikes[row//2, (row%2)*5 + 2], linewidth=2, markersize=12, color=colors[row//2], alpha=alphas[(row%2)*5 + 2])
        # ax2.plot(x, aug_spikes[row//2, (row%2)*5 + 2], linewidth=2, markersize=12, color=colors[(row%2)*5 + 2])
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)
        
        # ax[row, 2].title.set_text('view 2')
        ax3 = fig.add_subplot(gs[row, 3], sharey=ax0)
        ax3.plot(x, aug_spikes[row//2, (row%2)*5 + 3], linewidth=2, markersize=12, color=colors[row//2], alpha=alphas[(row%2)*5 + 3])
        # ax3.plot(x, aug_spikes[row//2, (row%2)*5 + 3], linewidth=2, markersize=12, color=colors[(row%2)*5 + 3])
        ax3.get_xaxis().set_visible(False)
        ax3.get_yaxis().set_visible(False)

        # ax[row, 3].title.set_text('overlaid view 1')
        ax4 = fig.add_subplot(gs[row, 4], sharey=ax0)
        ax4.plot(x, aug_spikes[row//2, (row%2)*5 + 4], linewidth=2, markersize=12, color=colors[row//2], alpha=alphas[(row%2)*5 + 4])
        # ax4.plot(x, aug_spikes[row//2, (row%2)*5 + 4], linewidth=2, markersize=12, color=colors[(row%2)*5 + 4])
        ax4.get_xaxis().set_visible(False)
        ax4.get_yaxis().set_visible(False)
        
        if row % 2 == 0:
            ax5 = fig.add_subplot(gs[row:row+2, 5:7])
            if row == 0:
                ax5.title.set_text('Contrastive 2D PCA Representations')
            ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors[row//2], alpha=alphas)
            # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
            # ax[row, 5].get_xaxis().set_visible(False)
            ax5.set_xlim([-1, 1])
            ax5.set_ylim([-1, 1])
            
        if row % 2 == 0:
            ax6 = fig.add_subplot(gs[row:row+2, 7:])
            if row == 0:
                ax6.title.set_text('PCA 2D Representations')
            ax6.scatter(pca_test[row//2, :, 0], pca_test[row//2, :, 1], color=colors[row//2], alpha=alphas)
            # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
            # ax[row, 5].get_xaxis().set_visible(False)
            ax6.set_xlim([-25, 25])
            ax6.set_ylim([-20, 20])
    
    fig.suptitle(title)
    fig.subplots_adjust(top=0.94)
    # fig.subplots_adjust(wspace=0.12)
    
    fig.subplots_adjust(hspace=0.4)
    
    if save_name is not None:
        plt.savefig(save_name)
        

def plot_recon_v_spikes_indiv(wf_train, wf_test, ckpt, lat_dim, pca_dim, title, enc_type=None, Lv=None, ks=None, save_name=None):
    tot_spikes, n_times = wf_test.shape
    spike_sels = np.array([15 + 300*ind for ind in range(10)])
    spikes = wf_test[spike_sels]
    num_ex = 10
    
    pca_aug = PCA_Reproj(pca_dim=pca_dim)
    pca_train = np.array([pca_aug(wf) for wf in wf_train])
    pca_test = np.array([pca_aug(wf) for wf in wf_test])

    _, contr_spikes_test, contr_spikes_test_pca, _, pca_spikes_test = get_ckpt_results(ckpt, lat_dim, wf_train, wf_test, enc_type=enc_type, Lv=Lv, ks=ks)
    print(contr_spikes_test_pca.shape)
    contr_spikes_test_pca = contr_spikes_test_pca[spike_sels]
    pca_spikes_test = pca_spikes_test[spike_sels]
    
    _, contr_recon_test, contr_recon_test_pca, _, pca_recon_test = get_ckpt_results(ckpt, lat_dim, pca_train, pca_test, enc_type=enc_type, Lv=Lv, ks=ks)
    contr_recon_test_pca = contr_recon_test_pca[spike_sels]
    pca_recon_test = pca_recon_test[spike_sels]
    
#     colors = ['blue', 'red', 'green', 'magenta']
#     cmap = plt.cm.get_cmap('hsv', 10)
#     colors = [cmap(i) for i in range(10)]
    colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink']
#     alphas = np.linspace(0.1, 1, num=10)
    alphas = np.ones(10)
    SMALL_SIZE = 12
    MEDIUM_SIZE = 16
    BIGGER_SIZE = 20
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    
    # fig, ax = plt.subplots(4, 6, figsize=(18, 30))
    fig = plt.figure(figsize=(22, 5), constrained_layout=True)
    fig.tight_layout()
    gs = GridSpec(2, 9, figure=fig)

    for row in range(2):
        
#         zero_title = '{}'.format(aug_titles[row//2])
        
        x = np.arange(0, n_times)
        ax0 = fig.add_subplot(gs[row, 0])
#         if row % 2 == 0:
#             ax0.title.set_text(zero_title)
        ax0.plot(x, spikes[(row%2)*5 + 0], linewidth=2, markersize=12, color=colors[5*row], alpha=alphas[(row%2)*5 + 0])
        # ax0.plot(x, aug_spikes[row//2, (row%2)*5 + 0], linewidth=2, markersize=12, color=colors[(row%2)*5 + 0])
        ax0.get_xaxis().set_visible(False)
        # ax0.get_yaxis().set_visible(False)
        
        # print(max_chan_min[row] + ' ')
        # print(max_chan_max[row])
        # ax[row, 0].set_ylim(max_chan_min[row]-0.5, max_chan_max[row]+0.5)
        ax1 = fig.add_subplot(gs[row, 1], sharey=ax0)
        ax1.plot(x, spikes[(row%2)*5 + 1], linewidth=2, markersize=12, color=colors[5*row+1], alpha=alphas[(row%2)*5 + 1])
        # ax1.plot(x, aug_spikes[row//2, (row%2)*5 + 1], linewidth=2, markersize=12, color=colors[(row%2)*5 + 1])
        ax1.get_xaxis().set_visible(False)
        ax1.get_yaxis().set_visible(False)

        # ax[row, 1].title.set_text('view 1')
        ax2 = fig.add_subplot(gs[row, 2], sharey=ax0)
        ax2.plot(x, spikes[(row%2)*5 + 2], linewidth=2, markersize=12, color=colors[5*row+2], alpha=alphas[(row%2)*5 + 2])
        # ax2.plot(x, aug_spikes[row//2, (row%2)*5 + 2], linewidth=2, markersize=12, color=colors[(row%2)*5 + 2])
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)
        
        # ax[row, 2].title.set_text('view 2')
        ax3 = fig.add_subplot(gs[row, 3], sharey=ax0)
        ax3.plot(x, spikes[(row%2)*5 + 3], linewidth=2, markersize=12, color=colors[5*row+3], alpha=alphas[(row%2)*5 + 3])
        # ax3.plot(x, aug_spikes[row//2, (row%2)*5 + 3], linewidth=2, markersize=12, color=colors[(row%2)*5 + 3])
        ax3.get_xaxis().set_visible(False)
        ax3.get_yaxis().set_visible(False)

        # ax[row, 3].title.set_text('overlaid view 1')
        ax4 = fig.add_subplot(gs[row, 4], sharey=ax0)
        ax4.plot(x, spikes[(row%2)*5 + 4], linewidth=2, markersize=12, color=colors[5*row+4], alpha=alphas[(row%2)*5 + 4])
        # ax4.plot(x, aug_spikes[row//2, (row%2)*5 + 4], linewidth=2, markersize=12, color=colors[(row%2)*5 + 4])
        ax4.get_xaxis().set_visible(False)
        ax4.get_yaxis().set_visible(False)
        
        if row % 2 == 0:
            ax5 = fig.add_subplot(gs[row:row+2, 5:7])
            if row == 0:
                ax5.title.set_text('Contrastive Representations')
            ax5.scatter(contr_spikes_test_pca[:, 0], contr_spikes_test_pca[:, 1], color=colors, alpha=alphas, label='spike')
            ax5.scatter(contr_recon_test_pca[:, 0], contr_recon_test_pca[:, 1], color=colors, alpha=alphas, marker='x', label='recon. spike')
            # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
            # ax[row, 5].get_xaxis().set_visible(False)
#             ax5.set_xlim([-1, 1])
#             ax5.set_ylim([-1, 1])
            ax5.legend()
            
            ax6 = fig.add_subplot(gs[row:row+2, 7:])
            if row == 0:
                ax6.title.set_text('PCA Representations')
            ax6.scatter(pca_spikes_test[:, 0], pca_spikes_test[:, 1], color=colors, alpha=alphas, label='spike')
            ax6.scatter(pca_recon_test[:, 0], pca_recon_test[:, 1], color=colors, alpha=alphas, marker='x', label='recon. spike')
            # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
            # ax[row, 5].get_xaxis().set_visible(False)
            ax6.set_xlim([-25, 25])
            ax6.set_ylim([-20, 20])
            ax6.legend()
    
    fig.suptitle(title)
    fig.subplots_adjust(top=0.85)
    # fig.subplots_adjust(wspace=0.12)
    
    fig.subplots_adjust(hspace=0.2)
    
    if save_name is not None:
        plt.savefig(save_name)

        
def plot_recon_v_spikes_indiv_pca(wf_train, wf_test, ckpt, lat_dim, pca_dim, title, save_name=None):
    tot_spikes, n_times = wf_test.shape
    spike_sels = np.array([15 + 300*ind for ind in range(10)])
    spikes = wf_test[spike_sels]
    num_ex = 10
    
    pca_aug = PCA_Reproj(pca_dim=pca_dim)
    pca_train = np.array([pca_aug(wf) for wf in wf_train])
    pca_test = np.array([pca_aug(wf) for wf in wf_test])

    _, contr_spikes_test, contr_spikes_test_pca, _, pca_spikes_test = get_ckpt_results(ckpt, lat_dim, wf_train, wf_test)
    pca_spikes_test = pca_spikes_test[spike_sels]
    
    _, contr_recon_test, contr_recon_test_pca, _, pca_recon_test = get_ckpt_results(ckpt, lat_dim, pca_train, pca_test)
    pca_recon_test = pca_recon_test[spike_sels]
    
#     colors = ['blue', 'red', 'green', 'magenta']
#     cmap = plt.cm.get_cmap('hsv', 10)
#     colors = [cmap(i) for i in range(10)]
    colors = ['blue', 'red', 'green', 'yellow', 'orange', 'black', 'cyan', 'violet', 'maroon', 'pink']
#     alphas = np.linspace(0.1, 1, num=10)
    alphas = np.ones(10)
    SMALL_SIZE = 12
    MEDIUM_SIZE = 16
    BIGGER_SIZE = 20
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    
    # fig, ax = plt.subplots(4, 6, figsize=(18, 30))
    fig = plt.figure(figsize=(17, 5), constrained_layout=True)
    fig.tight_layout()
    gs = GridSpec(2, 7, figure=fig)

    for row in range(2):
        
#         zero_title = '{}'.format(aug_titles[row//2])
        
        x = np.arange(0, n_times)
        ax0 = fig.add_subplot(gs[row, 0])
#         if row % 2 == 0:
#             ax0.title.set_text(zero_title)
        ax0.plot(x, spikes[(row%2)*5 + 0], linewidth=2, markersize=12, color=colors[5*row], alpha=alphas[(row%2)*5 + 0])
        # ax0.plot(x, aug_spikes[row//2, (row%2)*5 + 0], linewidth=2, markersize=12, color=colors[(row%2)*5 + 0])
        ax0.get_xaxis().set_visible(False)
        # ax0.get_yaxis().set_visible(False)
        
        # print(max_chan_min[row] + ' ')
        # print(max_chan_max[row])
        # ax[row, 0].set_ylim(max_chan_min[row]-0.5, max_chan_max[row]+0.5)
        ax1 = fig.add_subplot(gs[row, 1], sharey=ax0)
        ax1.plot(x, spikes[(row%2)*5 + 1], linewidth=2, markersize=12, color=colors[5*row+1], alpha=alphas[(row%2)*5 + 1])
        # ax1.plot(x, aug_spikes[row//2, (row%2)*5 + 1], linewidth=2, markersize=12, color=colors[(row%2)*5 + 1])
        ax1.get_xaxis().set_visible(False)
        ax1.get_yaxis().set_visible(False)

        # ax[row, 1].title.set_text('view 1')
        ax2 = fig.add_subplot(gs[row, 2], sharey=ax0)
        ax2.plot(x, spikes[(row%2)*5 + 2], linewidth=2, markersize=12, color=colors[5*row+2], alpha=alphas[(row%2)*5 + 2])
        # ax2.plot(x, aug_spikes[row//2, (row%2)*5 + 2], linewidth=2, markersize=12, color=colors[(row%2)*5 + 2])
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)
        
        # ax[row, 2].title.set_text('view 2')
        ax3 = fig.add_subplot(gs[row, 3], sharey=ax0)
        ax3.plot(x, spikes[(row%2)*5 + 3], linewidth=2, markersize=12, color=colors[5*row+3], alpha=alphas[(row%2)*5 + 3])
        # ax3.plot(x, aug_spikes[row//2, (row%2)*5 + 3], linewidth=2, markersize=12, color=colors[(row%2)*5 + 3])
        ax3.get_xaxis().set_visible(False)
        ax3.get_yaxis().set_visible(False)

        # ax[row, 3].title.set_text('overlaid view 1')
        ax4 = fig.add_subplot(gs[row, 4], sharey=ax0)
        ax4.plot(x, spikes[(row%2)*5 + 4], linewidth=2, markersize=12, color=colors[5*row+4], alpha=alphas[(row%2)*5 + 4])
        # ax4.plot(x, aug_spikes[row//2, (row%2)*5 + 4], linewidth=2, markersize=12, color=colors[(row%2)*5 + 4])
        ax4.get_xaxis().set_visible(False)
        ax4.get_yaxis().set_visible(False)
        
        if row % 2 == 0:
            ax5 = fig.add_subplot(gs[row:row+2, 5:7])
            if row == 0:
                ax5.title.set_text('2D Spike v. Reconstructed Spike Representations')
            ax5.scatter(pca_spikes_test[:, 0], pca_spikes_test[:, 1], color=colors, alpha=alphas, label='spike')
            ax5.scatter(pca_recon_test[:, 0], pca_recon_test[:, 1], color=colors, alpha=alphas, marker='x', label='recon. spike')
            # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
            # ax[row, 5].get_xaxis().set_visible(False)
            ax5.set_xlim([-25, 25])
            ax5.set_ylim([-20, 20])
            ax5.legend()
            
#         if row % 2 == 0:
#             ax6 = fig.add_subplot(gs[row:row+2, 7:])
#             if row == 0:
#                 ax6.title.set_text('2D Reconstructed Spike Representations')
#             ax6.scatter(contr_recon_test_pca[:, 0], contr_recon_test_pca[:, 1], color=colors, alpha=alphas)
#             # ax5.scatter(contr_reps_test_pca[row//2, :, 0], contr_reps_test_pca[row//2, :, 1], color=colors)
#             # ax[row, 5].get_xaxis().set_visible(False)
#             ax6.set_xlim([-1, 1])
#             ax6.set_ylim([-1, 1])
    
    fig.suptitle(title)
    fig.subplots_adjust(top=0.85)
    # fig.subplots_adjust(wspace=0.12)
    
    fig.subplots_adjust(hspace=0.2)
    
    if save_name is not None:
        plt.savefig(save_name)



## Plot results

In [46]:
fived_norm_attenc001_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr001/checkpoint_0500.pth.tar'
fived_norm_attenc0001_fc4_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr0001-fc4/checkpoint_0500.pth.tar'
fived_norm_attenc0001_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr0001/checkpoint_0500.pth.tar'
fived_norm_attenc01_sgd_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr01-sgd/checkpoint_0500.pth.tar'
fived_norm_attenc01_sgdfc4_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr01-sgd-fc4/checkpoint_0500.pth.tar'
fived_norm_attenc001_sgd_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr001-sgd/checkpoint_0500.pth.tar'
fived_norm_attenc001_sgdfc4_path = '/Users/ankit/Documents/PaninskiLab/SimCLR-torch/runs/new_augs/spike_att_tb/attenc-5d5d-lr001-sgd-fc4/checkpoint_0500.pth.tar'


contr_reps_train_5d_att001, contr_reps_test_5d_att001, contr_reps_test_5d_pca_att001, pca_train_5d, pca_test_5d = \
    get_ckpt_results(ckpt=fived_norm_attenc001_path, lat_dim=5, train_data=dy_wfs_interest, test_data=dy_wfs_test, \
                     plot=True, wfs=dy_wfs, wfs_interest=wf_interest_dy, \
                     title='5D contrastive model (augs=a/j/n, transformer, lr=0.001) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', \
                     save_name='pca_v_5dcont_normal_attenc_lr001')
contr_reps_train_5d_att0001_fc4, contr_reps_test_5d_att0001_fc4, contr_reps_test_5d_pca_att0001_fc4, _, _ = \
    get_ckpt_results(fived_norm_attenc0001_fc4_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.0001, fc4) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', fc=4, \
                     save_name='pca_v_5dcont_normal_attenc_lr0001fc4')
contr_reps_train_5d_att0001, contr_reps_test_5d_att0001, contr_reps_test_5d_pca_att0001, _, _ = \
    get_ckpt_results(fived_norm_attenc0001_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.0001) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', \
                     save_name='pca_v_5dcont_normal_attenc_lr0001')
contr_reps_train_5d_att01_sgd, contr_reps_test_5d_att01_sgd, contr_reps_test_5d_pca_att01_sgd, _, _ = \
    get_ckpt_results(fived_norm_attenc01_sgd_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.01, sgd) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', \
                     save_name='pca_v_5dcont_normal_attenc_lr01_sgd')
contr_reps_train_5d_att01_sgdfc4, contr_reps_test_5d_att01_sgdfc4, contr_reps_test_5d_pca_att01_sgdfc4, _, _ = \
    get_ckpt_results(fived_norm_attenc01_sgdfc4_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.01, sgd, fc4) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', fc=4, \
                     save_name='pca_v_5dcont_normal_attenc_lr01_sgdfc4')
contr_reps_train_5d_att001_sgd, contr_reps_test_5d_att001_sgd, contr_reps_test_5d_pca_att001_sgd, _, _ = \
    get_ckpt_results(fived_norm_attenc001_sgd_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.001, sgd) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', \
                     save_name='pca_v_5dcont_normal_attenc_lr001_sgd')
contr_reps_train_5d_att001_sgdfc4, contr_reps_test_5d_att001_sgdfc4, contr_reps_test_5d_pca_att001_sgdfc4, _, _ = \
    get_ckpt_results(fived_norm_attenc001_sgdfc4_path, 5, dy_wfs_interest, dy_wfs_test, True, dy_wfs, wf_interest_dy, \
                     '5D contrastive model (augs=a/j/n, transformer, lr=0.001, sgd, fc4) (2D PCA reps) - test set', \
                     enc_type='attention_encoder', fc=4, save_name='pca_v_5dcont_normal_attenc_lr001_sgdfc4')


train_reps = [pca_train_5d, 
              contr_reps_train_5d_att001, contr_reps_train_5d_att0001_fc4, contr_reps_train_5d_att0001, 
              contr_reps_train_5d_att01_sgd, contr_reps_train_5d_att01_sgdfc4, contr_reps_train_5d_att001_sgd,
              contr_reps_train_5d_att001_sgdfc4]
test_reps = [pca_test_5d,  
              contr_reps_test_5d_att001, contr_reps_test_5d_att0001_fc4, contr_reps_test_5d_att0001, 
              contr_reps_test_5d_att01_sgd, contr_reps_test_5d_att01_sgdfc4, contr_reps_test_5d_att001_sgd,
              contr_reps_test_5d_att001_sgdfc4]
model_names = ['PCA 5D',  
          'Contrastive 5D (augs=a/j/n, transformer, lr=0.001)', 'Contrastive 5D (augs=a/j/n, transformer, lr=0.0001, fc4)',
          'Contrastive 5D (augs=a/j/n, transformer, lr=0.0001)', 'Contrastive 5D (augs=a/j/n, transformer, lr=0.01, sgd)', 
          'Contrastive 5D (augs=a/j/n, transformer, lr=0.01, sgd, fc4)', 'Contrastive 5D (augs=a/j/n, transformer, lr=0.001, sgd)',
          'Contrastive 5D (augs=a/j/n, transformer, lr=0.001, sgd, fc4)']


per_class_map = per_class_accs(train_reps, test_reps, model_names, 10)
avg_class_map = avg_class_accs(train_reps, test_reps, model_names, 10)

pc_df = pd.DataFrame.from_dict(per_class_map, 'index')
ac_df = pd.DataFrame.from_dict(avg_class_map)

with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    display(pc_df)
    
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    display(ac_df)

Using projector; batchnorm False with depth 3; hidden_dim=512


KeyboardInterrupt: 