# Data Efficient Neural Scaling Laws via Model Reusing



**Peihao Wang, Rameswar Panda, Zhangyang (Atlas) Wang**

This notebook is aimed at helping reproduce Figure 4(a) of ICML 2023 paper [Data Efficient Neural Scaling Law via Model Reusing](https://openreview.net/pdf?id=iXYnIz4RRx).


In [None]:
import os
import sys
import shutil
import pickle
import random
from typing import Dict, List, Tuple
from datetime import datetime
import time
import math

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from tqdm import tqdm, trange

from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
)

from data import CoLDataset
from model import SimpleBertForMaskedLM, SimpleRobertaForMaskedLM

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter


## Utilities

In [None]:

MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, SimpleBertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, SimpleRobertaForMaskedLM, RobertaTokenizer),
}


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


## Mask language modeling

In [None]:

def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, mlm_probability=0.15) -> Tuple[torch.Tensor, torch.Tensor]:
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels


## Hyperparameters

In [None]:

# Tunable hyperparameters
seed = 0 # @param {type:"integer"}
gpu = 0 # @param {type:"integer"}

eval_data_file = 'data/wiki-cased/en.valid.raw' # @param {type:"string"}

checkpoints_path = './data-efficient-scaling/' # @param {type:"string"}

batch_size = 64 # @param {type:"integer"}

# Fixed hyperparameters
tokenizer_name = 'bert-base-uncased'
model_type = 'bert'
block_size = 126
split_sent = True
cache_dir = None


configs = ['12L-64H', '12L-128H', '12L-192H', '12L-256H', '12L-320H', '12L-384H', '12L-448H', '12L-512H', '12L-576H', '12L-640H']
ratios = [0.009, 0.005, 0.004, 0.003, 0.002, 0.001][::-1]


# Get class names
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]


## Count model parameters

In [None]:

def compute_num_params(conf):
    conf_path = f'configs/bert-{conf}.json'
    config = config_class.from_pretrained(conf_path, cache_dir=None)
    model = model_class(config=config)
    return sum(p.numel() for p in model.parameters())

num_model_params = {conf: compute_num_params(conf) for conf in configs}


## Evaluate checkpoints

In [None]:

def evaluate(eval_dataset, eval_batch_size, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device, prefix="", mlm=True) -> Dict:

    # Note that DistributedSampler samples randomly

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size, collate_fn=collate
    )

    # Eval!
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = mask_tokens(batch, tokenizer) if mlm else (batch, batch)
        inputs = inputs.to(device)
        labels = labels.to(device)
        # If some of the input is padded, then the attention mask is needed
        attention_mask = (inputs != tokenizer.pad_token_id)  # word_tokens --> 1, pad_token --> 0
        if attention_mask.all():
            attention_mask = None

        with torch.no_grad():
            outputs = model(inputs, attention_mask=attention_mask, masked_lm_labels=labels) if mlm else model(inputs, labels=labels)
            lm_loss = outputs['lm_loss']
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss)).item()

    result = {"perplexity": perplexity}

    return result


In [None]:
# Setup CUDA & GPU
torch.cuda.set_device(gpu)
device = torch.device("cuda", gpu)

ppl_results = {}

for ratio in ratios:
    for conf in configs:
        for k in [f'bert-{conf}-{ratio}D', f'bert-n2n-{conf}-{ratio}D']:

            model_name_or_path = os.path.join(checkpoints_path, k)
            
            print(f"Evaluating checkpoint at {model_name_or_path}")

            # Set seed
            set_seed(seed)


            # Get config
            config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)

            # Get tokenizer
            tokenizer = tokenizer_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
            assert block_size <= tokenizer.model_max_length

            # Load model
            model = model_class.from_pretrained(
                model_name_or_path,
                from_tf=bool(".ckpt" in model_name_or_path),
                config=config,
                cache_dir=cache_dir
            )
            model.to(device)


            # Loop to handle MNLI double evaluation (matched, mis-matched)
            eval_dataset = CoLDataset(eval_data_file, tokenizer_name, tokenizer, block_size, split_sent=split_sent, verbose=False)

            result = evaluate(eval_dataset, batch_size, model, tokenizer, device)

            ppl_results[k] = math.log(result['perplexity'])


## Load or save results for next time

In [None]:

### uncomment to read results from a saved one
# if os.path.exists('ppl_results.pkl'):
#     with open('ppl_results.pkl', 'rb') as f:
#         # read information from file
#         ppl_results = pickle.load(f)

# save results
with open('ppl_results.pkl', 'wb') as f:
    # dump information to that file
    pickle.dump(ppl_results, f)


## Plot curves

In [None]:

# setup figure
NEW_SIZE = 20
plt.rcParams["font.family"] = "DejaVu Sans"
plt.rcParams["font.size"] = NEW_SIZE
fig = plt.figure(figsize=(8, 8))
[x.set_linewidth(2.) for x in plt.gca().spines.values()]


# colormaps
cols_scratch = matplotlib.cm.get_cmap('GnBu', len(ratios))(np.linspace(0.4, 1., len(ratios)))
cols_n2n = matplotlib.cm.get_cmap('YlOrBr', len(ratios))(np.linspace(0.3, 1., len(ratios)))

# plot curves
for ratio, col_scratch, col_n2n in zip(ratios, cols_scratch, cols_n2n):
    xs = []
    ys_scratch = []
    ys_n2n = []

    for conf in configs:
        
        num_params = num_model_params[conf] / 1e6

        xs.append(int(num_params))
        ys_scratch.append(ppl_results[f'bert-{conf}-{ratio}D'])
        ys_n2n.append(ppl_results[f'bert-n2n-{conf}-{ratio}D'])

    ratio_txt = f'{float((ratio if ratio > 0 else 1.) * 100):.1f}%'
    xs, ys_scratch, ys_n2n = np.array(xs), np.array(ys_scratch), np.array(ys_n2n)

    plt.plot(xs, ys_scratch, linewidth=3.5, linestyle='dotted', alpha=0.8, color=col_scratch)
    plt.scatter(xs, ys_scratch, s=150, color=col_scratch, alpha=1., linewidths=0, marker='*')

    plt.plot(xs, ys_n2n, label=ratio_txt, linewidth=3.5, alpha=0.8, color=col_n2n)
    plt.scatter(xs, ys_n2n, s=100, color=col_n2n, alpha=1., linewidths=0, marker='o')
    
leg = plt.legend(title="Data Frac.", bbox_to_anchor=(1., 1), frameon=False, prop={'size': 18})
plt.setp(leg.get_title(), fontsize=18)

plt.xlabel('# Param. (M)')
plt.ylabel('Log Perplexity')

