# Compress custom network

This brief tutorial shows how to compress a custom network with EfficientBioAI. We use a simple 2d unet to do the 2d semantic segmentation task on the [Simulated nuclei of HL60 cells stained with Hoescht](http://celltrackingchallenge.net/2d-datasets/).  Both pruning and quantization are tried.

In [1]:
import os
import numpy as np
import torch
import monai
from monai.data import DataLoader, Dataset
from monai.transforms import RandSpatialCropSamplesd, Compose, AddChanneld, ScaleIntensityRanged, ToTensord, Transform, CastToTyped, EnsureTyped, ScaleIntensityRangePercentilesd
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.losses import DiceLoss
from tqdm.contrib import tenumerate
from aicsimageio import AICSImage
from torch.optim.lr_scheduler import ReduceLROnPlateau

  from .autonotebook import tqdm as notebook_tqdm


2023-03-30 16:11:34,612 - Resource 'XMLSchema.xsd' is already loaded


## 1. Prepare the dataset

In [2]:
train_data_path = "/home/ISAS.DE/yu.zhou/Downloads/Fluo-N2DH-SIM+/02"
train_gt_path = "/home/ISAS.DE/yu.zhou/Downloads/Fluo-N2DH-SIM+/02_GT/SEG"

test_data_path = "/home/ISAS.DE/yu.zhou/Downloads/Fluo-N2DH-SIM+/03"
test_gt_path = "/home/ISAS.DE/yu.zhou/Downloads/Fluo-N2DH-SIM+/03_GT/SEG"

In [3]:
def generate_data_dict(data_path, gt_path):
    data_dicts = []
    
    for i, (data,label) in tenumerate(zip(os.listdir(data_path), os.listdir(gt_path))):
        data_dict = {}
        data_dict['img'] = os.path.join(data_path, data)
        data_dict['seg'] = os.path.join(gt_path, label)
        data_dict['fn'] = data.split(".")[0]
        data_dicts.append(data_dict)
    return data_dicts

class LoadTiffd(Transform):
    def __init__(self, keys=['img','seg']) :
        super().__init__()
        self.keys = keys

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            x = AICSImage(data[key])
            d[key] = x.get_image_data("YX", S=0, T=0, C=0)
        return d

class Ins2Semd(Transform):
    def __init__(self, keys=['seg']) :
        super().__init__()
        self.keys = keys
        
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key][d[key]!=0] = 1
        return d

transform = Compose([
                    LoadTiffd(keys=["img", "seg"]), \
                    AddChanneld(keys=["img", "seg"]), \
                    CastToTyped(keys=["img"], dtype=np.float32), \
                    Ins2Semd(keys=["seg"]), \
                    EnsureTyped(keys=["img", "seg"]), \
                    ScaleIntensityRangePercentilesd(keys=["img"], lower=0.5, upper=99.5, b_min = 0, b_max = 1), \
                    RandSpatialCropSamplesd(keys=["img", "seg"], roi_size=(256, 256), num_samples=4, random_size=False), \
                    ToTensord(keys=["img", "seg"])
                    ])



In [4]:
dataset = Dataset(data=generate_data_dict(train_data_path, train_gt_path), transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())

150it [00:00, 123507.18it/s]


In [40]:
example  = next(iter(dataloader))

Let us visualize the dataset first.

## 2. Train the model

In [5]:
from model.unet import Unet

In [None]:
net = Unet(in_channels=1, classes=2)
criterion = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(net.parameters(), 1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epoch = 100
net.to(device)

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
for i in range(num_epoch):
    # train step:
    net.train()
    epoch_loss = 0
    for j, batch_data in tenumerate(dataloader):
        data, label = batch_data['img'].to(device), batch_data['seg'].to(device)
        optimizer.zero_grad()
        out = net(data)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'epoch {i+1}/{num_epoch}, avg loss: {epoch_loss / len(dataloader)}')
    scheduler.step()
    

## 3. Compress the model

In [14]:
torch.save(net.state_dict(), "./unet.pth")

In [15]:
!pwd

/home/ISAS.DE/yu.zhou/EfficientBioAI/tutorial
