In [1]:
import os
import numpy as np

from IPython.display import Image
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set(font_scale=2)
sns.set_style('ticks')

matplotlib.rcParams.update({'font.size': 16})
matplotlib.rc('axes', titlesize=16)

import torch
import torch.functional as F
import glob
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

from glia.gn import WeightNoise
from glia.gn import WeightLoss
from glia import gn

class GliaNet(nn.Module):
    """A simple test model"""

    def __init__(self, z_features=10, activation_function='Softmax'):
        # --------------------------------------------------------------------
        # Init
        super().__init__()
        if z_features < 2:
            raise ValueError("z_features must be > 2.")
        self.z_features = z_features

        # Lookup activation function (a class)
        AF = getattr(nn, activation_function)
#         self.phi = AF(dim=1)
        
        # --------------------------------------------------------------------
        # Def fc1:
        self.fc1 = gn.Slide(self.z_features)
        self.fc2 = gn.Gather(self.z_features)

    def forward(self, x, verbose=False):
        x = self.fc1(x)
        if verbose: print(x)
        x = self.fc2(x)
        if verbose: print(x)
            
        return x

# Create some shared input

In [56]:
x = torch.rand(1, 4)
net = GliaNet(4)
print(x)

tensor([[0.3606, 0.2939, 0.8955, 0.9413]])


In [58]:
print(net)

GliaNet(
  (fc1): Slide(in_features=4, out_features=4, bias=True)
  (fc2): Gather(in_features=4, out_features=2, bias=True)
)


In [59]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.3687, -0.3640, -0.4445,  0.4557],
        [ 0.3500,  0.3443, -0.0640, -0.3020],
        [ 0.0855,  0.2554, -0.3448,  0.2174],
        [-0.4958,  0.4272,  0.3583,  0.3005]])), ('fc1.bias', tensor([-0.2882,  0.0052, -0.4574, -0.0070])), ('fc2.weight', tensor([[ 0.0156, -0.4211,  0.0527, -0.4432],
        [-0.2517,  0.2604, -0.1322, -0.1776]])), ('fc2.bias', tensor([-0.0076,  0.4532]))])


In [60]:
print(net(x))

tensor([[-0.0691, -0.0426]], grad_fn=<AddBackward0>)


In [46]:
clone = GliaNet(4)
clone.load_state_dict(net.state_dict())
print(clone.state_dict())

OrderedDict([('fc1.weight', tensor([[ 0.3393, -0.1487, -0.3940, -0.0267],
        [ 0.0843,  0.3380, -0.3070,  0.3612],
        [-0.0080, -0.3376, -0.0311,  0.2769],
        [-0.0065, -0.2602,  0.2513, -0.0621]])), ('fc1.bias', tensor([-0.0601, -0.3608, -0.0242, -0.2862])), ('fc2.weight', tensor([[-0.2137, -0.3110, -0.2661,  0.0418],
        [ 0.2199,  0.1162, -0.0369,  0.1128]])), ('fc2.bias', tensor([0.0565, 0.4851]))])


# Add noise to the net`

In [47]:
noise = WeightNoise(.05)
net.apply(noise)

GliaNet(
  (fc1): Slide(in_features=4, out_features=4, bias=True)
  (fc2): Gather(in_features=4, out_features=2, bias=True)
)

In [48]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[ 0.2928, -0.2119, -0.3444, -0.0609],
        [ 0.0761,  0.3833, -0.2535,  0.3296],
        [-0.0302, -0.3918, -0.0610,  0.2345],
        [-0.0604, -0.2299,  0.2196, -0.0627]])), ('fc1.bias', tensor([-0.0601, -0.3608, -0.0242, -0.2862])), ('fc2.weight', tensor([[-0.2165, -0.3550, -0.3198,  0.0012],
        [ 0.1935,  0.0760, -0.0800,  0.0693]])), ('fc2.bias', tensor([0.0565, 0.4851]))])


In [49]:
print(net(x))

tensor([[-0.0307,  0.0425]], grad_fn=<AddBackward0>)


# Drop connections

In [50]:
lost = WeightLoss(.1)
net.apply(lost)

GliaNet(
  (fc1): Slide(in_features=4, out_features=4, bias=True)
  (fc2): Gather(in_features=4, out_features=2, bias=True)
)

In [51]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[ 0.2928, -0.2119, -0.3444, -0.0609],
        [ 0.0761,  0.3833, -0.0000,  0.3296],
        [-0.0302, -0.3918, -0.0610,  0.2345],
        [-0.0604, -0.2299,  0.2196, -0.0627]])), ('fc1.bias', tensor([-0.0601, -0.3608, -0.0242, -0.2862])), ('fc2.weight', tensor([[-0.2165, -0.3550, -0.3198,  0.0012],
        [ 0.1935,  0.0760, -0.0800,  0.0693]])), ('fc2.bias', tensor([0.0565, 0.4851]))])


In [36]:
print(net(x))

tensor([[-0.0128,  0.0000]], grad_fn=<AddBackward0>)


# Check clone

In [53]:
print(clone.state_dict())

OrderedDict([('fc1.weight', tensor([[ 0.3393, -0.1487, -0.3940, -0.0267],
        [ 0.0843,  0.3380, -0.3070,  0.3612],
        [-0.0080, -0.3376, -0.0311,  0.2769],
        [-0.0065, -0.2602,  0.2513, -0.0621]])), ('fc1.bias', tensor([-0.0601, -0.3608, -0.0242, -0.2862])), ('fc2.weight', tensor([[-0.2137, -0.3110, -0.2661,  0.0418],
        [ 0.2199,  0.1162, -0.0369,  0.1128]])), ('fc2.bias', tensor([0.0565, 0.4851]))])
