In [10]:
import torch
import torch_pruning as tp

import torch
from torch import nn
from torch.functional import F
import numpy as np
import swyft.lightning as sl
from toolz.dicttoolz import valmap
from sklearn.metrics import roc_curve, auc

import logging
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)


In [11]:
# Model

class InferenceNetwork(sl.SwyftModule):
    def __init__(self, **conf):
        super().__init__()
        
        self.one_d_only = True       
        self.batch_size = conf["batch_size"]
        self.noise_shuffling = True
        self.num_params = 15
        self.marginals = (0,1),
        self.include_noise = True
        
        self.unet_t = Unet(
            n_in_channels=3,
            n_out_channels=1,
            sizes=(16, 32, 64, 128, 256),
            down_sampling=(8, 8, 8, 8),
        )
        self.unet_f = Unet(
            n_in_channels=6,
            n_out_channels=1,
            sizes=(16, 32, 64, 128, 256),
            down_sampling=(4, 2, 2, 2),
        )

        self.flatten = nn.Flatten(1)
        self.linear_t = LinearCompression()
        self.linear_f = LinearCompression()

        self.logratios_1d = sl.LogRatioEstimator_1dim(
            num_features=32, num_params=int(self.num_params), varnames="z_total"
        )
            
        self.optimizer_init = sl.AdamOptimizerInit(lr=conf["learning_rate"])

    def forward(self, A, B):
        
        if self.noise_shuffling and A["d_t"].size(0) != 1:
            noise_shuffling = torch.randperm(self.batch_size)
            d_t = A["d_t"] + A["n_t"][noise_shuffling]
            d_f_w = A["d_f_w"] + A["n_f_w"][noise_shuffling]
        else:
            d_t = A["d_t"] + A["n_t"]
            d_f_w = A["d_f_w"] + A["n_f_w"]
        z_total = B["z_total"]

        d_t = self.unet_t(d_t)
        d_f_w = self.unet_f(d_f_w)

        features_t = self.linear_t(self.flatten(d_t))
        features_f = self.linear_f(self.flatten(d_f_w))
        features = torch.cat([features_t, features_f], dim=1)
        logratios_1d = self.logratios_1d(features, z_total)
        return logratios_1d
               
# 1D Unet implementation below
class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        mid_channels=None,
        padding=1,
        bias=False,
    ):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv1d(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                mid_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, down_sampling=2):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(down_sampling), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super().__init__()
        self.up = nn.ConvTranspose1d(
            in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride
        )
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diff_signal_length = x2.size()[2] - x1.size()[2]

        x1 = F.pad(
            x1, [diff_signal_length // 2, diff_signal_length - diff_signal_length // 2]
        )
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size)

    def forward(self, x):
        return self.conv(x)


class Unet(nn.Module):
    def __init__(
        self,
        n_in_channels,
        n_out_channels,
        sizes=(16, 32, 64, 128, 256),
        down_sampling=(2, 2, 2, 2),
    ):
        super(Unet, self).__init__()
        self.inc = DoubleConv(n_in_channels, sizes[0])
        self.down1 = Down(sizes[0], sizes[1], down_sampling[0])
        self.down2 = Down(sizes[1], sizes[2], down_sampling[1])
        self.down3 = Down(sizes[2], sizes[3], down_sampling[2])
        self.down4 = Down(sizes[3], sizes[4], down_sampling[3])
        self.up1 = Up(sizes[4], sizes[3])
        self.up2 = Up(sizes[3], sizes[2])
        self.up3 = Up(sizes[2], sizes[1])
        self.up4 = Up(sizes[1], sizes[0])
        self.outc = OutConv(sizes[0], n_out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        f = self.outc(x)
        return f


class LinearCompression(nn.Module):
    def __init__(self):
        super(LinearCompression, self).__init__()
        self.sequential = nn.Sequential(
            nn.LazyLinear(1024),
            nn.ReLU(),
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.LazyLinear(64),
            nn.ReLU(),
            nn.LazyLinear(16),
        )

    def forward(self, x):
        return self.sequential(x)



In [12]:
model = Unet(
    n_in_channels=3,
    n_out_channels=1,
    sizes=(16, 32, 64, 128, 256),
    down_sampling=(8, 8, 8, 8),
)
example_inputs = torch.randn(5, 3, 8192)

device = 'cpu'
checkpoint_path = "models/unet_peregrine/epoch=41-step=305000_val_loss=-5.51_train_loss=-5.75.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=device)

model_params = model.state_dict()
pretrained_params = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'unet_t' in k}
model_params.update(pretrained_params)
model.load_state_dict(pretrained_params)

