# Weight Sharing in PyTorch

Unfortunately, PyTorch does not support weight sharing. Instead, we will provide a high-level description of a possible implementation. In this example, we will provide a custom implementation of weight sharing for PyTorch

In [23]:
import torch
import torch.quantization
import torch.nn as nn

from sklearn.cluster import KMeans

torch.manual_seed(0)  # set the seed for reproducibility

<torch._C.Generator at 0x7f4815b25fa8>

### Model with weight clustering support

In [24]:
# A model with few linear layers 
class SampleLinearModel(torch.nn.Module): 

    def __init__(self): 
        super(SampleLinearModel, self).__init__() 
        self.linear = torch.nn.Linear(10, 10)
        
    def cluster_weights(self, num_cluster): 

        # cluster weights of the layer          
        km = KMeans(
            n_clusters=num_cluster, init='random',
            n_init=10, max_iter=300, 
            tol=1e-04, random_state=0
        )
        
        # construct a mapping from a cluster index to a centroid value and store at self.weights_mapping 
        weights = model.linear.weight.reshape([-1, 1]).detach().numpy()
        self.weights_cluster = km.fit_predict(weights)

        # find cluster index for each weight value and store at self.weights_cluster 
        self.weights_mapping = km.cluster_centers_

        # drop the original weights to reduce the model size 
        self.linear.weight = None

    def forward(self, x): 
        if self.training:
            x = self.linear(x)
        else: # in eval mode
            # update weights of the self.layer by reassigning each value based on 
            # self.weights_cluster and self.weights_mapping 
            self.linear.weight = torch.nn.Parameter(torch.Tensor(self.weights_mapping[self.weights_cluster]).reshape(10,10))
            x = self.linear(x)
        return x 

### train the model

In [25]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.num_samples = 100
        self.data = torch.rand([self.num_samples, 10])
        self.label = torch.rand([self.num_samples, 1])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


train_dataset = CustomDataset()
training_data_loader = torch.utils.data.DataLoader(train_dataset)

In [26]:
model = SampleLinearModel() 

# train the model
model.train()
mse_loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for data, label in training_data_loader:
    optimizer.zero_grad()
    pred = model(data)
    loss = mse_loss(pred, label)
    loss.backward()
    optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)


In [27]:
original_output = model(train_dataset[0][0])
print(original_output)

tensor([ 0.0296, -0.1846, -0.0341, -0.0587, -0.0704,  0.2134, -0.1665,  0.4920,
         0.1367,  0.0739], grad_fn=<AddBackward0>)


### check the original model size

In [28]:
import os

# save the model and check the model size
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size


In [29]:
clustered_model = SampleLinearModel() 
original_model_size=print_size_of_model(model,"original model")

model:  original model  	 Size (KB): 1.511


### Apply weight clustering

In [30]:
model.eval()
model.cluster_weights(5)

In [31]:
print("linear.weights_mapping: \n", model.weights_mapping)
print("linear.weights_cluster: \n", model.weights_cluster)
print("linear.weight: \n", model.linear.weight)

linear.weights_mapping: 
 [[-0.20028913]
 [-0.06823479]
 [ 0.17053993]
 [ 0.02950809]
 [ 0.28837958]]
linear.weights_cluster: 
 [3 1 4 4 0 2 0 0 2 3 1 0 1 2 0 0 2 0 0 1 3 0 0 3 0 3 1 0 2 0 0 4 2 4 1 1 3
 3 4 0 3 2 1 1 2 0 2 0 2 0 4 0 0 3 4 0 1 4 3 2 0 1 1 4 2 0 2 3 1 1 1 4 0 0
 1 3 2 3 1 0 3 1 4 0 0 0 3 4 4 0 3 2 2 4 3 0 4 1 2 1]
linear.weight: 
 None


In [33]:
clustered_output = model(train_dataset[0][0])
print(clustered_output)

tensor([ 0.0817, -0.1601, -0.0350, -0.0173, -0.1515,  0.2588, -0.2154,  0.4650,
         0.1360,  0.0468], grad_fn=<AddBackward0>)


### compare the difference in model size

In [32]:
# compare the sizes
clustered_model_size=print_size_of_model(model,"clustered model")
print("{0:.2f} times smaller".format(original_model_size/clustered_model_size))

model:  clustered model  	 Size (KB): 0.811
1.86 times smaller
