Install requirements

In [None]:
!pip install -r requirements.txt

In [None]:
import logging
import numpy as np
import evaluate

from datasets import load_dataset
from torch.utils.data import DataLoader, Subset

from transformers import (
    AutoTokenizer,
    default_data_collator,
    TrainingArguments,
    AutoModelForSequenceClassification,
    AutoConfig
)

from training import train_supernetwork
from search import multi_objective_search

logger = logging.getLogger(__name__)

In [None]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

task_name = 'rte'
model_type = 'bert-base-cased'
output_dir = 'nas_output_dir'

#TODO: fix
max_seq_length = 128
per_device_train_batch_size = 8
per_device_eval_batch_size = 8

raw_datasets = load_dataset(
    "glue", task_name
)

metric = evaluate.load("glue", task_name)


tokenizer = AutoTokenizer.from_pretrained(model_type)

if model_type.startswith("gpt2"):
    tokenizer.pad_token = tokenizer.eos_token

# Preprocessing the raw_datasets
sentence1_key, sentence2_key = task_to_keys[task_name]

# Padding strategy
padding = "max_length"

max_seq_length = min(max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples):
    # Tokenize the texts
    args = (
        (examples[sentence1_key],)
        if sentence2_key is None
        else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(
        *args, padding=padding, max_length=max_seq_length, truncation=True
    )

    # Map labels to IDs (not necessary for GLUE tasks)
    # if label_to_id is not None and "label" in examples:
    #     result["label"] = [
    #         (label_to_id[l] if l != -1 else -1) for l in examples["label"]
    #     ]
    return result

raw_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    desc="Running tokenizer on dataset",
)

label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)

train_dataset = raw_datasets["train"]
test_dataset = raw_datasets[
    "validation_matched" if task_name == "mnli" else "validation"
]

train_dataset = train_dataset.remove_columns(["idx"])
test_dataset = test_dataset.remove_columns(["idx"])

# Split training dataset in training / validation
split = train_dataset.train_test_split(
    train_size=0.7, seed=0
)  # fix seed, all trials have the same data split
valid_dataset = split["test"]

if task_name in ["sst2", "qqp", "qnli", "mnli"]:
    valid_dataset = Subset(
        valid_dataset,
        np.random.choice(len(valid_dataset), 2048).tolist(),
    )

data_collator = default_data_collator

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=per_device_train_batch_size,
    collate_fn=data_collator,
)
eval_dataloader = DataLoader(
    valid_dataset,
    batch_size=per_device_eval_batch_size,
    collate_fn=data_collator,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=per_device_eval_batch_size,
    collate_fn=data_collator,
)

In [None]:
config = AutoConfig.from_pretrained(
    model_type,
    num_labels=num_labels,
    finetuning_task=task_name,
)

model = AutoModelForSequenceClassification.from_pretrained(
    model_type,
    config=config,
)

In [None]:
training_args = TrainingArguments(output_dir=output_dir)
training_args.search_space = 'small'
training_args.use_accelerate = False # set this to True to distribute training on multiple GPUs
training_args.is_regression = False  # set this to True if your dataset is a regression dataset, for example STSB
training_args.save_strategy = "epoch"
training_args.log_dir = '.log_dir'

train_supernetwork(model, train_dataloader, eval_dataloader, metric, training_args)

In [None]:
metric_name = 'accuracy'
training_args.num_samples = 5
pareto_set = multi_objective_search(model, eval_dataloader, metric, metric_name, training_args)

In [None]:
import matplotlib.pyplot as plt
plt.scatter(pareto_set['params'], pareto_set['error'])


In [None]:
# model = get_final_model(pareto_set.masks[0])