In [1]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="daslab-testing/testing-800m", local_dir="../testing-800m")

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

summary.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

'/mloscratch/homes/panferov/schedules-and-scaling/testing-800m'

In [2]:
import json

from optim.utils import load_checkpoint
from models.utils import get_model


class DotDict(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

    def __setattr__(self, key, value):
        self[key] = value


PATH = "../exps/UNTIED-800M-TrustQuantizer@4:TrustQuantizer@4-c4_c4_llama_nlayers16_nhead16_lr7.5e-05_sched_cos_warmup30517_decay_linear_0.1_iter305175_bs32x2_ws8_seed0_data_seed1337" # "../testing-800m"
with open(f"{PATH}/summary.json", "r") as f:
    config = json.load(f)


In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from fast_hadamard_transform import hadamard_transform

from models.quantization.base_linear import OPTIMAL_GAUSSIAN_SCALES, TrustQuantizer


def quantize_pack_dense(x: torch.Tensor, quantizer: TrustQuantizer):
    assert quantizer.centered
    std = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True))
    scale = OPTIMAL_GAUSSIAN_SCALES[quantizer.bits] * std + 1e-8

    step = 2 * scale / (quantizer.n_levels - 1)
    x_clip = torch.clamp(x, -scale, scale)
    xq = torch.round((x_clip + scale) / step)

    assert xq.min() >= 0 and xq.max() < quantizer.n_levels
    return xq, scale, step

def dequantize_dense(xq, scale, step):
    return xq * step - scale


weight = torch.rand(2, 128).cuda()
quantizer = TrustQuantizer(bits=4, centered=True)
ref = quantizer(weight)
xq, scale, step = quantize_pack_dense(weight, quantizer)
deq = dequantize_dense(xq, scale, step)

torch.testing.assert_close(ref, deq, rtol=1e-4, atol=1e-4)

In [4]:
from models.quantization.base_linear import QuantizedLinear

class Linear4bit(nn.Module):
    def __init__(self, quantizer_linear):
        super().__init__()
        
        assert isinstance(quantizer_linear.weight_quantizer, TrustQuantizer)
        assert isinstance(quantizer_linear.activation_quantizer, TrustQuantizer)
        
        self.activation_quantizer = quantizer_linear.activation_quantizer
        
        wq = dequantize_dense(*quantize_pack_dense(quantizer_linear.weight, quantizer_linear.weight_quantizer))
        self.register_buffer("wq", wq)
        self.bias = quantizer_linear.bias

    def forward(self, x):
        x = dequantize_dense(*quantize_pack_dense(x, self.activation_quantizer))
        return F.linear(x, self.wq, self.bias)


def replace_linears(model):
    for name, module in model.named_children():
        if isinstance(module, QuantizedLinear):
            model._modules[name] = Linear4bit(module)
        else:
            replace_linears(module)
    return model

In [5]:
class PseudoDdp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self._orig_mod = nn.ModuleDict({
            "module": model,
        })
        
class PseudoLoader:
    def load_state_dict(self, *args, **kwargs):
        pass

model = PseudoDdp(get_model(DotDict(config['args'])))
load_checkpoint(model, PseudoLoader(), PseudoLoader(), f"{PATH}/ckpts/latest/main.pt", "cuda")
model = model.cuda()
model = model._orig_mod["module"]
model = replace_linears(model)


  ckpt = torch.load(ckpt_path, map_location=device)


In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [7]:
def generate_text_greedily(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids, get_logits=True)
            logits = outputs['logits'][:, -1, :]
        
        next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

generated_text = generate_text_greedily(model, tokenizer, "Hi!", max_length=20)
print(generated_text)


Hi! Sign in to let us know how The Coffee Shop was?
by jessica


In [8]:
numel = 0
for name, param in model.named_buffers():
    numel += param.numel()
    
print(numel/1e6)

822.083584
