In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

''' Parameters (CHange Anything Here!) '''
transform = transforms.ToTensor()
batch_size = 3
# lifetime Sparcity
k_percent = 5


''' Code Starts Here '''
# Data MNIST
mnist_data = datasets.MNIST(root='./data', train = True, download = True, transform = transform)
data_loader = torch.utils.data.DataLoader(dataset= mnist_data, batch_size = batch_size, shuffle = True)
dataiter = iter(data_loader)
images, labels = dataiter.next()


# testing model
''' Conv 2d Layer 
#         Accessible Variables: .weights(Tensor), .bias(Tensor)
#         parameters :
#         torch.nn.Conv2d(in_channels, out_channels, 
#                         kernel_size, stride=1, padding=0, 
#                         dilation=1, groups=1, bias=True, 
#                         padding_mode='zeros')
'''
# CONV-WTA CRITERIA
# - zero padded, so that each feature map has the same size as the input
# - hidden representation is mapped linearly to the output using a deconvolution operation
# - Parameters are optimized to reduce the mean squared error MSE
# - Conv layer is 5 x5, DECONVOLUTION layer is using filters of 11x 11
### In this implementation, I will not use deconvolution, but transpose convolution to ease process
class Autoencoder_Test(nn.Module):
    def __init__(self):
        super().__init__()

        #Image size:N, 28, 28
        self.conv1      = nn.Conv2d(1, 2, 5, stride=1, padding = 2) 
        self.transConv1 = nn.ConvTranspose2d(in_channels=2, out_channels=3, kernel_size=11, stride =1, padding = 5) # padding will decrease output size
        
    def forward(self, x):
        encoded = self.conv1(x) # encode, output: torch.Size([3, 2, 26, 26])
        hidden, winners = self.spatial_sparsity_(encoded)
        hidden = self.lifetime_sparsity_(hidden, winner, k_percent = 0.1)
        decoded = self.transConv1(hidden)
        return decoded
    
    # Spatial Sparsity reconstructs the activation map, remain only one winner neuron of each feature map and rest to 0
    # with torch.no_grad() temporarily sets all of the requires_grad flags to false
    def spatial_sparsity_(self, hiddenMaps):
        with torch.no_grad():
            shape = hiddenMaps.shape  #torch.Size([batch_size, feature_num, 26, 26])
            n_batches = shape[0]
            n_features = shape[1]
            size = shape[2]
            
            # Step 1: flatten it out, find max_vals
            flatten = hiddenMaps.view(n_batches, n_features, -1)
            max_val, batch_idx = torch.max(summation, 0) # max_val return size[n_batches, n_features]
            
            # Step 2: creating "drop" Array to be multiplied into featureMaps, dropping loser values
            maxval, _ = torch.max(flatten, 2)
            maxval_p = torch.reshape(maxval, (n_batches, n_features, 1, 1))
            drop = torch.where(torch.reshape(b, (n_batches, n_features, size, size)) < maxval_p, 
                               torch.zeros((n_batches, n_features, size, size)), 
                               torch.ones((n_batches,n_features2, size, size)))
            
            return hiddenMaps*drop, maxval
    # Only retain the top-k percent of the winners for every feature. The rest will be zeroed out
    def lifetime_sparsity_(self, hiddenMaps, maxval, k_percent):
        with torch.no_grad():
            shape = hiddenMaps.shape  #torch.Size([batch_size, feature_num, 26, 26])
            n_batches = shape[0]
            n_features = shape[1]
            size = shape[2]
            k = floor(n_batches * k_percent)
            
            # Step 1: pick the max of dim-0, along the batch axis
            # if input size is (n, c), returns (k, c) if operate over dim-0
            top_k, _ = torch.topk(maxval, k, 0)
            winner_t = torch.transpose(a, 0, 1)
            
            # Step 2: creating "drop" Array to be multiplied into featureMaps, dropping loser values
            drop = torch.where(winner_t < top_k[:,k-1:k], 
                               torch.zeros((n_batches, n_features, size, size)), 
                               torch.ones((n_batches, n_features, size, size)))
            
            # dropping all them loser batches to zero
            return hiddenMaps * drop.reshape(n_batches, n_features, 1, 1)
    
def _lifetime_sparsity(self, h, winner, rate):
    shape = tf.shape(winner)
    n = shape[0]
    c = shape[1]
    k = tf.cast(rate * tf.cast(n, tf.float32), tf.int32)

    winner = tf.transpose(winner) # c, n
    th_k, _ = tf.nn.top_k(winner, k) # c, k

    shape_t = tf.stack([c, n])
    drop = tf.where(winner < th_k[:,k-1:k], # c, n
      tf.zeros(shape_t, tf.float32), tf.ones(shape_t, tf.float32))
    drop = tf.transpose(drop) # n, c
    return h * tf.reshape(drop, tf.stack([n, 1, 1, c]))
    
model = Autoencoder_Test()
generator = model.parameters() #(returns a generator)
# criterion = RMSELoss()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 1e-5)

In [None]:
# Flow of the example code in tf:
def loss(self, x, lifetime_sparsity=0.20):
    h = self.encoder(x)
    h, winner = self._spatial_sparsity(h)
    h = self._lifetime_sparsity(h, winner, lifetime_sparsity)
    y = self._decoder(h)
    
# def _spatial_sparsity(self, h):
#     shape = tf.shape(h)
#     n = shape[0]
#     c = shape[3]

#     h_t = tf.transpose(h, [0, 3, 1, 2]) # n, c, h, w
#     h_r = tf.reshape(h_t, tf.stack([n, c, -1])) # n, c, h*w

