In [2]:
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader

from functools import reduce

In [3]:
class Net(nn.Module):
    """Convolutional Neural Network architecture.

    As described in McMahan 2017 paper :

    [Communication-Efficient Learning of Deep Networks from
    Decentralized Data] (https://arxiv.org/pdf/1602.05629.pdf)
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Forward pass of the CNN.

        Parameters
        ----------
        x : torch.Tensor
            Input Tensor that will pass through the network

        Returns
        -------
        torch.Tensor
            The resulting Tensor after it has passed through the network
        """
        output_tensor = F.relu(self.conv1(input_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = F.relu(self.conv2(output_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = torch.flatten(output_tensor, 1)
        output_tensor = F.relu(self.fc1(output_tensor))
        output_tensor = self.fc2(output_tensor)
        return output_tensor


In [4]:
model = Net()

In [5]:
result = []
for layer in model.parameters():
    reshape_tensor = torch.flatten(layer.data)
#     print(reshape_tensor)
#     print(layer.data.shape)
#     print(reshape_tensor.shape)

    result.append(reshape_tensor)
result = torch.cat(result)


In [6]:
print(result.shape)    

torch.Size([1663370])


In [7]:
a = torch.tensor([1,2,3,4])

In [8]:
a.shape

torch.Size([4])

In [9]:
def compress_model_mask(net: nn.Module, mask: torch.tensor) -> nn.Module:
    list_of_reshaped_layers = []
    list_of_shapes = []
    for layer in net.parameters():
        reshaped_layer = torch.flatten(layer.data)
        list_of_reshaped_layers.append(reshaped_layer)
        shape = reduce((lambda x, y: x * y), list(layer.data.shape))
        list_of_shapes.append(shape)
    cat_full_vec = torch.cat(reshape_list)
    compressed_full_vec = torch.mul(cat_full_vec, mask)
    
    compressed_split_vec = torch.split(compressed_full_vec, list_of_shapes)
    i = 0
    for layer in net.parameters():
        layer.data = compressed_split_vec[i]
        i+=1
    
    return net


    

In [10]:
model = Net()

In [14]:
for layer in model.parameters():
    shape = reduce((lambda x, y: x * y), list(layer.data.shape))
    print(shape)

800
32
51200
64
1605632
512
5120
10


In [31]:
mask = torch.zeros_like(cat_full_vec)

In [33]:
mask.shape

torch.Size([1663370])

In [36]:
list_of_reshaped_layers = []
list_of_shapes = []
for layer in model.parameters():
    reshaped_layer = torch.flatten(layer.data)
    list_of_reshaped_layers.append(reshaped_layer)
    shape = reduce((lambda x, y: x * y), list(layer.data.shape))
    list_of_shapes.append(shape)
cat_full_vec = torch.cat(list_of_reshaped_layers)

compressed_full_vec = torch.mul(cat_full_vec, mask)
compressed_split_vec = torch.split(compressed_full_vec, list_of_shapes)
i = 0
for layer in model.parameters():
    layer.data = compressed_split_vec[i]
    i+=1
    
print(list_of_shapes)
print(cat_full_vec.shape[0])
print(compressed_split_vec[7].shape)
for layer in model.parameters():
    print(layer.data)

[800, 32, 51200, 64, 1605632, 512, 5120, 10]
1663370
torch.Size([10])
tensor([0., -0., 0., -0., -0., 0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0.,
        -0., 0., 0., 0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., -0.,
        0., 0., -0., -0., 0., -0., 0., 0., -0., -0., 0., -0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., -0.,
        -0., 0., 0., 0., -0., -0., -0., 0., 0., 0., -0., 0., -0., -0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0.,
        0., 0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., -0., -0.,
        -0., -0., -0., 0., 0., 0., -0., -0., 0., 0., 0., -0., -0., 0., 0., -0., -0., -0., 0., -0., -0., 0., 0., 0.,
        0., 0., -0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., -0., -0., -0., 0., -0., 0., -0., 0., 0.,
        0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0.,