In [22]:
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader

from functools import reduce

In [2]:
class Net(nn.Module):
    """Convolutional Neural Network architecture.

    As described in McMahan 2017 paper :

    [Communication-Efficient Learning of Deep Networks from
    Decentralized Data] (https://arxiv.org/pdf/1602.05629.pdf)
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Forward pass of the CNN.

        Parameters
        ----------
        x : torch.Tensor
            Input Tensor that will pass through the network

        Returns
        -------
        torch.Tensor
            The resulting Tensor after it has passed through the network
        """
        output_tensor = F.relu(self.conv1(input_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = F.relu(self.conv2(output_tensor))
        output_tensor = self.pool(output_tensor)
        output_tensor = torch.flatten(output_tensor, 1)
        output_tensor = F.relu(self.fc1(output_tensor))
        output_tensor = self.fc2(output_tensor)
        return output_tensor


In [3]:
model = Net()

In [49]:
result = []
for layer in model.parameters():
    reshape_tensor = torch.flatten(layer.data)
#     print(reshape_tensor)
#     print(layer.data.shape)
#     print(reshape_tensor.shape)

    result.append(reshape_tensor)
result = torch.cat(result)


In [50]:
print(result.shape)    

torch.Size([1663370])


In [31]:
a = torch.tensor([1,2,3,4])

In [32]:
a.shape

torch.Size([4])

In [None]:
def Compress_Model_Mask(net: nn.Module, mask: torch.tensor) -> nn.Module:
    reshape_list = []
    for layer in model.parameters():
        reshape_tensor = torch.flatten(layer.data)
        reshape_list.append(reshape_tensor)
    cat_full_vec = torch.cat(reshape_list)
    compressed_full_vec = torch.mul(cat_full_vec, mask)
    