In [1]:
from bbml.foundations.gpt2 import GPT2Foundation, GPTConfig

In [2]:
foundation = GPT2Foundation(  # based on nanoGPT
    GPTConfig(),
    None
)

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


In [3]:
foundation.run(foundation.input_model(text="The quick brown fox jumps over", max_new_tokens=10)).text

'The quick brown fox jumps over the front wheel of a Joltcar. Then'

In [4]:

from enum import Enum
import math

import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from einops import rearrange


class SplitLinear(nn.Module):

    def __init__(self, bias: bool, out_features: int | None = None, device=None, dtype=None):
        super().__init__()
        self.splits: nn.ModuleList = nn.ModuleList()

        if bias:
            self.bias = nn.Parameter(
                torch.empty(out_features, device=device, dtype=dtype)
            )
        else:
            self.register_parameter('bias', None)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outs = []
        for split in self.splits:
            outs.append(split(x))
        out = torch.cat(outs, dim=-1)
        if self.bias is not None:
            out = out + self.bias
        return out
    

class ShareLinearState(str, Enum):
    ORIGINAL = "original"  # Forward uses original weights
    CALIBRATING = "calibrating"  # Forward uses original weights + tracks inputs
    COMPRESSED = "compressed"  # Forward uses basis @ coefficient


class ShareLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        basis_features: int,
        out_features: int,
        device=None,
        dtype=None
    ):
        super().__init__()
        self.in_features = in_features
        self.basis_features = basis_features
        self.out_features = out_features
        
        self.basis = nn.Parameter(
            torch.empty(basis_features, in_features, device=device, dtype=dtype)
        )
        self.coefficient = nn.Parameter(
            torch.empty(out_features, basis_features, device=device, dtype=dtype)
        )
        self.original = nn.Parameter(
            torch.empty(out_features, in_features, device=device, dtype=dtype)
        )
        self.state = ShareLinearState.ORIGINAL
        self.register_buffer("xtx", None)
    
    @torch.no_grad()
    def track_input(self, x: torch.Tensor):

        inp = x.detach().float()
        inp = rearrange(inp, "B T H -> (B T) H")
        xtx = inp.T @ inp  # (H, H)

        if self.xtx is None:
            self.xtx = torch.zeros_like(xtx)

        self.xtx = self.xtx + xtx  # sum over iters xTx same as XTX

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.state == ShareLinearState.CALIBRATING:
            self.track_input(x)

        if self.state != ShareLinearState.COMPRESSED:
            return F.linear(x, self.original)

        b = F.linear(x, self.basis)
        out = F.linear(b, self.coefficient)
        return out
    
    def extra_repr(self) -> str:
        return f"in={self.in_features}, basis={self.basis_features}, out={self.out_features}, state={self.state.value}"

        

In [5]:
a = torch.rand((10, 768))

In [6]:
(a.T @ a).shape

torch.Size([768, 768])

In [None]:
from collections import defaultdict
import re
from pathlib import Path
from typing import Annotated, Literal

from pydantic import BaseModel, Field, model_validator
from torch.utils.data import Dataset
import tqdm
import bbml

class WeightConfig(BaseModel):
    pattern: str
    split: bool|Literal["qkv", "heads"] = False
    qkv_order: str = "qkv"
    n_head: int|None = None

    @model_validator(mode="after")
    def qkv_string(self):
        if self.split:
            if not self.qkv_order or set(self.qkv_order.lower()) != {"q", "k", "v"}:
                raise ValueError("qkv_order must contain exactly 'q', 'k', and 'v' characters when split='qkv'")
        return self

    @model_validator(mode="after")
    def validate_heads_split(self):
        if self.split == "heads" and self.n_head is None:
            raise ValueError("n_head must be set when split='heads'")
        return self


Percent = Annotated[float, Field(ge=0, le=1)]
class SplitConfig(BaseModel):
    weight_types: dict[str, WeightConfig]
    group_size: int
    layer_pattern: str
    compression_rate: Percent