#     th, _ = tf.nn.top_k(h_r, 1) # n, c, 1
#     th_r = tf.reshape(th, tf.stack([n, 1, 1, c])) # n, 1, 1, c
#     drop = tf.where(h < th_r, 
#       tf.zeros(shape, tf.float32), tf.ones(shape, tf.float32))

#     # spatially dropped & winner
#     return h*drop, tf.reshape(th, tf.stack([n, c])) # n, c

def _lifetime_sparsity(self, h, winner, rate):
    shape = tf.shape(winner)
    n = shape[0]
    c = shape[1]
    k = tf.cast(rate * tf.cast(n, tf.float32), tf.int32)

    winner = tf.transpose(winner) # c, n
    th_k, _ = tf.nn.top_k(winner, k) # c, k

    shape_t = tf.stack([c, n])
    drop = tf.where(winner < th_k[:,k-1:k], # c, n
      tf.zeros(shape_t, tf.float32), tf.ones(shape_t, tf.float32))
    drop = tf.transpose(drop) # n, c
    return h * tf.reshape(drop, tf.stack([n, 1, 1, c]))

In [12]:
#fix the seed to have same outputs
import tensorflow as tf
import math as m
torch.manual_seed(10)

# torch.max experimentation
# a = torch.rand(4,4)
# print(a)
# print(a.view(-1, 4*4)) #same
# print(a.view(1, -1))   #same
# flatten = a.view(1, -1)
# maxval, idx = torch.max(flatten, 1)
# print(maxval, idx)
# print("y: ", m.floor(idx.item()/4)) # y axis, or dim - 0
# print("x: ",idx.item() %4)             # x-axis, or dim - 1

size = 4

#create array
b = torch.rand(3, 2, size, size)
# print(b)
b.view(3, 2, -1)
flatten = b.view(3, 2, -1)
# print(flatten)
maxval, idx = torch.max(flatten, 2)
# print(maxval.shape)

maxval = torch.reshape(maxval, (3, 2, 1, 1))
# print(maxval)

test = torch.reshape(b, (3, 2, size, size)) < maxval
# print(test)

drop = torch.where(torch.reshape(b, (3, 2, size, size)) < maxval, 
                   torch.zeros((3, 2, size, size)), 
                   torch.ones((3, 2, size, size)))

# print(drop*b)
# print(drop)

# # Setting all the existing non-zero loder values into 0
# - flow
# shape = tf.shape(h) initial shape  GUESS: (n, h, w, c)
# n = number of batches?
# c = number of features?
# transpose(hidden, [0, 3, 1, 2]) # n, c, h, w (changing the order?)
# reshape(transposed hidden, stack([n, c, -1])) # n, c, h*w
# th, _ = get_top(reshaped hidden, one value) # n, 1, 1, c

# Lifetime sparsity flow
t = tf.constant([[1.0, 2.0], 
                 [3.0, 4.0], 
                 [5.0, 6.0], 
                 [7.0, 8.0]])
rate = 0.5
shape = tf.shape(t)
n = shape[0]
c = shape[1]
print(shape)
print(n, c)
k = tf.cast(rate * tf.cast(n, tf.float32), tf.int32)
print(k)
print(tf.stack([n,c]))
t = tf.transpose(t) # c, n
print(t)
th_k, _ = tf.nn.top_k(t, k) # c, k
print("topk:\n", th_k)
print(th_k[:,k-1:k])
# shape_t = tf.stack([c, n])
# drop = tf.where(winner < th_k[:,k-1:k], # c, n
#   tf.zeros(shape_t, tf.float32), tf.ones(shape_t, tf.float32))
# drop = tf.transpose(drop) # n, c
# return h * tf.reshape(drop, tf.stack([n, 1, 1, c]))
k = torch.tensor(2)
a = torch.arange(4*3).view(4, 3) #(n, c)
a = torch.transpose(a, 0, 1) #(c, n)
# print(a)
topk, _ = torch.topk(a, k, 1)
# print(topk)

topk[: , k-1:k]

tf.Tensor([4 2], shape=(2,), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor([4 2], shape=(2,), dtype=int32)
tf.Tensor(
[[1. 3. 5. 7.]
 [2. 4. 6. 8.]], shape=(2, 4), dtype=float32)
topk:
 tf.Tensor(
[[7. 5.]
 [8. 6.]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[5.]
 [6.]], shape=(2, 1), dtype=float32)


tensor([[6],
        [7],
        [8]])

In [93]:
# understanding of how matrix convolution and np.where works
a = torch.arange(2*3*4*4).view(2, 3, 4, 4)
print(a)
b = torch.arange(2*3).view(2, 3, 1, 1)
print(b)
print((b*a).shape)
print(b*a)

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]],

         [[32, 33, 34, 35],
          [36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]],


        [[[48, 49, 50, 51],
          [52, 53, 54, 55],
          [56, 57, 58, 59],
          [60, 61, 62, 63]],

         [[64, 65, 66, 67],
          [68, 69, 70, 71],
          [72, 73, 74, 75],
          [76, 77, 78, 79]],

         [[80, 81, 82, 83],
          [84, 85, 86, 87],
          [88, 89, 90, 91],
          [92, 93, 94, 95]]]])
tensor([[[[0]],

         [[1]],

         [[2]]],


        [[[3]],

         [[4]],

         [[5]]]])
torch.Size([2, 3, 4, 4])
tensor([[[[  0,   0,   0,   0],
          [  0,   0,   0,   0],
          [  0,   0,   0,   0],
          [  0,   0,   0,   0]],

         [[ 16,  17,  18,  19],
          [ 20,  21, 