In [80]:
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader

import copy

from functools import reduce

In [8]:
class Net(nn.Module):
    """Convolutional Neural Network architecture.

    As described in McMahan 2017 paper :

    [Communication-Efficient Learning of Deep Networks from
    Decentralized Data] (https://arxiv.org/pdf/1602.05629.pdf)
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Forward pass of the CNN.

        Parameters
        ----------
        x : torch.Tensor
            Input Tensor that will pass through the network

        Returns
        -------
        torch.Tensor
            The resulting Tensor after it has passed through the network
        """
        output_tensor = F.relu(self.conv1(input_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = F.relu(self.conv2(output_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = torch.flatten(output_tensor, 1)
        output_tensor = F.relu(self.fc1(output_tensor))
        output_tensor = self.fc2(output_tensor)
        return output_tensor


In [9]:
model = Net()

In [10]:
result = []
for layer in model.parameters():
    reshape_tensor = torch.flatten(layer.data)
#     print(reshape_tensor)
#     print(layer.data.shape)
#     print(reshape_tensor.shape)

    result.append(reshape_tensor)
result = torch.cat(result)


In [11]:
print(result.shape)    

torch.Size([1663370])


In [12]:
a = torch.tensor([1,2,3,4])

In [13]:
a.shape

torch.Size([4])

In [70]:
def compress_model_mask(net: nn.Module, mask: torch.tensor) -> nn.Module:
    list_of_reshaped_layers = []
    list_of_shapes = []
    for layer in net.parameters():
        reshaped_layer = torch.flatten(layer.data)
        list_of_reshaped_layers.append(reshaped_layer)
        shape = reduce((lambda x, y: x * y), list(layer.data.shape))
        list_of_shapes.append(shape)
    cat_full_vec = torch.cat(list_of_reshaped_layers)
    compressed_full_vec = torch.mul(cat_full_vec, mask)
    
    
    compressed_split_vec = torch.split(compressed_full_vec, list_of_shapes)
#     print(compressed_split_vec[0].shape)
    i = 0
    for layer in net.parameters():
        layer.data = compressed_split_vec[i].reshape(layer.data.shape)
        i+=1
        print(layer.data[0])
    
    return net


    

In [75]:
model = Net()

In [81]:
net = copy.deepcopy(model)

In [82]:
mask = torch.zeros_like(cat_full_vec)

In [86]:
for i in range(len(mask)):
    if i%2==0:
        mask[i] = 1
    else:
        mask[i] = 0

In [88]:
for layer in model.parameters():
    print(layer.data[0])

tensor([[[ 0.1910,  0.0403,  0.0855, -0.1819, -0.0555],
         [ 0.1871, -0.0156, -0.0966, -0.1714, -0.1367],
         [ 0.0938,  0.1438,  0.0437, -0.1529, -0.0639],
         [ 0.1319,  0.1302, -0.1780, -0.0480,  0.0627],
         [ 0.1551, -0.1815,  0.0705,  0.1847, -0.0073]]])
tensor(-0.0489)
tensor([[[ 0.0180, -0.0072,  0.0226,  0.0126, -0.0193],
         [ 0.0347,  0.0342, -0.0311,  0.0292, -0.0172],
         [-0.0209, -0.0238, -0.0049,  0.0300, -0.0198],
         [ 0.0275,  0.0207,  0.0145,  0.0140,  0.0115],
         [-0.0245,  0.0028,  0.0339, -0.0344,  0.0332]],

        [[-0.0187, -0.0227, -0.0046, -0.0063, -0.0271],
         [-0.0262, -0.0262, -0.0195, -0.0245, -0.0149],
         [ 0.0350,  0.0324, -0.0050,  0.0344, -0.0008],
         [-0.0283, -0.0238,  0.0231,  0.0300, -0.0185],
         [-0.0317, -0.0272,  0.0226,  0.0173, -0.0218]],

        [[-0.0044, -0.0055, -0.0275,  0.0313, -0.0004],
         [-0.0119, -0.0224, -0.0065,  0.0314,  0.0227],
         [-0.0056,  0.0319

In [89]:
compress_model_mask(net, mask)

tensor([[[ 0.1910,  0.0000,  0.0855, -0.0000, -0.0555],
         [ 0.0000, -0.0156, -0.0000, -0.1714, -0.0000],
         [ 0.0938,  0.0000,  0.0437, -0.0000, -0.0639],
         [ 0.0000,  0.1302, -0.0000, -0.0480,  0.0000],
         [ 0.1551, -0.0000,  0.0705,  0.0000, -0.0073]]])
tensor(-0.0489)
tensor([[[ 0.0180, -0.0000,  0.0226,  0.0000, -0.0193],
         [ 0.0000,  0.0342, -0.0000,  0.0292, -0.0000],
         [-0.0209, -0.0000, -0.0049,  0.0000, -0.0198],
         [ 0.0000,  0.0207,  0.0000,  0.0140,  0.0000],
         [-0.0245,  0.0000,  0.0339, -0.0000,  0.0332]],

        [[-0.0000, -0.0227, -0.0000, -0.0063, -0.0000],
         [-0.0262, -0.0000, -0.0195, -0.0000, -0.0149],
         [ 0.0000,  0.0324, -0.0000,  0.0344, -0.0000],
         [-0.0283, -0.0000,  0.0231,  0.0000, -0.0185],
         [-0.0000, -0.0272,  0.0000,  0.0173, -0.0000]],

        [[-0.0044, -0.0000, -0.0275,  0.0000, -0.0004],
         [-0.0000, -0.0224, -0.0000,  0.0314,  0.0000],
         [-0.0056,  0.0000

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=1, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)

In [14]:
for layer in model.parameters():
    shape = reduce((lambda x, y: x * y), list(layer.data.shape))
    print(shape)

800
32
51200
64
1605632
512
5120
10


In [31]:
mask = torch.zeros_like(cat_full_vec)

In [33]:
mask.shape

torch.Size([1663370])

In [17]:
list_of_reshaped_layers = []
list_of_shapes = []
for layer in model.parameters():
    reshaped_layer = torch.flatten(layer.data)
    list_of_reshaped_layers.append(reshaped_layer)
    shape = reduce((lambda x, y: x * y), list(layer.data.shape))
    list_of_shapes.append(shape)
cat_full_vec = torch.cat(list_of_reshaped_layers)

compressed_full_vec = torch.mul(cat_full_vec, mask)
compressed_split_vec = torch.split(compressed_full_vec, list_of_shapes)
i = 0
for layer in model.parameters():
    layer.data = compressed_split_vec[i]
    i+=1
    
print(list_of_shapes)
print(cat_full_vec.shape[0])
print(compressed_split_vec[7].shape)
for layer in model.parameters():
    print(layer.data)

NameError: name 'mask' is not defined