In [None]:
!pip install sentence-transformers

In [None]:
# utils
import os
import csv
import gzip
import pandas as pd
import typing as t
from sentence_transformers import util
from contextlib import contextmanager

ALLNLI_DATASET_URL = "https://sbert.net/datasets/AllNLI.tsv.gz"
STS_BENCHMARK_DATASET_URL = "https://sbert.net/datasets/stsbenchmark.tsv.gz"


def download_dataset(url: str, download_path: str) -> None:
    """Download dataset from given url to download_path.

    Args:
        url (str): Dataset URL.
        download_path (str): Path to download.
    """
    if not os.path.exists(download_path):
        util.http_get(url, download_path)


@contextmanager
def open_dataset(dataset_path: str) -> t.Generator[csv.DictReader, None, None]:
    """Open dataset within a context manager.

    Args:
        dataset_path (str): Path to dataset file.

    Yields:
        Iterator[t.Generator[csv.DictReader, None, None]]: CSV reader.
    """
    file = gzip.open(dataset_path, "rt", encoding="utf8")
    reader = csv.DictReader(file, delimiter="\t", quoting=csv.QUOTE_NONE)
    yield reader
    file.close()


def read_as_dataframe(dataset_path: str) -> pd.DataFrame:
    """Read from given dataset path as DataFrame."""
    with open_dataset(dataset_path) as reader:
        columns = reader.fieldnames
        lines = [[row[field] for field in columns] for row in reader]
        return pd.DataFrame(lines, columns=columns)


# genetics
import math
import typing as t
from random import Random
from copy import deepcopy

random_state = Random(42)


class Gene:
    @property
    def value(self) -> int:
        """Current value of the gene."""
        return self._value

    @property
    def value_range(self) -> t.Tuple[int, int]:
        """Value range for the range."""
        return self._value_range

    def __init__(
        self,
        value_range: t.Tuple[int, int],
        initial_value: t.Optional[int] = None,
    ) -> None:
        """Gene base class.

        Args:
            value_range (t.Tuple[int, int]): Value range in `[min, max)` order.
            initial_value (t.Optional[int], optional): Initial value. If None, then init with a random value. Defaults to None.
        """
        assert value_range[0] < value_range[1]
        self._value_range = value_range
        if initial_value:
            a, b = value_range
            assert a <= initial_value <= b
            self._value = initial_value
        else:
            self._value = self._rand_value()

    def __repr__(self) -> str:
        return str(self.value)

    def _rand_value(self) -> int:
        """Generate random value."""
        a, b = self._value_range
        return random_state.choice(range(a, b))

    def randomize(self) -> None:
        """Randomize current value."""
        self._value = self._rand_value()


class Chromosome:
    @property
    def fitness(self) -> float:
        """Current fitness value for the chromosome."""
        return self._fitness

    @property
    def genes(self) -> t.List[Gene]:
        """List of current genes."""
        return self._genes

    @property
    def length(self) -> int:
        """Number of genes."""
        return self._length

    def __init__(
        self,
        length: int,
        initial_genes: t.Optional[t.List[Gene]] = None,
        **kwargs,
    ) -> None:
        """Chromosome base class.

        Args:
            length (int): Number of genes (ie. chromosome length).
            initial_genes (t.Optional[t.List[Gene]], optional): Initial genes. If None, then initialize with random genes. Defaults to None.
        """
        assert length > 0
        self._length = length
        if not initial_genes:
            assert "value_range" in kwargs
            self._value_range = kwargs["value_range"]
            self._genes = self._rand_genes()
        else:
            self._genes = initial_genes
        self._fitness = 0

    def __repr__(self) -> str:
        return str(self._genes)

    def _rand_genes(self) -> t.List[Gene]:
        """Generate random genes."""
        return [Gene(self._value_range) for _ in range(self._length)]

    def to_list(self) -> t.List[int]:
        """Convert gene values to list."""
        return [g.value for g in self._genes]

    def apply_fitness_fn(self, fn: t.Callable[["Chromosome"], float]) -> None:
        """Apply fitness function. After applying, updates its current fitness value.

        Args:
            fn (t.Callable[[&quot;Chromosome&quot;], float]): Fitness function.
        """
        self._fitness = fn(self)

    def mutate(self, rate: float) -> None:
        """Mutation operator.

        Args:
            rate (float): Mutation rate in `[0, 1]`.
        """
        for gene in self._genes:
            if random_state.random() < rate:
                gene.randomize()

    def crossover(self, other: "Chromosome") -> "Chromosome":
        """Crossover operator.

        Args:
            other (Chromosome): Other chromosome to crossover.

        Returns:
            Chromosome: New (child) chromosome.
        """
        child_genes = []
        mid_point = random_state.randint(0, self._length)
        for i in range(self._length):
            if i < mid_point:
                child_genes.append(self._genes[i])
            else:
                child_genes.append(other.genes[i])
        return Chromosome(self._length, child_genes)


