# Visualize PCA of weight, activity and Estimates MI from data

## Import lib

In [116]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader , TensorDataset
from torchvision.utils import save_image, make_grid
from torch.optim import Adam
import torch.nn.init as init

import numpy as np
import math

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import MultipleLocator
import matplotlib.cm as cm

import copy
import seaborn as sns

from scipy.stats import norm
from sklearn.neighbors import KernelDensity, LocalOutlierFactor
from sklearn.decomposition import PCA

import tqdm

import pickle

# MI estimators
from utils.estimators import *

## Load data

In [193]:
with open("data/weight_dynamic/HEBB-FULL_STATE-seed_0-rand-0.pkl", "rb") as f:
    loaded_buffer = pickle.load(f)
with open("data/weight_dynamic/MORPH-HEBB-FULL_STATE-seed_0-rand-0.pkl", "rb") as m:
    loaded_morph = pickle.load(m)

In [194]:
def reshape_input_output(val):
    '''
    Input dim : [timestep , np:[1,dim] ]
    Return dim : np: [timestep , dim]
    '''
    for i in range(len(val)):
        val[i] = val[i].reshape(-1)
    return np.array(val)

In [195]:
def select_activity_layer(activity,selected_layer):
    '''
    Input dim : [timestep][layer]
    Return dim : np: [timestep , dim]
    '''
    buff = []
    for t in range(len(activity)):
        layer = activity[t][selected_layer]
        buff.append(layer.reshape(-1) if isinstance(layer, np.ndarray) else layer.numpy().reshape(-1))
    return np.array(buff , dtype=object)

In [196]:
def select_weight_layer(weight,selected_layer):
    '''
    Input dim : [timestep][layer]
    Return dim : np: [timestep , dim]
    '''
    buff = []
    for t in range(len(weight)):
        _b = weight[t][selected_layer][0]
        buff.append(_b if isinstance(_b, np.ndarray) else _b.numpy())
    return np.array(buff , dtype=object)

Load Normal

In [197]:
activity    = loaded_buffer["brain"]["activity"]              # [timestep][layer_n][ 1 , hidden_dim]
weight      = loaded_buffer["brain"]["weight"]                # [timestep][layer_n][ 1 , 64 , 128]
input       = reshape_input_output(loaded_buffer["data"]["state"])           # [timestep][1 , all_observation] ---reshape---> [timestep, all_obs]
act_hid_1   = select_activity_layer(activity , 0)             # [timestep][layer_n][1 , hidden_dim] ---reshape---> [timestep, num_dim_of_selected_layer]
act_hid_2   = select_activity_layer(activity , 1)
weight_hid_1 = select_weight_layer(weight,0)
weight_hid_2 = select_weight_layer(weight,1)
weight_hid_3 = select_weight_layer(weight,2)
output      = reshape_input_output(loaded_buffer["data"]["action"])          # [timestep][1 , all_observation] ---reshape---> [timestep, all_obs]

Load Morph

In [198]:
m_activity    = loaded_morph["brain"]["activity"]              # [timestep][layer_n][ 1 , hidden_dim]
m_weight      = loaded_morph["brain"]["weight"]                # [timestep][layer_n][ 1 , 64 , 128]
m_input       = reshape_input_output(loaded_morph["data"]["state"])           # [timestep][1 , all_observation] ---reshape---> [timestep, all_obs]
m_act_hid_1   = select_activity_layer(m_activity , 0)             # [timestep][layer_n][1 , hidden_dim] ---reshape---> [timestep, num_dim_of_selected_layer]
m_act_hid_2   = select_activity_layer(m_activity , 1)
m_weight_hid_1 = select_weight_layer(m_weight,0)
m_weight_hid_2 = select_weight_layer(m_weight,1)
m_weight_hid_3 = select_weight_layer(m_weight,2)
m_output      = reshape_input_output(loaded_morph["data"]["action"])          # [timestep][1 , all_observation] ---reshape---> [timestep, all_obs]

## Activity


- `Activity` : 
    - FROM : [timestep][layer][? numpy + tensor]
    - TO   : [timestep, num_dim_of_selected_layer]

- `INPUT / OUTPUT` : 
    - FROM : [timestep][1 , all_observation]
    - TO : [timestep , dim]


### Measuring information : INPUT Hidden 1

#### Binning

In [None]:
# Create the first histogram
state_count, state_bound, state_bar = plt.hist(input[:, 0].ravel(), bins=30, range=[-1, 1], alpha=0.6, label='State')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend()
plt.show()
act_hid_1_count, act_hid_1_bound, act_hid_1_bar = plt.hist(act_hid_1.ravel(), bins=30, range=[-1, 1], alpha=0.6, label='Action Hidden 1',color='m')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend()
plt.show()
act_hid_2_count, act_hid_2_bound, act_hid_2_bar = plt.hist(act_hid_2.ravel(), bins=30, range=[-1, 1], alpha=0.6, label='Action Hidden 1',color='m')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [None]:
mutual_information_binning(state_count , act_hid_2_count , bins=30)

