# Backward Compatibility During Data Updates by Weight Interpolation

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

---

## Introduction

Backward compatibility of model predictions is a desired property when updating a machine learning driven application. It allows to seamlessly improve the underlying model without introducing regression bugs. In classification tasks these bugs occur in the form of negative flips: A new model incorrectly predicts the output for a test sample that was correctly classified by the old (reference) model.

A common reason to update the model is when new training data becomes available and needs to be incorporated. Simply retraining the model with the updated data introduces the unwanted negative flips. 

In this notebook we introduce and implement Backward Compatible Weight Interpolation (BCWI) that reduces reduce regression during data updates. This method interpolates between the weights of the old and new model.

As demonstrated in the notebook implementation below, our BCWI technique provides signigicant 40% reduction in negative flip rate. Further details of the method will be released in the upcoming paper.

If you find this notebook useful please consider citing our prepring
```
@article{Schumann2023BCWI,
title={Backward Compatibility During Data Updates by Weight Interpolation},
author={Raphael Schumann and Elman Mansimov and Yi-An Lai and Nikolaos Pappas and Xibin Gao and Yi Zhang},
journal={ArXiv},
year={2023}}
```

Also, we suggest you checkout out related notebook on improving backward-compatibility during model achitecture updates [LINK](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_regression_free_training/Amazon_JumpStart_Regression_Free_Training.ipynb)

## Install required python libraries (transformers and datasets)

In [None]:
# Install huggingface transformers and datasets
!pip install transformers
!pip install datasets

In [None]:
# imports
import json
import os
import torch
import transformers
import datasets
import requests
import copy

from transformers import RobertaTokenizer

## Shared variables pointing to the data url and model name

In [None]:
# Shared Variables
DATA_URL = "https://raw.githubusercontent.com/amazon-science/regression-constraint-model-upgrade/main/nlp/data/MASSIVE/"
PT_MODEL_NAME = "roberta-base"

## Setup tokenizer

In [None]:
# Load tokenizer
tokenizer = []


def load_tokenizer():
    if len(tokenizer) == 0:
        tokenizer.append(RobertaTokenizer.from_pretrained(PT_MODEL_NAME))


load_tokenizer()

print(tokenizer)

In [None]:
def tokenize_function(examples):
    """Helper tokenize function"""
    return tokenizer[0](
        examples["text"], padding=False, truncation=True, return_attention_mask=False
    )


def load_dataset(splits, file_template):
    """Helper function to load MASSIVE dataset"""
    data_files = dict()
    for split in splits:
        data_files[split] = file_template.format(split)

    print(data_files)
    dataset = datasets.load_dataset("json", data_files=data_files)
    dataset_info = json.loads(requests.get(file_template.format("info")[:-1]).content)

    new_label_ids = [dataset_info["labels"].index(c) for c in dataset_info["add_classes"]]
    old_label_ids = [
        i for i, c in enumerate(dataset_info["labels"]) if c not in dataset_info["add_classes"]
    ]
    dataset_info["new_label_ids"] = new_label_ids
    dataset_info["old_label_ids"] = old_label_ids

    tokenized_dataset = dataset.map(tokenize_function, batched=False)

    return tokenized_dataset, dataset_info

## Download the initial dataset used before new data is introduced (old dataset)

In [None]:
# COMPLETE: Pull the data before update
old_dataset_files = os.path.join(DATA_URL, "add_data", "old", "{}.jsonl")
old_dataset, old_dataset_info = load_dataset(["train", "dev", "test"], old_dataset_files)

# should contain 1000 lines in train, 333 lines in dev, and 4000 lines in test
print("Old dataset before update")
print(old_dataset)

print("Old dataset before update info")
print(old_dataset_info)

## Set up SageMaker and SageMaker Training Environment

In [None]:
## Setup Sagemaker environment
import sagemaker
from sagemaker.huggingface import HuggingFace
import boto3

# Use remote mode
sagemaker_region = boto3.Session().region_name
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()

role = sagemaker.get_execution_role()
instance_type = "ml.p3.2xlarge"