class Population:
    @property
    def size(self) -> int:
        """Population size."""
        return self._size

    @property
    def global_best(self) -> Chromosome:
        """Global best chromosome."""
        return self._global_best

    @property
    def local_best(self) -> Chromosome:
        """Local best chromosome (for last evaluation)."""
        return self._local_best

    @property
    def best_chromosomes(self) -> t.List[Chromosome]:
        """Best of local bests."""
        return self._best_chromosomes

    def __init__(
        self,
        size: int,
        mutation_rate: float,
        initial_chromosomes: t.Optional[t.List[Chromosome]] = None,
        keep_best_chromosomes: bool = True,
        **kwargs,
    ) -> None:
        """Population base class.

        Args:
            size (int): Population size.
            mutation_rate (float): Mutation rate in `[0, 1]` range.
            initial_chromosomes (t.Optional[t.List[Chromosome]], optional): Initial chromosomes. If None, then init with random chromosomes. Defaults to None.
            keep_best_chromosomes (bool): Keep local bests on each evaluation. Defaults to True.
        """
        self._size = size
        self._mutation_rate = mutation_rate
        if not initial_chromosomes:
            assert "value_range" in kwargs
            assert "length" in kwargs
            self._value_range = kwargs["value_range"]
            self._length = kwargs["length"]
            self.chromosomes = self._rand_chromosomes()
        else:
            self.chromosomes = initial_chromosomes
        self._global_best = deepcopy(self.chromosomes[0])
        self.keep_best_chromosomes = keep_best_chromosomes
        self._best_chromosomes = []

    def _rand_chromosomes(self) -> t.List[Chromosome]:
        """Generate random chromosomes."""
        return [
            Chromosome(self._length, value_range=self._value_range)
            for _ in range(self._size)
        ]

    def eval(self, fitness_fn: t.Callable[[t.List[int]], float]) -> None:
        """Evaluate all chromosomes with given function.

        Args:
            fitness_fn (t.Callable[[t.List[int], float]]): Fitness function to evaluate.
        """
        # ? apply fitnes function to all
        self._local_best = self.chromosomes[0]
        for c in self.chromosomes:
            c.apply_fitness_fn(lambda c: fitness_fn(c.to_list()))
            if c.fitness > self._local_best.fitness:
                self._local_best = c
        if self._local_best.fitness > self.global_best.fitness:
            self._global_best = deepcopy(self._local_best)
        if self.keep_best_chromosomes:
            self._best_chromosomes.append(deepcopy(self._local_best))

    def update(self) -> None:
        """Update all chromosomes for its' current states."""
        # apply natural selection
        mating_pool: t.List[Chromosome] = []
        for c in self.chromosomes:
            n = math.floor((c.fitness / self._local_best.fitness) * 100)
            mating_pool.extend([c for _ in range(n)])
        # create next generation
        pool_size = len(mating_pool) - 1
        for i in range(self._size):
            c1 = mating_pool[random_state.randint(0, pool_size)]
            c2 = mating_pool[random_state.randint(0, pool_size)]
            child = c1.crossover(c2)
            child.mutate(self._mutation_rate)
            self.chromosomes[i] = child

# Genetic Algorithms based Knowledge Distillation

This notebook simply introduces a knowledge distillation application with a pretrained model built on Transformer architecture and uses genetics algorithms to isolate best fitted pretrained layers from base (teacher) model.

In [None]:
# import neural network libraries and utils
import torch
from torch.utils.data import DataLoader
from sentence_transformers import models, losses, evaluation
from sentence_transformers import (
    LoggingHandler,
    SentenceTransformer,
    util,
    InputExample,
)
from sentence_transformers.datasets import ParallelSentencesDataset
# import general utils and helper libraries
import random
import typing as t
from datetime import datetime

## Set parameters

You can set all related parameters with this project, including genetic hyperparameters.

In [None]:
# set utility variables
random_state = random.Random(42) # random state to get same result for every run this notebook
# set global variables
model_name = "all-MiniLM-L12-v2" # module to be distilled (ie. teacher model)
output_path = f"output/{model_name}_" + datetime.now().strftime(r"%Y-%m-%d_%H-%M-%S") # output path to save model file and evaluation results
max_train_samples = 1_000 # maximum number of training samples
train_batch_size = 32 # batch size for training
inference_batch_size = 32 # batch size for trained model
max_sentence_length = 256 # maximum char length for each sample (sentence) in the training set
### standard neural network hyperparameters ###
epochs = 1 
warmup_steps = 1000
evaluation_steps = 5000
learning_rate = 1e-4
epsilon = 1e-6
### standard neural network hyperparameters ###
# set hyperparameters for genetic algorithms
max_generation = 10 # maximum number of generations (ie. max iteration)
population_size = 10 # population size (ie. number of chromosome for each generation)
mutation_rate = 0.01 # mutation rate (%)
chromosome_length = 10 # number of layers (because each gene represents a layer's indice)
gene_value_range = (0, 12)  # value range for a gene => [a, b) means a is included while b is excluded. 