## PCA of HIDDEN LAYER

In [199]:
def PCA_plot3D_single(pca):
    # Create a 3D scatter plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # Plot the transformed data
    ax.scatter(pca[:, 0], pca[:, 1], pca[:, 2], label='w_hid1_flat', c='r', marker='o' ,s=10)
    # Labels and title
    ax.set_xlabel('PCA Component 1')
    ax.set_ylabel('PCA Component 2')
    ax.set_zlabel('PCA Component 3')
    ax.set_title('PCA')
    # Show legend
    ax.legend()
    # Show plot
    plt.show()

In [200]:
def PCA_plot3D_three(pca1,pca2,pca3):
    # Create a figure with 3 subplots (1 row, 3 columns)
    fig = plt.figure(figsize=(15, 5))

    # First subplot (for w_hid1_flat)
    ax1 = fig.add_subplot(131, projection='3d')
    ax1.scatter(pca1[:, 0], pca1[:, 1], pca1[:, 2], c='r', marker='o', s=10)
    ax1.set_xlabel('PCA Component 1')
    ax1.set_ylabel('PCA Component 2')
    ax1.set_zlabel('PCA Component 3')
    ax1.set_title('PCA data 1')

    # Second subplot (for w_hid2_flat)
    ax2 = fig.add_subplot(132, projection='3d')
    ax2.scatter(pca2[:, 0], pca2[:, 1], pca2[:, 2], c='g', marker='o', s=10)
    ax2.set_xlabel('PCA Component 1')
    ax2.set_ylabel('PCA Component 2')
    ax2.set_zlabel('PCA Component 3')
    ax2.set_title('PCA data 2')

    # Third subplot (for w_hid3_flat)
    ax3 = fig.add_subplot(133, projection='3d')
    ax3.scatter(pca3[:, 0], pca3[:, 1], pca3[:, 2], c='b', marker='o', s=10)
    ax3.set_xlabel('PCA Component 1')
    ax3.set_ylabel('PCA Component 2')
    ax3.set_zlabel('PCA Component 3')
    ax3.set_title('PCA data 3')

    # Show the plot
    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()

In [201]:
def PCA_plot3D_3in1(pca1 , pca2 , pca3):
    # Create a 3D scatter plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot the transformed data
    ax.scatter(pca1[:, 0], pca1[:, 1], pca1[:, 2], label='pca1', c='r', marker='o' ,s=10)
    ax.scatter(pca2[:, 0], pca2[:, 1], pca2[:, 2], label='pca2', c='g', marker='o', s=10)
    ax.scatter(pca3[:, 0], pca3[:, 1], pca3[:, 2], label='w_hid3_flat', c='b', marker='o', s=10)

    # Labels and title
    ax.set_xlabel('PCA Component 1')
    ax.set_ylabel('PCA Component 2')
    ax.set_zlabel('PCA Component 3')
    ax.set_title('PCA')
    # Show legend
    ax.legend()
    # Show plot
    plt.show()

In [202]:
def PCA_plot3D_two(pca1,pca2):
    # Create a figure with 3 subplots (1 row, 3 columns)
    fig = plt.figure(figsize=(15, 5))

    # First subplot (for w_hid1_flat)
    ax1 = fig.add_subplot(121, projection='3d')
    ax1.scatter(pca1[:, 0], pca1[:, 1], pca1[:, 2], c='r', marker='o', s=10)
    ax1.set_xlabel('PCA Component 1')
    ax1.set_ylabel('PCA Component 2')
    ax1.set_zlabel('PCA Component 3')
    ax1.set_title('PCA data 1')

    # Second subplot (for w_hid2_flat)
    ax2 = fig.add_subplot(122, projection='3d')
    ax2.scatter(pca2[:, 0], pca2[:, 1], pca2[:, 2], c='g', marker='o', s=10)
    ax2.set_xlabel('PCA Component 1')
    ax2.set_ylabel('PCA Component 2')
    ax2.set_zlabel('PCA Component 3')
    ax2.set_title('PCA data 2')
    # Show the plot
    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()

Normal PCA

In [203]:
pca = PCA(n_components=3)

w_hid1_flat = weight_hid_1.reshape(weight_hid_1.shape[0] , -1)
w_hid2_flat = weight_hid_2.reshape(weight_hid_2.shape[0] , -1)
w_hid3_flat = weight_hid_3.reshape(weight_hid_3.shape[0] , -1)
w_all_flat  = np.hstack((w_hid1_flat,w_hid2_flat,w_hid3_flat))

w1_pca = pca.fit_transform(w_hid1_flat)
w2_pca = pca.fit_transform(w_hid2_flat)
w3_pca = pca.fit_transform(w_hid3_flat)
w_pca  = pca.fit_transform(w_all_flat)

# Zeroout weight
