# Utility Functions

## Overview

This notebook defines several utility functions for managing our multitask neural network models. The models contain custom masked linear layers that enable selective activation of parts of the models for different tasks.

Three key functions are defined:

* **set_model_task**: Sets the current task for all masked linear layers in the models. This allows switching the model behaviour for different tasks.
* **cache_masks**: Caches the mask values for all masked linear layers. This is useful for saving the mask state for later use.
* **set_num_tasks_learned**: Sets the number of tasks learned so far for all masked linear layers. This allows for incrementally increasing the capacity of the model as more tasks are learned.

Together, these functions provide a convenient interface for controlling our multitask models during training and evaluation.

## Importing Required Libraries

In [None]:
from utilities.models import MultitaskMaskLinear, MultitaskMaskLinearV2

## Functions for Multitask Model Management

In [None]:
# Define a function to set the current task for each multitask module in the model
def set_model_task(model, task, verbose=True):
    # Loop over all named modules of the model
    for n, m in model.named_modules():
        # If the module is a multitask linear layer
        if isinstance(m, MultitaskMaskLinear) or isinstance(m, MultitaskMaskLinearV2):
            # If verbose, print a message indicating the task change
            if verbose:
                pringvbt(f"=> Set task of {n} to {task}")
            # Set the task attribute of the module to the given task
            m.task = task



# Define a function to cache the masks for each multitask module in the model
def cache_masks(model):
    # Loop over all named modules of the model
    for n, m in model.named_modules():
        # If the module is a multitask linear layer
        if isinstance(m, MultitaskMaskLinear) or isinstance(m, MultitaskMaskLinearV2):
            # Print a message indicating that the mask state is being cached
            print(f"=> Caching mask state for {n}")
            # Cache the mask state
            m.cache_masks()


# Define a function to set the number of tasks learned for each multitask module in the model
def set_num_tasks_learned(model, num_tasks_learned):
    # Loop over all named modules of the model
    for n, m in model.named_modules():
        # If the module is a multitask linear layer
        if isinstance(m, MultitaskMaskLinear) or isinstance(m, MultitaskMaskLinearV2):
            # Print a message indicating the change in learned tasks
            print(f"=> Setting learned tasks of {n} to {num_tasks_learned}")
            # Set the number of tasks learned attribute of the module
            m.num_tasks_learned = num_tasks_learned

--------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/RAIVNLab/supsup