In [1]:
import os
import pickle
from contextlib import nullcontext
import torch
import time
import tiktoken
from model import GPTConfig, GPT
from vanilla_transformer import TransformerLM, TransformerConfig
from retnet import RetNet, retnet_1_3b, RetNetConfig

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "\nAs the man walked down the stairs, " # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 5 # number of samples to draw
max_new_tokens = 1000 # number of tokens generated in each sample
temperature = 1 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 50 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 10
model_name = 'ckpt_r_2048.pt'
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
isRetnet = True
isTransformer = False
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)



In [2]:
# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, model_name)
    checkpoint = torch.load(ckpt_path, map_location=device)
    if isRetnet:
        conf = RetNetConfig(**checkpoint['model_args'])
        conf.n_embd = int(conf.n_embd)
        print(conf)
        model = RetNet(conf, 
        num_tokens=50304,
        d_model=conf.n_embd,
        nhead=conf.n_head,
        num_layers=conf.n_layer,
        dim_feedforward=conf.n_embd * 4)
        device=device
    elif isTransformer:
        conf = TransformerConfig(**checkpoint['model_args'])
        print(conf)
        model = TransformerLM(conf, 
        num_tokens=50304,
        d_model=conf.n_embd,
        nhead=conf.n_head,
        num_layers=conf.n_layer,
        dim_feedforward=conf.n_embd * 4)
        device=device
    else:
        gptconf = GPTConfig(**checkpoint['model_args'])
        model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    # TODO want to make this more general to arbitrary encoder/decoder schemes
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])
else:
    # ok let's assume gpt-2 encodings by default
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
time_per_sample = []
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            start_time = time.time()
            y = model.generate_parallel(x, max_new_tokens, temperature=temperature, top_k=top_k)
            #print(y)
            print(decode(y[0].tolist()))
            end_time = time.time()
            print('---------------')
            duration = end_time - start_time
            print(f"Elapsed time: {duration}")
            time_per_sample.append(duration)
print(time_per_sample)
print(sum(time_per_sample) / len(time_per_sample))


RetNetConfig(block_size=2048, vocab_size=50304, n_layer=4, n_head=16, n_embd=256, dropout=0.0, bias=False)


  from .autonotebook import tqdm as notebook_tqdm


No meta.pkl found, assuming GPT-2 encodings...

As the man walked down the stairs, owbrium off;
Of all: yet?
But by their female too much you are the day.
BRO:
I never be not I am this place and to the crown,
Which thou did never speak.
HENI'll play a prayer.
My husband's a noble heart,
Thanio,
If you the law, by my breast, it is my good
As, and what you are I should not to take.

First Murderer:
And, she with me not like a Montague!
I say it

COMINCE EDWARD:
By the world,
For your worship well in this?

GLOUCESTER:
'
Amen, I will your voices'd my good Camillo.

First Servingman:
It was the city with me;
But I, then, for you will make the queen,
As bright track.
O' faith hath a great affairs!
LUCIO:
A bachelor, the time of Gloucester and all thy hands.

Clown:
Marry, we have it had you both:
Who, and therefore be great love from the world, and you to this!
I may be no, my cousin that our mind:
And in the Duke to thee;
When I am not hear the king, though he hadst it!
Their harness, and 

In [3]:
import gc
import sys
import time
from pathlib import Path
from typing import Optional
import os
import numpy as np
import pickle
from contextlib import nullcontext
import time
import tiktoken
from model import GPTConfig, GPT
from vanilla_transformer import TransformerLM, TransformerConfig
from retnet import RetNet, retnet_1_3b, RetNetConfig

import torch
from datasets import load_dataset

# support running without installing as a package
#wd = Path(__file__).parent.parent.resolve()
#sys.path.append(str(wd))

from model import GPTConfig, GPT
from quantization import GPTQQuantizer
#from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup


model_name = 'ckpt_r_2048.pt'
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
isRetnet = True
isTransformer = False
out_dir = 'out'
NR_SAMPLES = 128


def get_sample_data():
        # traindata = load_dataset(
    #     "allenai/c4",
    #     "allenai--c4",
    #     data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
    #     split="train",
    # )
    traindata = load_dataset(
        "wikitext",
        "wikitext-2-v1",
        #data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
        split="train",
    )
    # heuristic for the data size?
    txt = "\n".join(
        traindata[i]["text"] for i in torch.randperm(len(traindata))[:5000].tolist()
    )
    return txt


