In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy import signal
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import seaborn as sns

from tqdm import tqdm

import sys
sys.path.append('/Users/ag1880/Github_repos/Milstein-Lab/dentate_circuit_model/')
from optimize_dynamic_model import get_binary_input_patterns

%matplotlib inline

In [10]:
class dentate_network(nn.Module):
    def __init__(self, lr):
        self.lr = lr
        self.input_size = 7
        self.output_size = 128
        self.fbi_size = 7
        
        self.fbi_init = torch.zeros(self.fbi_size)
        
        super().__init__()
                        
        self.out_in  = nn.Linear(self.input_size, self.output_size)
        self.fbi_out = nn.Linear(self.output_size, self.fbi_size)     
        self.out_fbi = nn.Linear(self.fbi_size, self.output_size)
        
#         self.out_in.bias.data.uniform_(3,3)
                        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        self.lossCriterion = nn.CrossEntropyLoss()
        
    def forward(self, input_pattern):
        out1 = self.out_in(input_pattern)
        out1_activation = F.elu(out1) # initial output activity
        
        out2 = self.fbi_out(out1_activation)
        out2_activation = F.elu(out2) # FB inh activity
        
        out3 = self.out_fbi(out2_activation)
        
        output = out3
#         output = out1 + out3
        
        return F.softmax(output)
    
    def forward_relu(self, input_pattern):
        out1 = self.out_in(input_pattern)
        out1_activation = F.relu(out1) # initial output activity
        
        out2 = self.fbi_out(out1_activation)
        out2_activation = F.relu(out2) # FB inh activity
        
        out3 = self.out_fbi(out2)
        output = out1 + out3
        
        return F.relu(output)
    
    def forward_semilinear(self, input_pattern):
        out1 = self.out_in(input_pattern)
        out1_activation = F.relu(out1) # initial output activity
        
        out2 = self.fbi_out(out1_activation)
        out2_activation = F.relu(out2) # FB inh activity
        
        out3 = self.out_fbi(out2)
        output = out1 + out3
        
        return output
    
    def train(self,epochs,time,all_patterns):
        losses = []
        
        for i in tqdm(range(epochs)):
            self.optimizer.zero_grad()
            
            pattern_index = np.random.randint(0,128)
            input_pattern = all_patterns[pattern_index]
            
            output = self.forward(input_pattern)
            target = torch.eye(128)[pattern_index]
                        
            non_target_idx = [x for x in range(128) if x!=pattern_index]
            non_target_loss = -torch.sum(torch.log(1-output[non_target_idx]))
            target_loss = -torch.log(output[pattern_index]+0.01)
            loss = target_loss #+ non_target_loss
                 
            losses.append(loss.detach())
            loss.backward()
            self.optimizer.step()
            
        return losses

In [11]:
# Create network
dentate_net = dentate_network(lr=0.1)
all_patterns = torch.tensor(get_binary_input_patterns(7)).type(torch.float32)

In [12]:
# Save initial state
w_init = {'out_in': dentate_net.out_in.weight.detach().numpy().flatten(),
          'fbi_out': dentate_net.fbi_out.weight.detach().numpy().flatten(),
          'out_fbi': dentate_net.out_fbi.weight.detach().numpy().flatten()}

b_init = {'out_in': dentate_net.out_in.bias.detach().numpy().flatten(),
          'fbi_out': dentate_net.fbi_out.bias.detach().numpy().flatten(),
          'out_fbi': dentate_net.out_fbi.bias.detach().numpy().flatten()}

output_init = []
for pattern in all_patterns:
    output_init.append(list(dentate_net.forward(pattern).detach()))
output_init = np.array(output_init)