## Load teacher model

Teacher model is the base model to be distilled. We will select its encoder layers within genetics processes to train best distilled model.

In [None]:
teacher_model = SentenceTransformer(model_name)

## Data preparation

We will download the training and eval (ie. benchmark) datasets seperately and convert them to DataFrame.
We are using these datasets:
- For training: [**ALLNLI**](https://www.sbert.net/examples/datasets/README.html): Includes [SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) datsets
- For benchmark: [**STS Benchmark**](http://ixa2.si.ehu.eus/stswiki/index.php/Main_Page): STS Benchmark comprises a selection of the English datasets used in the STS tasks organized in the context of SemEval between 2012 and 2017.

In [None]:
def download_as_dataframe(url: str, download_path: str):
    download_dataset(url, download_path)
    return read_as_dataframe(download_path)


# download training dataset (ALLNLI)
training_ds = download_as_dataframe(
    ALLNLI_DATASET_URL,
    "datasets/allnli.tsv.gz",
)
# download evaluation (ie. benchmark) dataset (STSBENCHMARK)
benchmark_ds = download_as_dataframe(
    STS_BENCHMARK_DATASET_URL,
    "datasets/stsbenchmark.tsv.gz",
)

## Create train and benchmark evaluators

Evaluators are proper objects to pass through fit function of teacher models. It includes the dataset and eval functions.

In [None]:
# training evaluator
train_sents = training_ds[training_ds["split"] == "train"].loc[
    :, ["sentence1", "sentence2"]
]
X_train = list(
    set(train_sents["sentence1"].to_list() + train_sents["sentence2"].to_list())
)
random_state.shuffle(X_train)
X_train = X_train[:max_train_samples]  # limit train dataset
train_eval = evaluation.MSEEvaluator(
    X_train,
    X_train,
    teacher_model,
    name="allnli-train",
)

# benchmark evaluator
bench_sents = benchmark_ds[benchmark_ds["split"] == "dev"].loc[
    :, ["sentence1", "sentence2", "score"]
]
bench_samples = [
    InputExample(
        texts=[bench_sents.iloc[i, 0], bench_sents.iloc[i, 1]],
        label=(float(bench_sents.iloc[i, 2]) / 5.0),
    )
    for i in range(len(bench_sents))
]
bench_eval = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(
    bench_samples,
    name="sts-dev",
)

## Evaluate teacher model

We first evaluate the teacher model on benchmark dataaset.

In [None]:
teacher_eval_result = bench_eval(teacher_model)
print("Teacher model's benchmark result:", teacher_eval_result)

## Define fitness function

We define fitness function to pass genetics algorithms.

In [None]:
def get_student_model_from_layers(layers: t.List[int]) -> SentenceTransformer:
    """Create a student model same as teacher model with given layers indices.

    Args:
        layers (t.List[int]): List of layer indices.

    Returns:
        SentenceTransformer: Student model.
    """
    student_model = SentenceTransformer(model_name)
    auto_model = student_model._first_module().auto_model
    new_layers = torch.nn.ModuleList(
        [
            layer_module
            for i, layer_module in enumerate(auto_model.encoder.layer)
            if i in layers
        ]
    )
    auto_model.encoder.layer = new_layers
    auto_model.config.num_hidden_layers = len(layers)
    return student_model


def fitness_function(layers: t.List[int]) -> float:
    """Fitness (or object) function. 

    Args:
        layers (t.List[int]): List of layer indices.

    Returns:
        float: Benchmark eval result (rated with teacher model's result).
    """
    student_model = get_student_model_from_layers(layers)
    train_data = ParallelSentencesDataset(
        student_model=student_model,
        teacher_model=teacher_model,
        batch_size=inference_batch_size,
        use_embedding_cache=False,
    )
    train_data.add_dataset(
        [[sent] for sent in X_train],
        max_sentence_length=max_sentence_length,
    )
    train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
    train_loss = losses.MSELoss(model=student_model)

    student_model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluation.SequentialEvaluator([bench_eval, train_eval]),
        epochs=epochs,
        warmup_steps=warmup_steps,
        evaluation_steps=evaluation_steps,
        output_path=output_path,
        optimizer_params={"lr": learning_rate, "eps": epsilon},
        save_best_model=False,
        use_amp=True,
    )
    return bench_eval(student_model) / teacher_eval_result

## Run genetic algorithms

Finally, we run the genetics processes to get best suited teacher model.

In [None]:
population = Population(
    population_size,
    mutation_rate,
    value_range=gene_value_range,
    length=chromosome_length,
    keep_best_chromosomes=True,
)

for i in range(max_generation):
    population.eval(fitness_function)
    population.update()
    print(population.local_best, population.local_best.fitness)
print(population.global_best, population.global_best.fitness)