In [1]:
import os
import numpy as np

from IPython.display import Image
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set(font_scale=2)
sns.set_style('ticks')

matplotlib.rcParams.update({'font.size': 16})
matplotlib.rc('axes', titlesize=16)

import torch
import torch.functional as F
import glob
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

from glia.gn import WeightNoise
from glia.gn import WeightLoss
from glia import gn

class GliaNet(nn.Module):
    """A simple test model"""

    def __init__(self, z_features=10, activation_function='Softmax'):
        # --------------------------------------------------------------------
        # Init
        super().__init__()
        if z_features < 2:
            raise ValueError("z_features must be > 2.")
        self.z_features = z_features

        # Lookup activation function (a class)
        AF = getattr(nn, activation_function)
#         self.phi = AF(dim=1)
        
        # --------------------------------------------------------------------
        # Def fc1:
        self.fc1 = gn.Slide(self.z_features)
        self.fc2 = gn.Gather(self.z_features)

    def forward(self, x, verbose=False):
        x = self.fc1(x)
        if verbose: print(x)
        x = self.fc2(x)
        if verbose: print(x)
            
        return x

# Create some shared input

In [2]:
x = torch.rand(1, 4)
net = GliaNet(4)
print(x)

tensor([[0.9521, 0.1484, 0.4430, 0.0227]])


In [3]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.3427, -0.2033,  0.4842,  0.1578],
        [ 0.3134,  0.2830,  0.4041, -0.4798],
        [-0.1142,  0.2313,  0.4118, -0.2118],
        [-0.4191, -0.0824,  0.4373, -0.3539]])), ('fc1.bias', tensor([-0.3990, -0.2006,  0.2941, -0.3459])), ('fc2.weight', tensor([[ 0.4834,  0.4472, -0.4185,  0.0995],
        [ 0.0701,  0.0525,  0.2630, -0.0287]])), ('fc2.bias', tensor([ 0.3917, -0.3640]))])


In [4]:
print(net(x))

tensor([[0.1261, 0.0085]], grad_fn=<AddBackward0>)


# Add noise to the net`

In [5]:
noise = WeightNoise(.05)
net.apply(noise)

GliaNet(
  (fc1): Slide(in_features=4, out_features=4, bias=True)
  (fc2): Gather(in_features=4, out_features=2, bias=True)
)

In [6]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.3128, -0.2219,  0.5441,  0.1779],
        [ 0.3258,  0.3897,  0.4004, -0.4754],
        [-0.2146,  0.2446,  0.3605, -0.1491],
        [-0.4010, -0.1423,  0.3702, -0.3478]])), ('fc1.bias', tensor([-0.3990, -0.2006,  0.2941, -0.3459])), ('fc2.weight', tensor([[ 0.5563,  0.4668, -0.5004,  0.1363],
        [ 0.0401,  0.1108,  0.2897, -0.1559]])), ('fc2.bias', tensor([ 0.3917, -0.3640]))])


In [7]:
print(net(x))

tensor([[0.1393, 0.0226]], grad_fn=<AddBackward0>)


# Drop connections

In [34]:
lost = WeightLoss(.1)
net.apply(lost)

GliaNet(
  (fc1): Slide(in_features=4, out_features=4, bias=True)
  (fc2): Gather(in_features=4, out_features=2, bias=True)
)

In [35]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.0000, -0.2219,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4004, -0.0000],
        [-0.2146,  0.0000,  0.0000, -0.0000],
        [-0.0000, -0.0000,  0.0000, -0.0000]])), ('fc1.bias', tensor([-0.3990, -0.2006,  0.2941, -0.3459])), ('fc2.weight', tensor([[0.5563, 0.0000, -0.0000, 0.1363],
        [0.0000, 0.0000, 0.2897, -0.0000]])), ('fc2.bias', tensor([ 0.3917, -0.3640]))])


In [36]:
print(net(x))

tensor([[-0.0128,  0.0000]], grad_fn=<AddBackward0>)
