In [9]:
import torch
import numpy as np
import torch.functional as F
from torch import nn
from torch.nn.utils import prune

In [10]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "/home/alan/Documents/Cloud_Detection/neural network/c_unet_1649683131.7209947.pth"

In [11]:
class depthwiseSeparableConv(nn.Module):
    def __init__(self, nin, nout):
        super(depthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class C_UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = self.contract_block(in_channels, 32)
        self.conv2 = self.contract_block(32, 64)
        self.conv3 = self.contract_block(64, 128)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64, 32, 3, 1)
        self.upconv1 = self.expand_block(32, out_channels, 3, 1)

        self.out = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.Sigmoid()
        )


    def __call__(self, x):
        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)
        upconv2 = self.upconv2(upconv3)
        upconv1 = self.upconv1(upconv2)

        out = self.out(upconv1)

        return out


    def contract_block(self, in_channels, out_channels):
        contract = nn.Sequential(
            depthwiseSeparableConv(in_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        return contract


    def expand_block(self, in_channels, out_channels, kernel_size, padding):
        expand = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
        )
        return expand

In [12]:
model = C_UNet(4, 2)
model.to(DEVICE)

model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

C_UNet(
  (conv1): Sequential(
    (0): depthwiseSeparableConv(
      (depthwise): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
      (pointwise): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): depthwiseSeparableConv(
      (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (pointwise): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): depthwiseSeparableConv(
      (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding

In [13]:
param_to_prune = (
    (model.conv1[0].depthwise, "weight"),
    (model.conv1[0].pointwise, "weight"),

    (model.conv2[0].depthwise, "weight"),
    (model.conv2[0].pointwise, "weight"),

    (model.conv3[0].depthwise, "weight"),
    (model.conv3[0].pointwise, "weight"),

    (model.upconv3[0], "weight"),
    (model.upconv3[3], "weight"),

    (model.upconv2[0], "weight"),
    (model.upconv2[3], "weight"),

    (model.upconv1[0], "weight"),
    (model.upconv1[3], "weight"),

    (model.out[0], "weight")
)

prune.global_unstructured(
    param_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.15,
)

In [14]:
print(len(param_to_prune))

13


In [15]:
for i in range(len(param_to_prune)):
    prune.remove(param_to_prune[i][0], param_to_prune[i][1])

In [16]:
x = torch.zeros(1, 4, 384, 384)
model_trace = torch.jit.trace(model, x)
torch.jit.save(model_trace, f"/home/alan/Documents/Cloud_Detection/neural network/cloud_detection_c_unet.pth")