print(f"sagemaker region: {sagemaker_region}")
print(f"sagemaker session: {sagemaker_session}")
print(f"bucket name: {bucket_name}")
print(f"role: {role}")
print(f"instance type: {instance_type}")

In [None]:
# COMPLETE: Train the model using the data before update
# Something along these lines

# git configuration to download regression-free training script
git_config = {
    "repo": "https://github.com/amazon-science/regression-constraint-model-upgrade.git",
    "branch": "main",
}

huggingface_estimator = HuggingFace(
    entry_point="train.py",
    source_dir="nlp",
    git_config=git_config,
    instance_type=instance_type,
    instance_count=1,
    role=role,
    transformers_version="4.17.0",
    pytorch_version="1.10.2",
    py_version="py38",
    hyperparameters={
        "dataset": "MASSIVE",
        "scenario": "add_data",
        "data_type": "old",
        "bucket_name": bucket_name,
    },
)

## Train the initial model before data update

In [None]:
huggingface_estimator.fit()

## Pull the trained model from S3 in order to initialize the new model and inference 

In [None]:
# COMPLETE: Load old model from S3
# Used for inference and calculating negative flip rate
import boto3

s3 = boto3.client("s3")

load_from_s3 = f"./bcwi_nlp_outputs/v1/MASSIVE/1111/add_data/old_model/model"

print(f"Loading from S3 {load_from_s3}")
os.makedirs("old_model_dir", exist_ok=True)
# load stuff from s3
with open("old_model_dir/config.json", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "config.json"), f)
    print("Downloaded old_model_dir/config.json")
with open("old_model_dir/hparams.json", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "hparams.json"), f)
    print("Downloaded old_model_dir/hparams.json")
with open("old_model_dir/pytorch_model.bin", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "pytorch_model.bin"), f)
    print("Downloaded old_model_dir/pytorch_model.bin")

from transformers import RobertaForSequenceClassification

old_model = RobertaForSequenceClassification.from_pretrained("old_model_dir")

print(old_model)

## Download the updated dataset with new data added to it (new dataset)

In [None]:
# COMPLETE: Pull the data after update
new_dataset_files = os.path.join(DATA_URL, "add_data", "updated", "{}.jsonl")
new_dataset, new_dataset_info = load_dataset(["train", "dev", "test"], new_dataset_files)

# should contain 1500 lines in train, 500 lines in dev, and 4000 lines in test
print("New dataset after update")
print(new_dataset)

print("New dataset after update info")
print(new_dataset_info)

## Train the new model using updated data

In [None]:
# COMPLETE: Get the old data checkpoint and continue training on the new data

# git configuration to download regression-free training script
git_config = {
    "repo": "https://github.com/amazon-science/regression-constraint-model-upgrade.git",
    "branch": "main",
}

huggingface_estimator_new = HuggingFace(
    entry_point="train.py",
    source_dir="nlp",
    git_config=git_config,
    instance_type=instance_type,
    instance_count=1,
    role=role,
    transformers_version="4.17.0",
    pytorch_version="1.10.2",
    py_version="py38",
    hyperparameters={
        "dataset": "MASSIVE",
        "scenario": "add_data",
        "data_type": "updated",
        "load_from_s3": f"./bcwi_nlp_outputs/v1/MASSIVE/1111/add_data/old_model/model",
        "bucket_name": bucket_name,
        "output_dir": "add_data",
        "num_epochs": 3,
    },
)

In [None]:
huggingface_estimator_new.fit()

## Pull the newly trained model from S3 for inference

In [None]:
# COMPLETE: Load new model from S3
# Used for inference and calculating negative flip rate

load_from_s3 = f"./bcwi_nlp_outputs/add_data/MASSIVE/1111/add_data/old_model/model"

print(f"Loading from S3 {load_from_s3}")
os.makedirs("new_model_dir", exist_ok=True)
# load stuff from s3
with open("new_model_dir/config.json", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "config.json"), f)
    print("Downloaded new_model_dir/config.json")