@torch.no_grad()
def llama_blockwise_quantization(
    model, sample_inputs, working_device, *, bits=4, groupsize=-1
):
    """
    This is the classic post-training quantization of all linear layers.
    We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather
    than doing them all at once.
    """
    print(model)
    print(model.config)

    print("Getting inputs for first block")
    print(model.decoder)
    model.embedding.to(working_device)
    sample_inputs = sample_inputs.to(working_device)
    inps = model.embedding(sample_inputs)
    model.embedding.to("cpu")
    torch.cuda.empty_cache()

    rope_cache = None
    mask_cache = None

    print("Starting to quantize blocks")
    outs = torch.zeros_like(inps)

    # better than relying on enumeration? originally the code bundled
    # the two mlp fc layers
    # we could automate this with a lot of hooks and another iteration
    submodules_to_process = [
        "retention.q_proj",
        "retention.k_proj",
        "retention.v_proj",
        "retention.g_proj",
        "linear1",
        "linear2",
    ]

    for i, block in enumerate(model.decoder.layers):
        block.to(working_device)

        for name in submodules_to_process:
            print(i, name, end=" ")
            t0 = time.perf_counter()
            print("collecting stats", end=" ")
            sys.stdout.flush()
            module = block.get_submodule(name)

            gptq = GPTQQuantizer(
                module,
                bits=bits,
                groupsize=groupsize,
                actorder=(groupsize == -1),
            )
            handle = module.register_forward_hook(gptq.collect_input_stats)
            for j in range(inps.size(0)):
                outs[j : j + 1] = block(
                    inps[j : j + 1],
                )

            handle.remove()

            print("quantizing", end=" ")
            sys.stdout.flush()
            q_module, error = gptq.quantize()
            # replace the linear module with the quantized module
            if(len(name.rsplit(".", 1)) == 2):
                pname, dname = name.rsplit(".", 1)
                print(q_module, pname, dname)
                setattr(block.get_submodule(pname), dname, q_module)
            else:
                print(q_module, name)
                setattr(block, name, q_module)

            # cleanup in an attempt to not run out of memory
            del gptq
            gc.collect()
            torch.cuda.empty_cache()
            t1 = time.perf_counter()
            print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}")

        for j in range(inps.size(0)):
            outs[j : j + 1] = block(
                inps[j : j + 1],
            )

        block.cpu()
        gc.collect()
        torch.cuda.empty_cache()

        # the outputs are the next block's inputs and we'll reuse the old inputs
        inps, outs = outs, inps
    # print(inps.size(0))
    # for j in range(inps.size(0)):
    #     model.decoder.layers[3].norm2.to(working_device)
    #     outs[j : j + 1] = model.decoder.layers[3].norm2(inps[j : j + 1])
    #     model.decoder.layers[3].norm2.to('cpu')
    inps, outs = outs, inps
    model.out.to(working_device)
    gptq = GPTQQuantizer(
        model.out,
        bits=bits,
        groupsize=groupsize,
        actorder=(groupsize == -1),
    )
    handle = model.out.register_forward_hook(gptq.collect_input_stats)
    for j in range(inps.size(0)):
        model.out(inps[j : j + 1])
    handle.remove()
    q_module, error = gptq.quantize()
    model.out = q_module
    model.out.to("cpu")
    print(model)