class SplitLinearFinetuner(bbml.Finetuner):
    def __init__(self, model: bbml.Foundation, config: SplitConfig):
        super().__init__(model)
        self.config: SplitConfig = config
    
        self.original_weights = []

        named_modules = {k:v for k,v in model.named_modules()}

        for name, module in named_modules.items():
            for wtype, wtype_cfg in self.config.weight_types.items():
                if re.match(wtype_cfg.pattern, name) is not None:
                    name_parts = name.split(".")
                    parent_name = ".".join(name_parts[:-1])
                    parent = named_modules[parent_name]
                    list_id = None
                    if name_parts[-1].isdigit():
                        list_id = int(name_parts[-1])
                    
                    layer_num = re.search(config.layer_pattern, name).group(1)

                    self.original_weights.append({
                        "name": name,
                        "module": module,
                        "parent": parent,
                        "list_id": list_id,
                        "last_name_part": name_parts[-1],
                        "layer": layer_num,
                        "weight_type": wtype,
                    })

    def get_train_parameters(self):
        raise NotImplementedError()


    def save(self, save_path: str | Path):
        raise NotImplementedError()


    def load(self, load_path: str | Path):
        raise NotImplementedError()

    def split_weights(self):

        self.all_sharelinears = {}
        self.weight_types = defaultdict(list)

        for wt_dict in self.original_weights:
            wtype = wt_dict["weight_type"]
            wtype_cfg = self.config.weight_types[wtype]
            module = wt_dict["module"]
            name = wt_dict["name"]
            has_bias = module.bias is not None
            in_feats = module.in_features
            out_feats = module.out_features
            split_linear = SplitLinear(bias=has_bias, out_features=out_feats)
            if has_bias:
                split_linear.bias.data = module.bias.data
            if wt_dict["list_id"] is None:
                setattr(wt_dict["parent"], wt_dict["last_name_part"], split_linear)
            else:
                idx = wt_dict["list_id"]
                wt_dict["parent"][idx] = split_linear

            if wtype_cfg.split == "qkv" or wtype_cfg.split == "heads":
                assert out_feats % 3 == 0
                qkv_feats = out_feats // 3 
            
            if wtype_cfg.split == "qkv":
                cur_ind = 0            
                for qkv_part in wtype_cfg.qkv_order:
                    to_part = ShareLinear(in_feats, in_feats, qkv_feats)    
                    to_part.original.data = module.weight.data[cur_ind:cur_ind+qkv_feats,:]
                    cur_ind += qkv_feats
                    
                    split_linear.splits.append(to_part)
                    self.all_sharelinears[f"{name}.{qkv_part}"] = to_part
                    
                    split_wtype = f"{wtype}.{qkv_part}"
                    self.weight_types[split_wtype].append({
                        "module": to_part,
                        "name": f"{name}.{qkv_part}",
                        "split_name": qkv_part,
                        "weight_type": wtype,
                        "layer": wt_dict["layer"],
                    })
                
            elif wtype_cfg.split == "heads":
                n_heads = wtype_cfg.n_head
                head_dim = qkv_feats // n_heads

                cur_ind = 0
                for qkv_part in wtype_cfg.qkv_order:
                    for head_num in range(n_heads):
                        to_head = ShareLinear(in_feats, in_feats, head_dim)
                        to_head.original.data = module.weight.data[cur_ind:cur_ind+head_dim,:]  # [out_dim, in_dim] -> [head_dim, in_dim]
                        cur_ind += head_dim

                        split_linear.splits.append(to_head)
                        self.all_sharelinears[f"{name}.{qkv_part}.{head_num}"] = to_head
                        
                        split_wtype = f"{wtype}.{qkv_part}.{head_num}"
                        self.weight_types[split_wtype].append({
                            "module": to_head,
                            "name": f"{name}.{qkv_part}.{head_num}",
                            "split_name": f"{qkv_part}.{head_num}",
                            "weight_type": wtype,
                            "layer": wt_dict["layer"],
                        })
                

            else: # no split
                in_linear = ShareLinear(in_feats, in_feats, out_feats)
                in_linear.original.data = module.weight.data

                split_linear.splits.append(in_linear)
                self.all_sharelinears[name] = in_linear
                self.weight_types[wtype].append({
                    "module": in_linear,
                    "name": name,
                    "split_name": "",
                    "weight_type": wtype,
                    "layer": wt_dict["layer"],
                })
        
    def group_weights(self):
        self.groups = defaultdict(list)
        
        group_size = self.config.group_size
        for weight_type, weights_list in self.weight_types.items():
            # group_adjacent
            cur_group = []
            for weights_dict in weights_list:
                cur_group.append(weights_dict)
                if len(cur_group) >= group_size:
                    self.groups[weight_type].append(cur_group)
                    cur_group = []
            if len(cur_group) > 0:  # leftovers
                self.groups[weight_type].append(cur_group)
                cur_group = []

    @torch.no_grad()
    def calibrate(self, dataset: Dataset|bbml.DataPipe):
        self.model.eval()
        for l in self.all_sharelinears.values():
            l.state = ShareLinearState.CALIBRATING
        
        if not isinstance(dataset, bbml.DataPipe):
            dataset = bbml.DataPipe(
                batch_size=1,
                shuffle=False,
                num_workers=2,
            ).add_dataset(
                dataset
            ).add_transforms(
                self.model.data_transforms
            )

        dataloader = dataset.get_loader()

        for batch in tqdm.tqdm(dataloader):
            step_info = {
                "step": 0,
                "split": "validation",
            }
            batch.update(step_info)
            self.model.single_step(batch)
    
    @staticmethod
    def compute_num_basis(
        in_features: int, out_features: int, compression_ratio: int
    ) -> int:
        total_original = in_features * out_features
        num_basis = (total_original * compression_ratio) / (in_features + out_features)
        return max(1, int(num_basis))

    def run_svd(self):
        self.group_bases = defaultdict(list)
        for weight_type, groups_list in self.groups.items():
            for group in tqdm.tqdm(groups_list):
                all_xtx = []
                all_weights = []
                out_sizes = []
                for weights_dict in group:
                    all_xtx.append(weights_dict["module"].xtx)
                    all_weights.append(weights_dict["module"].original.data)
                    out_sizes.append(weights_dict["module"].original.data.size(1))
                    
                all_xtx = sum(all_xtx)
                try:
                    S = torch.linalg.cholesky(all_xtx).T
                except Exception as e:
                    print("Warning: eigen scaling_diag_matrix is not positive!")
                    eigenvalues = torch.linalg.eigvalsh(all_xtx)
                    all_xtx += (- eigenvalues[0] + 7e-6) * torch.eye(all_xtx.shape[0]).to(all_xtx.device)
                    S = torch.linalg.cholesky(all_xtx).T
                S_inv = torch.linalg.inv(S)
                W = torch.cat(all_weights, dim=0).T  # -> [in, out_cat]
                

                W_white = S @ W

                U, sigma, Vh = torch.linalg.svd(W_white, full_matrices=False)  # different function torch.svd default some=True is full_matrices=False

                total_basis = S_inv @ U @ torch.diag(sigma)
                total_coefficient = Vh
                compressed_basis = total_basis[:, :k]
                compressed_coefficient = total_coefficient[:k, :]


                k = self.compute_num_basis(W.size(0), W.size(1), self.config.compression_rate)
                group_basis = nn.Parameter(compressed_basis)
                self.group_bases[weight_type].append(group_basis)
                
                cur_id = 0
                for i, weights_dict in enumerate(group):
                    module = weights_dict["module"]
                    module.basis.weight = group_basis
                    cur_out_dim = out_sizes[i]
                    module.coefficient.weight.data = compressed_coefficient[:,cur_id:cur_id+cur_out_dim]
                    cur_id += cur_out_dim
                    module.state = ShareLinearState.COMPRESSED


                
                
        


