# Multi-Modalities Multi-Tasks Example

In [38]:
# define a model with multi-modalities and multi-tasks
import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleModality(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.flat = nn.Flatten()
        self.task1 = nn.Linear(32*32*32, 10)
        self.task2 = nn.Linear(32*32*32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.flat(x)
        task1 = self.task1(x)
        task2 = self.task2(x)
        return task1, task2

class MMMT(nn.Module):
    def __init__(self):
        super().__init__()
        self.modality1 = SingleModality()
        self.modality2 = SingleModality()

    def forward(self, x1, x2):
        task1_1, task2_1 = self.modality1(x1)
        task1_2, task2_2 = self.modality2(x2)
        task1 = task1_1 + task1_2
        task2 = task2_1 + task2_2
        return task1, task2

In [39]:
from unip.core.pruner import BasePruner
from unip.utils.evaluation import cal_flops

model = MMMT()

# note: we need to set the `requires_grad=True` to connect the input
#       to the graph
example_input = [
    torch.rand(1, 3, 32, 32, requires_grad=True),
    torch.rand(1, 3, 32, 32, requires_grad=True)]

# calculate the flops of the original model
print("original model:")
flops_ori, params_ori = cal_flops(model, example_input)


original model:
('11.633M', '1.321M')


## define MTU ratio
The final ratio of each layer is the product of the ratio of setting,
each modality, and each task:
    $$\rho = \rho_{setting} * \rho_{MTU},$$
where the $\rho_{MTU}$ is:
    $$\rho_{MTU} = \frac{1}{N}\sum_{i=0}^{N-1} MTU_{i}.$$

Let's assume you want to prune the modules influenced by the `input_0`, and the `output_1`, and the MTU of each module is:

In [40]:
# if you want to consider multiple modalities and multiple tasks:
MTU = {
    "input_0": 0.5,
    "input_1": 0,
    "output_0": 0,
    "output_1": 0.5,
}
# if your want to consider single modality or single task:
# MTU = {
#     "input_0": 0.5,
#     "output_0": 0.5,
# }

# note:
#  1. the difference between of above two MTU is that the former one
#     will consider calculate the mean of the all the modalities and
#     tasks if they have the same tag of MTU. 
# 
#  2. The latter one will calculate the mean of the "input_0" and 
#     "output_0" only.
#
#  3. Thus, the prune ratio of former one is a little bit smaller than the
#     latter one if the other name in MTU is 0. The saved ratio of former
#     one is larger than the latter one.



In [41]:
# define a pruner
pruner = BasePruner(
    model,
    example_input,
    "MTURatio",
    algo_args={
        "score_fn": "weight_sum_l1_out",
        "MTU": MTU,
    },
)

# prune the model
pruner.prune(0.3)

# calculate the flops of the pruned model
print("pruned model:")
flops_pruned, params_pruned = cal_flops(model, example_input)

pruned model:
('10.666M', '1.238M')