with open("new_model_dir/hparams.json", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "hparams.json"), f)
    print("Downloaded new_model_dir/hparams.json")
with open("new_model_dir/pytorch_model.bin", "wb") as f:
    s3.download_fileobj(bucket_name, os.path.join(load_from_s3, "pytorch_model.bin"), f)
    print("Downloaded new_model_dir/pytorch_model.bin")

from transformers import RobertaForSequenceClassification

new_model = RobertaForSequenceClassification.from_pretrained("new_model_dir")

print(new_model)

## Helper function to calculate accuracy

In [None]:
# Prepare for calculating negative flip rate and accuracy
test_data = new_dataset["test"]  # test sets in old and new datasets are the same


def calculate_accuracy(data, model, batch_size=40):
    all_preds = []
    all_labels = []
    for i in range(0, len(data), batch_size):
        if int(i / batch_size) % 10 == 0:
            print(f"working on {int(i/batch_size)} out of {int(len(data)/batch_size)}")
        torch.cuda.empty_cache()  # clear memory
        # process examples in the batch
        examples = data[i : (i + batch_size)]
        text = examples["text"]
        label = examples["label"]
        text_tokenizer = tokenizer[0](
            text,
            padding="max_length",
            truncation=True,
            max_length=max([len(t.split()) for t in text]),
            return_tensors="pt",
        )

        with torch.no_grad():
            outputs = model(**text_tokenizer)
        preds = outputs.logits.argmax(-1).tolist()

        # merge them into the list that combines all labels and predictions
        all_preds.extend(preds)
        all_labels.extend(label)
    return all_preds, all_labels

## Calculate accuracy of old and new models together with a negative flip rate

In [None]:
# Get old model prediction
print("Getting old model predictions")
old_model_preds, labels = calculate_accuracy(test_data, old_model)

In [None]:
# Get new model prediction
print("Getting new model predictions")
new_model_preds, _ = calculate_accuracy(test_data, new_model)

In [None]:
# Get old model accuracy
old_acc = (
    100 * sum([old_pred == l for old_pred, l in zip(old_model_preds, labels)]) / float(len(labels))
)
print(f"Old model accuracy {old_acc}%")

# Get new model accuracy
new_acc = (
    100 * sum([new_pred == l for new_pred, l in zip(new_model_preds, labels)]) / float(len(labels))
)
print(f"New model accuracy {new_acc}%")

# Calculate negative flip rate
nfr = (
    100
    * sum(
        [
            old_pred == l and new_pred != l
            for old_pred, new_pred, l in zip(old_model_preds, new_model_preds, labels)
        ]
    )
    / float(len(labels))
)
print(f"Negative Flip Rate {nfr}%")

## Backward compatible weight interpolation (our core method to reduce regression)

Backward compatible weight interpolation (BCWI) is defined as 

\begin{equation}
    \theta_{\mathrm{BCWI}} = \alpha \theta_{old} + (1-\alpha) \theta_{new},
\end{equation}

Where $\alpha \in [0.0, 1.0]$, $\theta_{old}$ are old model parameters and $\theta_{new}$ are new model parameters.

We use $\alpha = 0.3$ in this notebook 

In [None]:
# Interpolate between old and new model
# More details can be found in the Github repo


