#Pruning 2D Autoencoder 

In [None]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.utils.prune as prune
import os
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import matplotlib.pyplot as plt



img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [None]:
# dataset class
class MyDataset(Dataset):
    def __init__(self, image_path, transform=None):
        self.image_path = image_path
        self.transform = transform

    def __len__(self):
        return 100

    def __getitem__(self, x):
        image = plt.imread(self.image_path)
        image = TF.to_tensor(image)
        s = image.size

        if self.transform is not None:
            image = self.transform(image)

        return image

In [None]:
### CHANGE the IMAGE URL

IMG_URL = '/content/Famous-portrait-CT_2860.jpeg'

dataset = MyDataset(IMG_URL)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

In [None]:
### Build an autoencoder

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding = 0),  # b, 16, 398, 398
            nn.ReLU(True),
            #nn.MaxPool2d(2, stride=2),  # b,  16, 199, 199
            nn.Conv2d(16, 4, 3, stride=2, padding = 0),  # b, 8, 99, 99
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, padding = 1)  # b, 8, 49, 49
        )
        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(4, 16, 2, stride=2),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 2, stride=2),
            nn.Tanh()          
            
            #nn.ConvTranspose2d(8, 16, 3, stride=2, padding = 0),  # b, 16, 5, 5
            #nn.ReLU(True),
            #nn.ConvTranspose2d(16, 8, 3, stride=2, padding = 0),  # b, 8, 15, 15
            #nn.ReLU(True),
            #nn.ConvTranspose2d(8, 1, 5, stride=2, padding = 0),  # b, 1, 28, 28
            #nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:
### Train an autoencoder

num_epochs = 100
batch_size = 1
learning_rate = 1e-3

model = autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
total_loss = 0
for epoch in range(num_epochs):
    for data in dataloader:
        img = Variable(data)
        # ===================forward=====================
        output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    total_loss += loss.data
    print('epoch [{}/{}]'
          .format(epoch+1, num_epochs))
    if epoch % 10 == 0:
        pic = output.cpu()#.detach()#.numpy()
        save_image(pic, 'image_{}.png'.format(epoch))

torch.save(model.state_dict(), './conv_autoencoder.pth')

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


epoch [1/50], loss:0.0145
epoch [2/50], loss:0.0248
epoch [3/50], loss:0.0335
epoch [4/50], loss:0.0402
epoch [5/50], loss:0.0466
epoch [6/50], loss:0.0526
epoch [7/50], loss:0.0583
epoch [8/50], loss:0.0637
epoch [9/50], loss:0.0689
epoch [10/50], loss:0.0737
epoch [11/50], loss:0.0784
epoch [12/50], loss:0.0830
epoch [13/50], loss:0.0875
epoch [14/50], loss:0.0919
epoch [15/50], loss:0.0962
epoch [16/50], loss:0.1004
epoch [17/50], loss:0.1047
epoch [18/50], loss:0.1089
epoch [19/50], loss:0.1131
epoch [20/50], loss:0.1172
epoch [21/50], loss:0.1213
epoch [22/50], loss:0.1254
epoch [23/50], loss:0.1295
epoch [24/50], loss:0.1336
epoch [25/50], loss:0.1377
epoch [26/50], loss:0.1417
epoch [27/50], loss:0.1457
epoch [28/50], loss:0.1498
epoch [29/50], loss:0.1538
epoch [30/50], loss:0.1578
epoch [31/50], loss:0.1618
epoch [32/50], loss:0.1658
epoch [33/50], loss:0.1698
epoch [34/50], loss:0.1737
epoch [35/50], loss:0.1777
epoch [36/50], loss:0.1817
epoch [37/50], loss:0.1856
epoch [38/

In [None]:
### Check the reconstructed output by feeding the same input

image = plt.imread(IMG_URL)
image = TF.to_tensor(image)
image = image.unsqueeze(0)

In [None]:
predicted = model(image)
save_image(predicted, 'predicted.png')

# show image the predicted image
predicted = predicted.detach().squeeze().permute(1,2,0)
plt.axis('off')
plt.imshow(predicted)
plt.show()

In [None]:
### Start pruning

print(model.state_dict().keys())

odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias'])


In [None]:
new_model = model

In [None]:

PRUNING_AMOUNT = 0.2

for name, module in new_model.named_modules():
    # prune 10% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
       prune.l1_unstructured(module, name='weight', amount=PRUNING_AMOUNT)
    # prune 10% of connections in all ConvTranspose2d layers 
    if isinstance(module, torch.nn.ConvTranspose2d):
       prune.l1_unstructured(module, name='weight', amount=PRUNING_AMOUNT)

In [None]:
pruned20 = new_model(image)

In [None]:
save_image(pruned20, 'pruned20.png')
pruned2 = pruned2.detach().squeeze().permute(1,2,0)
plt.axis('off')
plt.imshow(pruned2)
plt.show()