In [1]:
import plotly.graph_objects as go
import json

In [2]:
with open('results.json') as f:
    data = json.load(f)
f.close()

num_params_50, hidden_dim_50, num_layers_50, intrinsic_dim_50 = [], [], [], []
num_params_100, hidden_dim_100, num_layers_100, intrinsic_dim_100 = [], [], [], []
num_params_200, hidden_dim_200, num_layers_200, intrinsic_dim_200 = [], [], [], []
num_params_400, hidden_dim_400, num_layers_400, intrinsic_dim_400 = [], [], [], []
for k,v in data.items():
    if v['validation_accuracy'] >= 0.9:
        if v['hidden_dimension'] == 50:
            num_params_50.append(v['number_parameter'])
            hidden_dim_50.append(v['hidden_dimension'])
            num_layers_50.append(v['number_layers'])
            intrinsic_dim_50.append(v['intrinsic_dimension'])
        elif v['hidden_dimension'] == 100:
            num_params_100.append(v['number_parameter'])
            hidden_dim_100.append(v['hidden_dimension'])
            num_layers_100.append(v['number_layers'])
            intrinsic_dim_100.append(v['intrinsic_dimension'])
        elif v['hidden_dimension'] == 200:
            num_params_200.append(v['number_parameter'])
            hidden_dim_200.append(v['hidden_dimension'])
            num_layers_200.append(v['number_layers'])
            intrinsic_dim_200.append(v['intrinsic_dimension'])
        elif v['hidden_dimension'] == 400:
            num_params_400.append(v['number_parameter'])
            hidden_dim_400.append(v['hidden_dimension'])
            num_layers_400.append(v['number_layers'])
            intrinsic_dim_400.append(v['intrinsic_dimension'])
        else:
            print('Error')

In [3]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=num_params_50,
    y=intrinsic_dim_50,
    mode='markers',
    marker=dict(
        color=num_layers_50,
        size=30,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 50',
))

fig.add_trace(go.Scatter(
    x=num_params_100,
    y=intrinsic_dim_100,
    mode='markers',
    marker=dict(
        color=num_layers_100,
        size=55,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 100',
))

fig.add_trace(go.Scatter(
    x=num_params_200,
    y=intrinsic_dim_200,
    mode='markers',
    marker=dict(
        color=num_layers_200,
        size=70,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 200'
))

fig.add_trace(go.Scatter(
    x=num_params_400,
    y=intrinsic_dim_400,
    mode='markers',
    marker=dict(
        color=num_layers_400,
        size=90,
        showscale=True
        ),
    error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[1, 2, 3, 4, 5],
            visible=True),
    name='layer width: 400'
))

fig.update_layout(
    legend=dict(
        itemsizing='trace',
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
        )
    )

fig.show()

In [6]:
with open('results.json') as f:
    data = json.load(f)
f.close()

num_params, hidden_dim, num_layers, intrinsic_dim = [], [], [], []
for k,v in data.items():
    if v['validation_accuracy'] >= 0.9:
        num_params.append(v['number_parameter'])
        hidden_dim.append(v['hidden_dimension'])
        num_layers.append(v['number_layers'] + 1)
        intrinsic_dim.append(v['intrinsic_dimension'])

sizeref = max(hidden_dim)/(30**2)

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=num_params,
    y=intrinsic_dim,
    mode='markers',
    marker=dict(
        color=num_layers, size=hidden_dim,
        sizemode = 'area', sizeref=sizeref,
        sizemin=2, showscale = True
        ),
#    error_y=dict(
#            type='data', # value of error bar given in data coordinates
#            array=[1, 2, 3, 4, 5],
#            visible=True),
))

fig.update_layout(
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
        )
    )

fig.show()

In [43]:
# load all weights in this folder of .pt into an array then plot the weights
#pay attention that all the images of the weight tensor are padded with zero to the start of the tensor to match the size  of a perfect square
import torch
import plotly.express as px
import os
import math
import plotly.graph_objects as go

def plot_kernels_3d(tensors):
    for tensor in tensors:
        root = math.sqrt(tensor.shape[0])        
        if int(root) * int(root) == tensor.shape[0]:
            num_cols = int(root)
            num_rows = int(root)
            fig = go.Figure(data=go.Surface(z=tensor.reshape(num_rows, num_cols)))
            fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                 highlightcolor="limegreen", project_z=True))
            fig.show() 
        else:
            num_cols = int(root) + 1
            num_rows = int(root) + 1
            a = torch.zeros(num_rows * num_cols)
            print(a.shape)
            a[a.shape[0]-tensor.shape[0]:,] = tensor
            fig = go.Figure(data=go.Surface(z=a.reshape(num_rows, num_cols)))
            fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                 highlightcolor="limegreen", project_z=True))
            fig.show() 

def plot_kernels_2d(tensors):
    for tensor in tensors:
        root = math.sqrt(tensor.shape[0])        
        if int(root) * int(root) == tensor.shape[0]:
            num_cols = int(root)
            num_rows = int(root)
            fig = px.imshow(tensor.reshape(num_rows, num_cols))
            fig.show() 
        else:
            num_cols = int(root) + 1
            num_rows = int(root) + 1
            a = torch.zeros(num_rows * num_cols)
            print(a.shape)
            a[a.shape[0]-tensor.shape[0]:,] = tensor
            fig = px.imshow(a.reshape(num_rows, num_cols))
            fig.show() 
         
def plot_kernels_1d(tensors):
    for tensor in tensors:
        fig = go.Figure()
        print(tensor.size())
        fig.add_trace(go.Scatter(x=[i for i in range(0,int(tensor.size()[0]))], y=tensor ))
        fig.show()       

def plot_kernels_1d_sorted(tensors):
    for tensor in tensors:
        fig = go.Figure()
        print(tensor.size())
        fig.add_trace(go.Scatter(x=[i for i in range(0,int(tensor.size()[0]))], y=tensor.sort()[0] ))
        fig.show()    

def load_model(path):
    model = torch.load(path)
    return model

def get_weights(model):
    weights = model['V']
    return weights.t()[0].cpu()

def get_weights_from_folder(folder):
    weights = []
    for file in os.listdir(folder):
        if file.endswith(".pt") and not "id0" in file:
            model = load_model(folder+'/'+file)
            weights.append(get_weights(model))
    plot_kernels_1d_sorted(weights)

get_weights_from_folder('./model/lenet/mnist/')



torch.Size([1000])


torch.Size([170])


torch.Size([175])


torch.Size([195])


torch.Size([205])


torch.Size([210])


torch.Size([225])


torch.Size([250])


torch.Size([275])


torch.Size([300])


torch.Size([400])


torch.Size([500])