In [13]:
# Train network
epochs = 10**5
time = 10
losses = dentate_net.train(epochs, time, all_patterns)

  0%|          | 0/100000 [00:00<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
w_final = {'out_in': dentate_net.out_in.weight.detach().numpy().flatten(),
          'fbi_out': dentate_net.fbi_out.weight.detach().numpy().flatten(),
          'out_fbi': dentate_net.out_fbi.weight.detach().numpy().flatten()}

b_final = {'out_in': dentate_net.out_in.bias.detach().numpy().flatten(),
          'fbi_out': dentate_net.fbi_out.bias.detach().numpy().flatten(),
          'out_fbi': dentate_net.out_fbi.bias.detach().numpy().flatten()}

output_final = []
for pattern in all_patterns:
    output_final.append(list(dentate_net.forward(pattern).detach()))
output_final = np.array(output_final)

output_final_semilinear = []
for pattern in all_patterns:
    output_final_semilinear.append(list(dentate_net.forward_semilinear(pattern).detach()))
output_final_semilinear = np.array(output_final_semilinear)

In [None]:
matplotlib.rcParams.update({'font.size': 15})

fig = plt.figure(figsize=(14,8.5))
axes = gs.GridSpec(nrows=3, ncols=12)
    
# Row 1: output activities
ax = fig.add_subplot(axes[0,0:3])
im = ax.imshow(output_init)
ax.set_title('Initital output')
ax.set_xlabel('unit ID')
ax.set_ylabel('pattern ID')
cbar = plt.colorbar(im, ax=ax)

ax = fig.add_subplot(axes[0,3:6])
ax.plot(losses)
ax.set_xlabel('epochs')
ax.set_title('Hyperbolic(+squared off-target) Loss')
ax.set_ylim([0,np.max(losses)])
# ax.set_ylim([0,1])

ax = fig.add_subplot(axes[0,6:9])
im = ax.imshow(output_final)
ax.set_title('Final output \n Softmax')
ax.set_xlabel('unit ID')
ax.set_ylabel('pattern ID')
cbar = plt.colorbar(im, ax=ax)

ax = fig.add_subplot(axes[1,9:12])
im = ax.imshow(output_final,vmax=1)
ax.set_title('Final output \n ReLU (vmax=1)')
ax.set_xlabel('unit ID')
ax.set_ylabel('pattern ID')
cbar = plt.colorbar(im, ax=ax)

ax = fig.add_subplot(axes[0,9:12])
im = ax.imshow(output_final_semilinear)
ax.set_title('Final output \n linear')
ax.set_xlabel('unit ID')
ax.set_ylabel('pattern ID')
cbar = plt.colorbar(im, ax=ax)


# Row 2: weights
ax = fig.add_subplot(axes[1,0:3])
ax.hist(w_init['out_in'],30,alpha=0.5,label='initial')
ax.hist(w_final['out_in'],30,alpha=0.5,label='final')
ax.set_title('Weights: Input -> Output')
ax.legend(loc='best', frameon=False,handlelength=1)

ax = fig.add_subplot(axes[1,3:6])
ax.hist(w_init['fbi_out'],30,alpha=0.5,label='initial')
ax.hist(w_final['fbi_out'],30,alpha=0.5,label='final')
ax.set_title('Weights: Output -> FB Inh')
ax.legend(loc='best', frameon=False,handlelength=1)

ax = fig.add_subplot(axes[1,6:9])
ax.hist(w_init['out_fbi'],30,alpha=0.5,label='initial')
ax.hist(w_final['out_fbi'],30,alpha=0.5,label='final')
ax.set_title('Weights: FB Inh -> Output')
ax.legend(loc='best', frameon=False,handlelength=1)


# Row 3: biases
ax = fig.add_subplot(axes[2,0:3])
ax.hist(b_init['out_in'],30,alpha=0.5,label='initial')
ax.hist(b_final['out_in'],30,alpha=0.5,label='final')
ax.set_title('Biases: Input -> Output')
ax.legend(loc='best', frameon=False,handlelength=1)

ax = fig.add_subplot(axes[2,3:6])
ax.hist(b_init['fbi_out'],4,alpha=0.5,label='initial')
ax.hist(b_final['fbi_out'],50,alpha=0.5,label='final')
ax.set_title('Biases: Output -> FB Inh')
ax.legend(loc='best', frameon=False,handlelength=1)

ax = fig.add_subplot(axes[2,6:9])
ax.hist(b_init['out_fbi'],30,alpha=0.5,label='initial')
ax.hist(b_final['out_fbi'],30,alpha=0.5,label='final')
ax.set_title('Biases: FB Inh -> Output')
ax.legend(loc='best', frameon=False,handlelength=1)


sns.despine()
fig.tight_layout()
plt.show()

# name = "backprop_network_FBI_MSE_ReLU"
# fig.savefig('../plots/'+name+'.png', edgecolor='white', dpi=300, facecolor='white', transparent=True)