In [11]:
%%writefile config.yaml
weight_types:
    attn_c_attn:
        pattern: '.*\.attn\.c_attn$'
        split: qkv  # qkv, heads
        qkv_order: qkv  # string
        n_head: 12
    attn_c_proj: 
        pattern: '.*\.attn\.c_proj$'
    mlp_c_fc: 
        pattern: '.*\.mlp\.c_fc$'
    mlp_c_proj: 
        pattern: '.*\.mlp\.c_proj$'

layer_pattern: 'h\.(\d+)'

group_size: 2
compression_rate: 0.9

Overwriting config.yaml


In [12]:
import yaml
with open("config.yaml") as f:
    cfg = yaml.safe_load(f)

In [13]:
splitcfg = SplitConfig(**cfg)

In [14]:
wrapper = SplitLinearFinetuner(foundation, splitcfg)

In [15]:
wrapper.split_weights()

In [16]:
wrapper.group_weights()

In [None]:
# train_dp = DataPipe(
#     batch_size=1,
#     shuffle=True,
#     num_workers=16,
# ).add_dataset(
#     WikiTextDataset(split="train")
# ).add_transforms(
#     gpt.data_transforms
# )
from bbml.data.datasets import WikiTextDataset

ds = WikiTextDataset(split="train")
ds.ds = ds.ds.select(range(256))

wrapper.calibrate(ds)

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6211 [00:00<?, ? examples/s]

In [None]:
wrapper.run_svd()

In [None]:
foundation.run(foundation.input_model(text="The quick brown fox jumps over", max_new_tokens=10)).text

In [None]:
wrapper.all_basislinears

In [None]:
wrapper.weight_types

In [None]:
list(foundation.named_modules())

In [None]:
from pprint import pprint

pprint(wrapper.groups)

In [None]:


pprint(wrapper.weight_types)