In [1]:
import Anti_HebbFF_Network as AHF
import Data_Loader as DL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import torchviz
import sys
import copy
import os
import seaborn as sns
from time import sleep
from rich.console import Console
from memory_profiler import profile

In [2]:
# meta parameters
vec_len = 25
lr = 0.001
batch_size = 10
acc = 0.99
hid_dim = 25
out_dim = 1

In [3]:
def load_model(PATH):

    model = AHF.HebbFF(input_dim = vec_len, hid_dim=hid_dim, out_dim=out_dim, batch_size = batch_size)
    model.load_state_dict(torch.load(PATH))
    model.eval()

    return model
    
def test(model,r1, r2,criterion,PATH):
    ACC_R = []
    for i in range(r1,r2):
        T = max(200,20 * i)
        input_seq,target_seq = DL.generate_seq(batch_size = model.batch_size, vec_len = vec_len, R = i, T = T)
        output_seq,accuracy,total_loss,hidden_activity = AHF.test(model,criterion,input_seq,target_seq,T)
        ACC_R.append(accuracy.item())
        
        dif_seq = output_seq - target_seq
        # false_positive =

    r_axis = np.arange(r1,r2)
    acc_plt = plt.figure()
    plt.plot(r_axis,ACC_R)
    plt.title("Accuracy across different intervals after training with R in [{},{}]".format(r1,r2))
    acc_plt.savefig(PATH + '/Acc',  dpi = 600, facecolor='w', transparent=True)
    acc_plt.clf()
    
def HA(model,r,criterion,PATH,T_step):

    model.batch_size = 1
    
    T = max(200,20 * r)
    input_seq,target_seq = DL.generate_seq(batch_size = model.batch_size, vec_len = vec_len, R = r, T = T)
    output_seq,accuracy,total_loss,hidden_activity = AHF.test(model,criterion,input_seq,target_seq,T)
        
    for i in range(len(hidden_activity)):
        hidden_activity[i] = torch.squeeze(torch.squeeze(hidden_activity[i],-3),-1).detach().numpy()

    start = random.randint(0, T-T_step)

    hidden_activity = np.transpose(np.array(hidden_activity))
    hidden_activity = hidden_activity[:,start:start+T_step]
    return hidden_activity
    
def acc_train(model, acc, vec_len, R, T, criterion, optimizer,print_log):
    iterations = 0
    avg_acc = 0
    avg_loss = 0
    Acc = 0
    while Acc < acc:
        iterations += 1
        input_seq,target_seq = DL.generate_seq(batch_size = model.batch_size, vec_len = vec_len, R = R, T = T)
        accuracy,loss = AHF.train(model,criterion,optimizer,input_seq,target_seq,T)

        avg_loss += loss
        avg_acc += accuracy
        if iterations % print_log == 0:
            print('Iterations: {}....'.format(iterations), end=' ')
            print("Average Loss: {:.4f}".format(avg_loss / print_log), end=' | ')
            avg_loss = 0
            print("Average Accuracy: {:.4f}".format(avg_acc / print_log))
            avg_acc = 0


        Acc = accuracy
        
    print('----------------------------Training Summary: Total Iterations: {}'.format(iterations),'Final Loss:{:.4f}'.format(loss),"Final Accuracy: {:.4f}".format(accuracy),'----------------------------')


## R = 1

In [None]:
# Starting New Models
R = 0
HebbFF_CFD = AHF.HebbFF(input_dim = vec_len, hid_dim=hid_dim, out_dim=out_dim, batch_size = batch_size)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(HebbFF_CFD.parameters(),lr = lr)


'''
# loading previous model
R = 3
PATH = os.getcwd() + '/Sigmoid_v_{}_h_{}/R_{}'.format(vec_len,hid_dim,R)
HebbFF_CFD = load_model(PATH + '/Model')
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(HebbFF_CFD.parameters(), lr = 1e-3)
'''



while R < 14:
    R += 1
    print('Current Training Interval Value: ',R)
    T = max(200,R*20)

    PATH = os.getcwd() + '/Sigmoid_v_{}_h_{}/R_{}'.format(vec_len,hid_dim,R)
    isExist = os.path.exists(PATH)

    if not isExist:
        os.makedirs(PATH)

    # training
    acc_train(HebbFF_CFD, acc, vec_len, R, T, criterion, optimizer,print_log = 100)

    # saving the model
    torch.save(HebbFF_CFD.state_dict(), PATH + '/Model')

    # test accuracy
    test(HebbFF_CFD,1,20,criterion,PATH)

    # hidden_layer heatmap
    hidden_activity = HA(HebbFF_CFD,R,criterion,PATH,T_step = 20)
    ha_plot = sns.heatmap(hidden_activity, linewidth=0.3)
    ha_figure = ha_plot.get_figure()
    plt.title('Hidden Activity')
    ha_figure.savefig(PATH+'/Hidden_Activity',  dpi = 600, facecolor='w', transparent=True)
    ha_figure.clf()

    # static matrix
    sm_plot = sns.heatmap(HebbFF_CFD.w1.detach().numpy(), linewidth=0.3)
    sm_figure = sm_plot.get_figure()
    plt.title('Static Weight Matrix')
    sm_figure.savefig(PATH+'/W_1',  dpi = 600, facecolor='w', transparent=True)
    sm_figure.clf()

Current Training Interval Value:  1
Iterations: 100.... Average Loss: 0.6463 | Average Accuracy: 0.6518
Iterations: 200.... Average Loss: 0.5012 | Average Accuracy: 0.6679
Iterations: 300.... Average Loss: 0.4305 | Average Accuracy: 0.6681
Iterations: 400.... Average Loss: 0.4039 | Average Accuracy: 0.7139
Iterations: 500.... Average Loss: 0.3791 | Average Accuracy: 0.7562
Iterations: 600.... Average Loss: 0.3558 | Average Accuracy: 0.7879
Iterations: 700.... Average Loss: 0.3313 | Average Accuracy: 0.8178
Iterations: 800.... Average Loss: 0.3070 | Average Accuracy: 0.8464
Iterations: 900.... Average Loss: 0.2841 | Average Accuracy: 0.8735
Iterations: 1000.... Average Loss: 0.2652 | Average Accuracy: 0.8956
Iterations: 1100.... Average Loss: 0.2483 | Average Accuracy: 0.9155
Iterations: 1200.... Average Loss: 0.2315 | Average Accuracy: 0.9324
Iterations: 1300.... Average Loss: 0.2158 | Average Accuracy: 0.9466
Iterations: 1400.... Average Loss: 0.2018 | Average Accuracy: 0.9587
Iterati