# Building a Model Merging Function

To fully leverage the capabilities of this library, it's important to understand some of its inner workings. The library is designed to be easily extendable, making it simple to customize and build upon. To facilitate this, the library provides two key Python function decorators:

## `@model_merge` Decorator

The `@model_merge` decorator is designed to wrap around your outer merging function, taking care of complex operations such as model loading, casting to and from `StateDict`, and managing skipped layers. When you apply this decorator to a function, it allows the function to accept a list of either `torch.nn.Module` objects or Hugging Face pipelines as input, and it will output a similarly structured object.

However, the core merging logic that you need to implement is much simpler. Your function only needs to accept a list of `StateDict` objects (which represent the state of each model) and return a `StateDict` for the merged model. The `@model_merge` decorator handles all the other complexities for you!


```python
@model_merge
def my_merge_function(models: List[StateDict], **kw) -> StateDict:
    # Your merging logic here
    return merged_state_dict

my_merge_function(list_of_models)
```

## `@dict_map` Decorator

The `@dict_map` decorator is intended to wrap the inner function that performs the actual layer-wise merging of the models.

This decorator enables you to write a function that operates on individual layers (i.e., `torch.Tensor` objects) and then applies this function across all the layers in the model. The function you write should take as input a `List[torch.Tensor]`—each tensor representing the parameters for a specific layer across different models—and return a single tensor representing the merged layer.

### Example Usage

```python
@dict_map()
def merge_layers(models: List[Tensor], **kw) -> torch.Tensor:
    # Your layer-wise merging logic here
    return merged_layer
```

By using these decorators, you can focus on the core logic of merging models and layers without worrying about the underlying complexities of model management. This makes your code more modular, readable, and maintainable.

# Median Merger
Imagine you hypothesize that taking the median value of the weights from several models might improve their overall performance compared to the original models. How would you go about implementing this idea?

### Step 1: Implementing the Layer-Wise Median Merger

The first step is to create a function that performs the merging operation on a layer-by-layer basis. This function will be wrapped with the `@dict_map` decorator to apply your merging logic across all layers of the model.

In this example, the merging strategy involves stacking the weights from each model and then taking the median across this stacked dimension. This approach ensures that the resulting weights are representative of the central tendency across all models.

Here’s how you can implement this:

In [10]:
%load_ext autoreload
%autoreload 2
from transformers import pipeline
from typing import List
from torch import Tensor
import torch
import torch.nn as nn 

from mergecraft import layer_merge, StateDict, dict_map, model_merge

@dict_map
def median_layer_merging(models: List[Tensor], **kw) -> Tensor:
    return torch.median(torch.stack(models), dim=0).values

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Step 2: Applying the Median Merger to the Entire Model

Once the layer-wise merging function is defined, you can use it within a model-merging function. This function will be decorated with `@model_merge` to handle the higher-level tasks, like converting models to `StateDict` and back.

In [11]:
@model_merge
def median(models: List[StateDict], **kw) -> StateDict:
    return median_layer_merging(models)

How It Works:

1. **Layer-wise Median Calculation**: The `median_merge_layers` function stacks the corresponding layers from each model and computes the median value across them.
2. **Model-wide Application**: The `@dict_map` decorator applies this median calculation across all layers in the model, while `@model_merge` manages the conversion and overall structure.

This process allows you to easily merge multiple models using the median of their weights, potentially improving performance by leveraging the central tendencies of the models' parameters.

## Evaluating the Merged Model

To assess the effectiveness of our novel merging approach, we can apply it to a real-world scenario. For this example, we'll use the BERT model and the RTE (Recognizing Textual Entailment) task from the GLUE benchmark.


First, we'll load a set of pre-trained BERT models that have been fine-tuned on the RTE task. These models will be used as input for our median merging function.

Next, we will apply our median merging function to these models. The process involves merging the models' weights using the median strategy, as previously defined, and then evaluating the merged model's performance on the text classification task.


### Explanation:

- **Model Selection**: The list of models includes various versions of the BERT model, each fine-tuned on the RTE task. This diversity can help ensure that the median merger captures a well-rounded set of weights.

- **Task Specification**: The `task='text-classification'` parameter tells the library that we're working with a classification task, which is essential for correctly loading the models from huuggingface.



In [13]:
%%time
models = ['textattack/bert-base-uncased-RTE', 'yoshitomo-matsubara/bert-base-uncased-rte', 'Ruizhou/bert-base-uncased-finetuned-rte', 'howey/bert-base-uncased-rte', 'anirudh21/bert-base-uncased-finetuned-rte']

merged_model = median(models, task='text-classification')

CPU times: total: 9min 37s
Wall time: 2min 9s


In [14]:
from mergecraft import evaluate_glue_pipeline
evaluate_glue_pipeline(merged_model, 'rte')

{'accuracy': 0.592057761732852}

### Wrapping Up

Congratulations! You've just implemented, run, and evaluated a completely novel model merging approach. While the process was successful, the results didn't quite meet our expectations—our merged model achieved an accuracy of 0.59, which is lower than the baseline models, whose accuracies range between 0.66 and 0.72.

### Reflections on the Results

Although this method wasn't as effective as we initially hoped and proved to be quite time-consuming, this is all part of the experimentation process. Every attempt, successful or not, brings valuable insights. By understanding what doesn't work, we're one step closer to discovering what does.

### Looking Forward

Don't be discouraged by these results! The beauty of this framework is its flexibility—you can easily tweak, modify, and re-run your merging strategies as often as needed. With each iteration, you're refining your approach and inching closer to a more effective solution.

Remember, innovation often comes from persistence. So keep experimenting, iterating, and learning until you develop a merging method that not only works but excels.

### Simplified Merging with `@layer_merge`

If you look closely at the `median()` function, you'll notice that it's mainly passing parameters down the chain without doing much on its own. This is typical for simple merging methods like the ones we've explored. 

To further simplify the construction of merging methods, the `mergecraft` library provides a `@layer_merge` decorator. This decorator combines the functionalities of `@dict_map` and `@model_merge` into one, making the process more straightforward. Here’s how it works:

```python
def layer_merge(func) -> Callable:
    func = dict_map(func)
    func = model_merge(func)
    return func
```

With `@layer_merge`, you can simplify the median merger even further

In [None]:
@layer_merge
def median(models: List[Tensor], **kw) -> Tensor:
    return torch.median(torch.stack(models), dim=0).values