In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.cuda.device_count()

1

In [2]:
model_path = "meta-llama/Meta-Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_path)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
# Data used for published models
from datasets import load_dataset

dataset1 = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
corpus1 = "\n\n".join(dataset1[:20000]["text"])

dataset2 = load_dataset("allenai/c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train")
corpus2 = "\n\n".join(dataset2[:20000]["text"])


In [None]:
import flute.integrations.base
import flute.integrations.learnable

In [5]:
flute.integrations.learnable.learn_scales(
    model=model,
    tokenizer=tokenizer,
    num_bits=4,
    group_size=64,
    custom_corpora=[corpus1, corpus2],
    samples=128,
)

Adding tunable scales to the linear layers...
Tokenizing corpora...


Token indices sequence length is longer than the specified maximum sequence length for this model (1315227 > 131072). Running this sequence through the model will result in indexing errors


Prepare model for training...
Running epoch 0...


  0%|          | 0/128 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


In [6]:
# Casting the model and learned scales to float16 (instead of bfloat16) might result in speed benefits due to kernel specifics.

In [7]:
flute.integrations.base.prepare_model_flute(
    name="model.layers",
    module=model.model.layers,
    num_bits=4,
    group_size=64,
    fake=False,
)



In [8]:
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM

In [9]:
lm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=32, add_bos_token=True)



In [None]:
results = evaluator.simple_evaluate(
    lm,
    tasks="arc_easy", # piqa, arc_easy, arc_challenge, hellaswag, winogrande
    num_fewshot=0,
    limit=None,
)

In [15]:
results["results"]

{'arc_easy': {'acc,none': 0.8173400673400674,
  'acc_stderr,none': 0.007928503719209124,
  'acc_norm,none': 0.8122895622895623,
  'acc_norm_stderr,none': 0.008012496274011486,
  'alias': 'arc_easy'}}