In [None]:
import numpy as np
import retinapy.mea as mea
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
colormap = pd.DataFrame({
    'names':['Red', 'Green', 'UV', 'Blue', 'Stim'],
    'display_hex':['#ff0a0a', '#0aff0a', '#0a0aff', '#303030', '#0a0a0a']})
colormap

## Data

In [None]:
# Load recording
rec_name = 'Chicken_17_08_21_Phase_00'
rec = mea.single_3brain_recording(
    rec_name,
    mea.load_stimulus_pattern('../data/ff_noise.h5'),
    mea.load_recorded_stimulus('../data/ff_recorded_noise.pickle'),
    mea.load_response('../data/ff_spike_response.pickle'))
rec = mea.decompress_recording(rec, downsample=18)
print(rec)

## Model

In [None]:
class Block(nn.Module):
    def __init__(self, in_n, out_n, residual=True, dilation=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv1d(in_n, out_n, kernel_size=10, stride=1, padding=0, dilation=dilation)
        self.bn = nn.BatchNorm1d(out_n)
        self.residual = residual
        
    def forward(self, x):
        x_p = F.relu(self.bn(self.conv1(x)))
        if self.residual:
            x_p = x_p + x[:,:,0:x_p.shape[-1]]
        return x_p

class GanglionAsCNN(nn.Module):
    LED_CHANNELS = 4

    def __init__(self, in_len, receptive_len=1000, receptive_offset=1, inc_cluster=True):
        super(GanglionAsCNN, self).__init__()
        self.receptive_len = receptive_len
        self.in_len = in_len
        self.receptive_offset = receptive_offset
        self.n_features = 20
        self.n_fc_features = 40
        self.out_len = in_len - receptive_len - receptive_offset
        self.PER_LOOP_WIN = 50
        
        # Input is the LED stimulus and the cell cluster's response.
        self.num_input_channels = self.LED_CHANNELS + int(inc_cluster)
        self.network = nn.Sequential(
            # 1000
            Block(self.num_input_channels, self.n_features, residual=False),
            # 991
            Block(self.n_features, self.n_features),
            # 982
            nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1),
            # -> 491
            Block(self.n_features, self.n_features, dilation=2),
            # -> 482
            Block(self.n_features, self.n_features, dilation=2),
            # -> 473
            Block(self.n_features, self.n_features, dilation=2),
            # -> 464
            nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=2),
            # 232
            Block(self.n_features, self.n_features, dilation=4),
            # 223
            Block(self.n_features, self.n_features, dilation=4),
            # 214
            nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=4),
            # 107
            Block(self.n_features, self.n_features, dilation=8),
            # 98
            Block(self.n_features, self.n_features, dilation=8),
            # 89
            Block(self.n_features, self.n_features, dilation=8),
            # 80
            nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=8),
            # 40
            Block(self.n_features, self.n_features, dilation=16),
            # 31
            Block(self.n_features, self.n_fc_features, residual=False, dilation=16),
            # 22
            nn.Conv1d(self.n_fc_features, 1, kernel_size=22, dilation=16), 
            nn.ReLU(),
        )
        
    def forward(self, x):
        ys = []
        i = 0
        count = 0
        per_loop_in = self.PER_LOOP_WIN * self.receptive_len
        while i < x.shape[-1]: # can be tighter, -1000
            x_len = min(x.shape[-1]-i, per_loop_in)
            x_part = x[:,:,i:i+x_len]
            pad_len = per_loop_in - x_len
            #print(f'{self.PER_LOOP_LEN}-{x_part.shape}')
            if pad_len:
                assert i >= self.out_len - per_loop_in, f'Only pad on the last loop. ({i})'
                x_part = F.pad(input=x_part, pad=(0, pad_len))
            y_part = self.network(x_part) # should be 1000 shorter
            assert y_part.shape[-1] == x_part.shape[-1] - 999, f'Got ({y_part.shape}-{x_part.shape})'
            ys.append(y_part)
            i += (per_loop_in - self.receptive_len + 1)
            count+=1
        y = torch.cat(ys, dim=-1)
        y = y[:,:,0:x.shape[-1]]
        return y

    def loss(self, pred_z, actual):
        pred_z = torch.squeeze(pred_z)
        #actual = actual[self.receptive_len + self.receptive_offset :]
        mask = actual == 1
        loss_1_only = 2*F.mse_loss(pred_z[mask], torch.ones_like(pred_z[mask]))
        loss = F.binary_cross_entropy_with_logits(pred_z[mask], torch.ones_like(pred_z[mask]))
        #correct = torch.sum(pred_0_1 == y)
        loss += F.binary_cross_entropy_with_logits(pred_z, actual)
        pred = torch.round(torch.sigmoid(pred_z))
        accuracy = torch.sum(pred == actual) / actual.shape[0]
        spike_only_correct = torch.sum(pred[actual == 1] == 1)
        spike_only_accuracy = spike_only_correct / torch.sum(actual == 1)
        print(
            f"loss: {loss.item():.4f} accuracy: {accuracy:.4f}, spike-only "
            f"accuracy: {spike_only_accuracy:.4f}"
        )
        return loss
    

## Train

In [None]:
def train_model(model, input_, spikes, epochs):
    # nn package also has different loss functions.
    # we use cross entropy loss for our classification task
    model.train()
    learning_rate = 1e-3
    lambda_l2 = 1e-5
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=lambda_l2
    )
    for epoch in range(epochs):
        pred = model(input_)
        loss = model.loss(pred, spikes)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return model

In [None]:
device = torch.device('cuda:0')
spikes = torch.from_numpy(np.array(rec.spikes.T[14])).float().to(device)
stimulus = torch.from_numpy(rec.stimulus).float().to(device)
input_ = torch.concat([stimulus, torch.unsqueeze(spikes, dim=-1)], dim=1)
input_ = input_.T # channel first
input_ = torch.unsqueeze(input_, dim=0)
in_len = rec.stimulus.shape[0]
model = GanglionAsCNN(in_len=in_len)
model.to(device)
train_model(model, input_, spikes, epochs=100)

## Results

In [None]:
def kernel_plot(kernel):
    fig = go.Figure()
    xs = np.arange(kernel.shape[0])
    # Shift the x-axis to have zero in the middle.
    for c in range(1,3):
        fig.add_trace(go.Scatter(x=xs, 
                                 y=kernel[:,c], 
                                 line_color=colormap.loc[c]['display_hex'], 
                                 mode='lines'))
    fig.update_layout(autosize=False,
                      height=300,
                      margin=dict(l=1, r=1, b=1, t=25, pad=1),
                      yaxis_fixedrange=True,
                      showlegend=False,
                      title='Kernel',
                      title_x=0.5,
                      title_pad=dict(l=1, r=1, b=10, t=1),
                      xaxis={'title':'time (ms), with spike at 0'},
                      yaxis={'title':'summed responses'} )
    return fig

In [None]:
kernel_plot(k).show()