def quantizing(
    *,
    output_path: Optional[Path] = None,
    n_samples: int = NR_SAMPLES,
    dtype: str = "float32",
    quantize: Optional[str] = None,
) -> None:
    """Generates text samples based on a pre-trained LLaMA model and tokenizer.

    Args:
        checkpoint_path: The checkpoint path to load.
        output_path: Path to write the quantized model's state dict to.
        tokenizer_path: The tokenizer path to load.
        n_samples: Number of example inputs to use for statistics (default: 128)
        dtype: The dtype to use to load the model.
        quantize: Mode to quantize the model to:
            ``"gptq.int4"``: GPTQ 4-bit mode.
            Note that ``"llm.int8"```does not need a quantization step.
    """
    device = "cuda"

    dt = getattr(torch, dtype, None)
    if not isinstance(dt, torch.dtype):
        raise ValueError(f"{dtype} is not a valid dtype.")
    dtype = dt

    if quantize == "gptq.int4":
        bits = 4
    elif quantize == "gptq.int8":
        bits = 8
    else:
        raise RuntimeError(f"unknown/unsupported quantization mode {quantize}")

    # we avoid loading the entire model on the GPU and do this block by block
    print("Loading model ...", file=sys.stderr)
    t0 = time.time()
    ckpt_path = os.path.join(out_dir, model_name)
    checkpoint = torch.load(ckpt_path, map_location=device)
    print(checkpoint['model_args'])
    if isRetnet:
        conf = RetNetConfig(**checkpoint['model_args'])
        print(conf)
        model = RetNet(conf, 
        num_tokens=50304,
        d_model=conf.n_embd,
        nhead=conf.n_head,
        num_layers=conf.n_layer,
        dim_feedforward=conf.n_embd * 4)
        device=device
    elif isTransformer:
        conf = TransformerConfig(**checkpoint['model_args'])
        print(conf)
        model = TransformerLM(conf, 
        num_tokens=50304,
        d_model=conf.n_embd,
        nhead=conf.n_head,
        num_layers=conf.n_layer,
        dim_feedforward=conf.n_embd * 4)
        device=device
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

    model.eval()

    enc = tiktoken.get_encoding("gpt2")
    def process(example):
        ids = enc.encode_ordinary(example) # encode_ordinary ignores any special tokens
        ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
        # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
        out = {'ids': ids, 'len': len(ids)}
        return ids

    # tokenize the dataset
    text = get_sample_data()
    encoded_text = torch.tensor(process(text))

    # tokenizer = Tokenizer(tokenizer_path)

    # test_string = get_sample_data()
    # encoded_text = tokenizer.encode(
    #     test_string,
    #     bos=True,
    #     eos=False,
    # )
    block_size = 2048  # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
    encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size)

    t0 = time.perf_counter()
    llama_blockwise_quantization(model, encoded_text, device, bits=bits)
    t = time.perf_counter() - t0

    print(
        f"\n\nTime for quantization: {t:.02f} sec total",
        file=sys.stderr,
    )
    print(
        f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
        file=sys.stderr,
    )
    print(model.state_dict().keys())
    checkpoint_2 = {
                    'model': model.state_dict(),
                    'optimizer': checkpoint['optimizer'],
                    'model_args': checkpoint['model_args'],
                    'iter_num': checkpoint['iter_num'],
                    'best_val_loss': checkpoint['best_val_loss'],
                    'config': checkpoint['config'],
                }
    print(f"saving quantized model to {output_path}")
    torch.save(checkpoint_2, os.path.join(out_dir, output_path))
    return model

In [4]:
model_8bit = quantizing(output_path=model_name[:-3] + '_q.pt', quantize='gptq.int8')
model_8bit.to(device)

device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
time_per_sample = []
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            start_time = time.time()
            y = model_8bit.generate_parallel(x, max_new_tokens, temperature=temperature)
            #print(y)
            print(decode(y[0].tolist()))
            end_time = time.time()
            print('---------------')
            duration = end_time - start_time
            print(f"Elapsed time: {duration}")
            time_per_sample.append(duration)
print(time_per_sample)
print(sum(time_per_sample) / len(time_per_sample))


Loading model ...
Time to load model: 0.13 seconds.


