<a href="https://colab.research.google.com/github/YanaySoker/Quantization/blob/main/Quantization_shell.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/EleutherAI/lm-evaluation-harness.git
%cd lm-evaluation-harness

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install matplotlib

In [None]:
!rm ./lm_eval/models/qu.py

In [None]:
%%writefile ./lm_eval/models/qu.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import InplaceFunction
import time

from . import gpt2
from . import gpt3
from . import dummy

# Change here:
class qu(gpt2.HFLM):
    def __init__(self, B, *args, **kwargs):
        super(qu, self).__init__(*args, **kwargs)

        self.E_bits = 2
        self.B = B
        self.m = 5

        self.E = 2**self.E_bits

        self.quantizeFwd = False


    def forward(self, input):

        if self.quantizeFwd:
            w_q = FPQuantizeSawb.apply(self.weight, self.B, self.E, self.m)

            if torch.min(input) < 0:
                self.QnA = -2 ** (self.abits - 1)

            qinput = FPQuantizeSawb.apply(input, self.B, self.E, self.m)
            # all
            output = F.conv2d(qinput, w_q, self.bias, self.stride,
                              self.padding, self.dilation, self.groups)

        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride,
                              self.padding, self.dilation, self.groups)

        return output


class FPQuantizeSawb(InplaceFunction):

    @staticmethod
    def forward(ctx, input, B, E, m):
        sign = (input.__ge__(0).int() - 0.5) * 2
        abs_input = input * sign

        Max = (2 - 2 ** (-m)) * (2 ** (E - B - 1))
        abs_input = torch.clip(abs_input, max=Max)

        abs_input_without_0 = torch.clip(abs_input, min=2 ** (1 - B))
        exp_range = torch.log2(abs_input_without_0)
        exp_range = exp_range + B
        exp_range = exp_range.int()
        exp_range = exp_range - B

        upper_bound = 2 ** (exp_range + 1)
        min_mask = torch.ge(exp_range, 1.5 - B).int()
        lower_bound = 2 ** exp_range
        lower_bound = lower_bound * min_mask
        num_of_value = 2 ** m
        num_of_values = (2 - min_mask) * num_of_value

        abs_input = abs_input - lower_bound
        abs_input = abs_input * num_of_values
        abs_input = abs_input / (upper_bound - lower_bound)
        abs_input = abs_input + 0.5
        abs_input = abs_input.int()
        abs_input = abs_input * (upper_bound - lower_bound)
        abs_input = abs_input / num_of_values
        abs_input = abs_input + lower_bound
        return abs_input * sign

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None

Writing ./lm_eval/models/qu.py


In [None]:
!rm ./lm_eval/models/__init__.py

In [None]:
%%writefile ./lm_eval/models/__init__.py
from . import gpt2
from . import gpt3
from . import dummy

from . import qu

MODEL_REGISTRY = {
    "qu": qu.qu,
}


def get_model(model_name):
    return MODEL_REGISTRY[model_name]

Writing ./lm_eval/models/__init__.py


In [None]:
!rm ./main.py

In [None]:
%%writefile ./main.py
import argparse
import json
import logging
import fnmatch
import matplotlib.pyplot as plt

# Change here:
FROM, TO = -6, 20


from lm_eval import tasks, evaluator

logging.getLogger("openai").setLevel(logging.WARNING)


class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True)
    parser.add_argument("--model_args", default="")
    parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--num_fewshot", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--output_path", default=None)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--no_cache", action="store_true")
    parser.add_argument("--decontamination_ngrams_path", default=None)
    parser.add_argument("--description_dict_path", default=None)
    parser.add_argument("--check_integrity", action="store_true")

    return parser.parse_args()


# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)


def main():
    args = parse_args()

    assert not args.provide_description  # not implemented

    if args.limit:
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )

    if args.tasks is None:
        task_names = tasks.ALL_TASKS
    else:
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)

    print(f"Selected Tasks: {task_names}")

    description_dict = {}
    if args.description_dict_path:
        with open(args.description_dict_path, "r") as f:
            description_dict = json.load(f)

    R = dict()

    for i in range(FROM, TO):
        print(f"B = {i}:")
        results = evaluator.simple_evaluate(
            model=args.model,
            model_args=f"B={i}",
            tasks=task_names,
            num_fewshot=args.num_fewshot,
            batch_size=args.batch_size,
            device=args.device,
            no_cache=args.no_cache,
            limit=args.limit,
            description_dict=description_dict,
            decontamination_ngrams_path=args.decontamination_ngrams_path,
            check_integrity=args.check_integrity,
        )
        results = results['results']
        for K in results.keys():
            for k in results[K].keys():
                if "_stderr" not in k:
                    if f"{K}: {k}" not in R.keys():
                        R[f"{K}: {k}"] = []
                    R[f"{K}: {k}"].append(results[K][k])
                    print(f"{K}: {k}: {results[K][k]}")

    _keys = list(R.keys())
    for k in _keys:
        print(f"{k}: {R[k]}")
        plt.plot(range(FROM, TO), R[k], label=k)
    plt.legend(_keys)
    plt.show()

if __name__ == "__main__":
    main()

Writing ./main.py


Change here - the tasks:

In [None]:
!python main.py \
	--model qu \
	--device 0 \
	--tasks hellaswag