In [1]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="daslab-testing/hadamard-testing-800m", local_dir="../hadamard-testing-800m")

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

summary.json:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

'/mloscratch/homes/panferov/schedules-and-scaling/hadamard-testing-800m'

In [2]:
import json

from optim.utils import load_checkpoint
from models.utils import get_model


class DotDict(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

    def __setattr__(self, key, value):
        self[key] = value


PATH = "../hadamard-testing-800m"
with open(f"{PATH}/summary.json", "r") as f:
    config = json.load(f)


In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from fast_hadamard_transform import hadamard_transform

from models.quantization.base_linear import OPTIMAL_GAUSSIAN_SCALES, HadamardTrustQuantizer, HalfHadamardTrustQuantizer


def quantize_pack_hadamard_dense(x: torch.Tensor, quantizer: HadamardTrustQuantizer):
    assert quantizer.centered
    x_had = hadamard_transform(x.reshape(-1, 128), scale=2 ** (-7/2)).reshape(x.shape)
    
    std = torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True)) + 1e-8
    scale = OPTIMAL_GAUSSIAN_SCALES[quantizer.bits] * std

    step = 2 * scale / (quantizer.n_levels - 1)
    x_clip = torch.clamp(x_had, -scale, scale)
    xq = torch.round((x_clip + scale) / step)

    assert xq.min() >= 0 and xq.max() < quantizer.n_levels
    return xq, scale, step
    # ^ note: xq is in rotated space!

def dequantize_dense(xq, scale, step):
    return xq * step - scale

weight = torch.rand(2, 128).cuda()
quantizer = HadamardTrustQuantizer(bits=4)
ref = quantizer(weight)
xq, scale, step = quantize_pack_hadamard_dense(weight, quantizer)
deq = dequantize_dense(xq, scale, step)

torch.testing.assert_close(hadamard_transform(ref.reshape(-1, 128), scale=2 ** (-7/2)).reshape(ref.shape), deq, rtol=1e-3, atol=1e-3)

In [7]:
from models.quantization.base_linear import QuantizedLinear

class Linear4bit(nn.Module):
    def __init__(self, quantizer_linear):
        super().__init__()
        
        assert isinstance(quantizer_linear.weight_quantizer, HadamardTrustQuantizer)
        assert isinstance(quantizer_linear.activation_quantizer, HadamardTrustQuantizer)
        
        self.activation_quantizer = quantizer_linear.activation_quantizer
        
        wq = dequantize_dense(*quantize_pack_hadamard_dense(quantizer_linear.weight, quantizer_linear.weight_quantizer))
        self.register_buffer("wq", wq)
        self.bias = quantizer_linear.bias

    def forward(self, x):
        x = dequantize_dense(*quantize_pack_hadamard_dense(x, self.activation_quantizer))
        return F.linear(x, self.wq, self.bias)


def replace_linears(model):
    for name, module in model.named_children():
        if isinstance(module, QuantizedLinear):
            model._modules[name] = Linear4bit(module)
        else:
            replace_linears(module)
    return model

In [8]:
class PseudoDdp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self._orig_mod = nn.ModuleDict({
            "module": model,
        })
        
class PseudoLoader:
    def load_state_dict(self, *args, **kwargs):
        pass

model = PseudoDdp(get_model(DotDict(config['args'])))
load_checkpoint(model, PseudoLoader(), PseudoLoader(), f"{PATH}/main.pt", "cuda")
model = model.cuda()
model = model._orig_mod["module"]
model = replace_linears(model)


In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [10]:
def generate_text_greedily(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids, get_logits=True)
            logits = outputs['logits'][:, -1, :]
        
        next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

generated_text = generate_text_greedily(model, tokenizer, "Hi!", max_length=20)
print(generated_text)


Hi! Sign in to let us know how The Pizza Shack was?
If you've been


In [11]:
numel = 0
for name, param in model.named_buffers():
    numel += param.numel()
    # print(name, param.numel())
    
print(numel/1e6)

822.083584