{'n_layer': 4, 'n_head': 16, 'n_embd': 256, 'block_size': 2048, 'bias': False, 'vocab_size': 50304, 'dropout': 0.0}
RetNetConfig(block_size=2048, vocab_size=50304, n_layer=4, n_head=16, n_embd=256, dropout=0.0, bias=False)
RetNet(
  (embedding): Embedding(50304, 256)
  (decoder): RetNetDecoder(
    (layers): ModuleList(
      (0-3): 4 x RetNetDecoderLayer(
        (dropout): Dropout(p=0.1, inplace=False)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (retention): MultiScaleRetention(
          (q_proj): Linear(in_features=256, out_features=256, bias=True)
          (k_proj): Linear(in_features=256, out_features=256, bias=True)
          (v_proj): Linear(in_features=256, out_features=256, bias=True)
          (group_norm): GroupNorm(16, 256, eps=1e-06, affine=False)
          (g_proj): Linear(in_features=256, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=256, bias=True)
        )
        (norm2): LayerNorm((256,),



Time for quantization: 35.02 sec total
Memory used: 3.20 GB


No meta.pkl found, assuming GPT-2 encodings...

As the man walked down the stairs, 
Was ranrate, I haveian our presence,
Have you shall and meeting from heaven, which are amazed;
And then IUGH both. For I could shortly,

CORIET:
WICK of that I say you do must be Paris
Faith, my soul's,
Would she knew barren served Polixenes the vap most truly is
Than Hector:--auteous trusty bal.

Second election, then begin; stones:
It seems thou not jar come out.
First Servant bears, and twenty years together,
Your high, like civil cause dead is but
No remedy of kings, a shrewdared no unt King Henry's eldest open thy heartily tenfold Edward.
BALT:
In this unmus's
Within and honour to seek be satisfied!

FLORIZABELLA liestent, the sunset up his blood!

Ready and hanging,
KING EDWARD:
What is no more, you are going
As now far off.
GLOUCESTER:
I'll be satisfied.



QUEENI'll pity so fast sitting than what a
And now!
Second Servingman:
Why, some ill, in his followers's knife is not high.

His dullhip me, 

In [7]:
model_4bit = quantizing(output_path=model_name[:-3] + '_q.pt', quantize='gptq.int4')
model_4bit.to(device)

device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
time_per_sample = []
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            start_time = time.time()
            y = model_4bit.generate_parallel(x, max_new_tokens, temperature=temperature)
            #print(y)
            print(decode(y[0].tolist()))
            end_time = time.time()
            print('---------------')
            duration = end_time - start_time
            print(f"Elapsed time: {duration}")
            time_per_sample.append(duration)
print(time_per_sample)
print(sum(time_per_sample) / len(time_per_sample))

Loading model ...


{'n_layer': 4, 'n_head': 16, 'n_embd': 256, 'block_size': 2048, 'bias': False, 'vocab_size': 50304, 'dropout': 0.0}
RetNetConfig(block_size=2048, vocab_size=50304, n_layer=4, n_head=16, n_embd=256, dropout=0.0, bias=False)


Time to load model: 0.24 seconds.


RetNet(
  (embedding): Embedding(50304, 256)
  (decoder): RetNetDecoder(
    (layers): ModuleList(
      (0-3): 4 x RetNetDecoderLayer(
        (dropout): Dropout(p=0.1, inplace=False)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (retention): MultiScaleRetention(
          (q_proj): Linear(in_features=256, out_features=256, bias=True)
          (k_proj): Linear(in_features=256, out_features=256, bias=True)
          (v_proj): Linear(in_features=256, out_features=256, bias=True)
          (group_norm): GroupNorm(16, 256, eps=1e-06, affine=False)
          (g_proj): Linear(in_features=256, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=256, bias=True)
        )
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
      )
    )
  )
  (out): Linear(in_fea



Time for quantization: 35.12 sec total
Memory used: 3.20 GB


No meta.pkl found, assuming GPT-2 encodings...

As the man walked down the stairs,  and teachmen and what,
Seal dreadful irCLAR PETERLE:
Cons were of you no more gone, would may.
What be done, then be strength yet there to live.

First Citizen:
FRIARENCE:
A shadow our fault indifferent! why
KING RICHARD III:
Well, never talk of death and that long of want thou art night.


YORK:
I conj's good your many time will strongly;
Said talk of kings restitution, not priv, see my see pathways and he
Outlive'? Hang them together: what title of the m scales
Ready to acquaint your will bewe most stifle the son
Or murderounceanderersmen-freto, lower my,
That they say. Hang of ease her most unjust gold,
For girlsue you restore's glory!
TRANLEY:
Here comes that calls, ha: subjected him whose plots.
It was not in the word, that thou that vice!
By deceived, bones is dead, he that,
and a thousandout of onearewell it befall him
Against
Come belie. So forward I am resolved it and

GLOUCALUS:
Ah, he wakes a

: 