def interpolate_weights(old_model, new_models, alpha, new_label_ids=None, weighted=None):
    # Form soup ensemble of new models
    new_state_dicts = [new_model.state_dict() for new_model in new_models]
    new_model_state_dict = dict()
    for key in new_models[0].state_dict():
        if not (key.endswith("bias") or key.endswith("weight")):
            continue

        new_model_state_dict[key] = torch.mean(
            torch.stack([s[key] for s in new_state_dicts]), dim=0
        )

    print("alpha", alpha)
    metrics = dict()

    # Use the old model as the basis of the interpolated model weights
    model = copy.deepcopy(old_model)
    # All weights of a model can be accessed by its state_dict
    state_dict = model.state_dict()
    for key in state_dict:
        # Be sure to only interpolate weight matrices; includes e.g. layer norm matrices
        if not (key.endswith("bias") or key.endswith("weight")):
            continue

        if weighted is not None:
            # when alpha = 1.0, there can be NaN values due to numerical instabilities when values in the weight
            # matrix are too small. In this case we replace the NaNs with the weights of the old model.
            if alpha == 1.0:
                c = state_dict[key].detach().clone()

            # Inplace operations to modify the weights of the model.
            # State_dict initially holds the weights of the old model.
            state_dict[key] *= alpha * weighted[key]
            state_dict[key] += (1 - alpha) * new_model_state_dict[key]
            state_dict[key] /= alpha * weighted[key] + (1 - alpha)

            # Three lines above as one-liner
            # state_dict[key].data.copy_(((alpha * weighted[key] * state_dict[key]) + ((1-alpha) * new_model_state_dict[key])) / (alpha * weighted[key] + (1-alpha)))

            if alpha == 1.0:
                nans = state_dict[key] != state_dict[key]
                state_dict[key][nans] = c[nans]
        else:
            # Simple linear interpolation with parameter alpha.
            # State_dict initially holds the weights of the old model.
            state_dict[key] *= alpha
            state_dict[key] += (1 - alpha) * new_model_state_dict[key]

        # Copy classifier weights of new classes from the new model. The old model was not trained on those classes.
        if new_label_ids:
            if key == "classifier.out_proj.weight":
                state_dict[key][new_label_ids, :] = new_model_state_dict[key][new_label_ids, :]
            if key == "classifier.out_proj.bias":
                state_dict[key][new_label_ids] = new_model_state_dict[key][new_label_ids]
    return state_dict

In [None]:
# Interpolated model
interpolated_model = RobertaForSequenceClassification.from_pretrained("old_model_dir")

# Initialize interpolated model with old model weights
interpolated_state_dict = interpolate_weights(interpolated_model, [new_model], alpha=0.3)
# load interpolated_state_dict into new model
interpolated_model.load_state_dict(interpolated_state_dict, strict=True)

In [None]:
# Get interpolated model prediction
print("Getting interpolated model predictions")
interpolated_model_preds, _ = calculate_accuracy(test_data, interpolated_model)

## Calculate accuracy and negative flip rate of interpolated model. Compare it to the baselines

In [None]:
# Get old model accuracy
old_acc = (
    100 * sum([old_pred == l for old_pred, l in zip(old_model_preds, labels)]) / float(len(labels))
)
print(f"Old model accuracy {old_acc}%")

# Get new model accuracy
new_acc = (
    100 * sum([new_pred == l for new_pred, l in zip(new_model_preds, labels)]) / float(len(labels))
)
print(f"New model accuracy {new_acc}%")

# Get interpolated model accuracy
interpolated_acc = (
    100
    * sum(
        [interpolated_pred == l for interpolated_pred, l in zip(interpolated_model_preds, labels)]
    )
    / float(len(labels))
)
print(f"Interpolated model accuracy {interpolated_acc}%")


# Calculate negative flip rate
nfr = (
    100
    * sum(
        [
            old_pred == l and new_pred != l
            for old_pred, new_pred, l in zip(old_model_preds, new_model_preds, labels)
        ]
    )
    / float(len(labels))
)
print(f"Negative Flip Rate {nfr}%")

# Calculate negative flip rate of old model and interpolated models
interpolate_nfr = (
    100
    * sum(
        [
            old_pred == l and interpolate_pred != l
            for old_pred, interpolate_pred, l in zip(
                old_model_preds, interpolated_model_preds, labels
            )
        ]
    )
    / float(len(labels))
)
print(f"Negative Flip Rate Old and Interpolated (Ours) {interpolate_nfr}%")

At the running the last cell of the notebook you should see the following outputs:

```
Old model accuracy 81.7%
New model accuracy 82.55%
Interpolated model accuracy 82.65%
Negative Flip Rate 3.25%
Negative Flip Rate Old and Interpolated (Ours) 2.225%
```

Hope you found the implementation helpful!

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart_regression_free_training|Amazon_JumpStart_NLP_Regression_Free_Training.ipynb)