<All keys matched successfully>

In [13]:
data_dir = '/scratch-shared/scur2012/training_data/default_limits_2e6/training_data'
zarr_store = sl.ZarrStore(f"{data_dir}")

train_data = zarr_store.get_dataloader(
    num_workers=8,
    batch_size=256,
    idx_range=[0, 1000],
    on_after_load_sample=False,
)

for batch_idx, batch in enumerate(train_data):
    #loss = model.training_step(batch, batch_idx)
    break

In [14]:
batch['d_t'].shape

torch.Size([256, 3, 8192])

In [16]:
for tf in ['t','f']:
    for ratio in [0, 0.05, 0.10, 0.2, 0.3, 0.4, 0.5]:

        if tf == 't':
            model = Unet(
                n_in_channels=3,
                n_out_channels=1,
                sizes=(16, 32, 64, 128, 256),
                down_sampling=(8, 8, 8, 8),
            )
            #example_inputs = torch.randn(5, 3, 8192)
            example_inputs = batch['d_t']
        else:
            model = Unet(
                n_in_channels=6,
                n_out_channels=1,
                sizes=(16, 32, 64, 128, 256),
                down_sampling=(2, 2, 2, 2),
            )
            example_inputs = torch.randn(5, 6, 4097)
            example_inputs = batch['d_f_w']

        checkpoint_path = "models/unet_peregrine/epoch=41-step=305000_val_loss=-5.51_train_loss=-5.75.ckpt"
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        model_params = model.state_dict()
        pretrained_params = {k[7:]: v for k, v in checkpoint['state_dict'].items() if f'unet_{tf}' in k}
        model_params.update(pretrained_params)
        model.load_state_dict(pretrained_params)


        # 1. Importance criterion
        imp = tp.importance.GroupNormImportance(p=2) #GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.

        # 2. Initialize a pruner with the model and the importance criterion
        ignored_layers = []
        # for m in model.modules():
        #     if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        #         ignored_layers.append(m) # DO NOT prune the final classifier!

        pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
            model,
            example_inputs,
            importance=imp,
            pruning_ratio=ratio, # remove 20% channels
            # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
            ignored_layers=ignored_layers,
        )

        # 3. Prune & finetune the model
        base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
        if isinstance(imp, tp.importance.GroupTaylorImportance):
            # Taylor expansion requires gradients for importance estimation
            loss = model(example_inputs) 
            loss.backward() # before pruner.step()

        pruner.step()
        macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
        # finetune the pruned model here
        # finetune(model)
        # ...

        # Save models
        model.zero_grad() # clear gradients
        torch.save(model, f'unet_{tf}_pruned_{ratio}_taylor.pth')
        
        print (ratio, tf, base_macs, base_nparams, macs, nparams)


RuntimeError: grad can be implicitly created only for scalar outputs

In [18]:
loss.shape

torch.Size([256, 1, 8192])

In [81]:
model2 = torch.load('/home/scur2012/Thesis/master-thesis/experiments/pruning/unet_t_pruned_0.5.pth')

In [85]:
class InferenceNetworkPruned(InferenceNetwork):
    def __init__(self, **conf):
        super().__init__(**conf)
        self.unet_t = torch.load('/home/scur2012/Thesis/master-thesis/experiments/pruning/unet_t_pruned_0.5.pth')
        self.unet_f = torch.load('/home/scur2012/Thesis/master-thesis/experiments/pruning/unet_f_pruned_0.5.pth')

network = InferenceNetworkPruned(batch_size=256, learning_rate=5e-4)



In [86]:
network

InferenceNetworkPruned(
  (unet_t): Unet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv1d(3, 8, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(8, 8, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool1d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
            (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv1d(16, 16, kernel_size=(3,), 

In [87]:
InferenceNetwork(batch_size=256, learning_rate=5e-4)

InferenceNetwork(
  (unet_t): Unet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv1d(3, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (4): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool1d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv1d(32, 32, kernel_size=(3,), 