# Weight subcloning https://arxiv.org/pdf/2312.09299

## Initialize the base model (parent)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
model_id = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True)

## Setup hooks to collect activations

In [4]:
hooks = []

In [7]:
activations = {
    'resid': {},
    'attn': {},
    'out': {},
    'in': {},
    'main': {},
    # 'interim': {}
}

def get_activation(name, layer_type, isinput):
    def hook(model, input, output):
        if name not in activations[layer_type]:
            activations[layer_type][name] = []
        if isinput == "input":
            activations[layer_type][name].append(input[0].sum(dim=0).detach().cpu())
        elif isinput == "output2":
            activations[layer_type][name].append(output[1].var(dim=-1).sum(dim=0).detach().cpu())
        elif isinput == "output":
            activations[layer_type][name].append(output[0].sum(dim=0).detach().cpu())
    return hook

for hook in hooks:
    hook.remove()
hooks = []

hooks.append(model.model.norm.register_forward_hook(get_activation('norm', 'main', 'output')))
hooks.append(model.model.embed_tokens.register_forward_hook(get_activation('embed', 'main', 'output')))
hooks.append(model.lm_head.register_forward_hook(get_activation('unembed', 'main', 'input')))

for i, layer in enumerate(model.model.layers):
    registerable_hooks = [
        (layer.self_attn.o_proj, "resid" , 'oprj', "output"),
        (layer.input_layernorm, "resid", 'inorm', "output"),
        (layer.mlp.down_proj, "resid", 'dprj', "output"),
        (layer.mlp.gate_proj, "resid", 'gprj', "output"),
        (layer.self_attn, "attn", '', "output2"),

    ]
    def add_if_not_taken(module, type, name, layer_type):
        # check by name + layer_type
        if module == layer.self_attn:
            return
        registerable_hooks.append((module, type, name, layer_type))
    for name, module in layer.named_modules():
        if ("self_attn" in name):
            continue
        if len(name) == 0:
            continue
        if isinstance(module, torch.nn.Linear):
            add_if_not_taken(module, "out", name, "output")
            add_if_not_taken(module, "in", name, "input")
        if "norm" in name:
            add_if_not_taken(module, "out", name, "output")

    for module, type, name, layer_type in registerable_hooks:
        hooks.append(module.register_forward_hook(get_activation(f'{i}_{name}', type, layer_type)))



## Collect activations

In [10]:
import numpy as np
import tqdm
import torch
import random
def collect_activations(model, dataset, tokenizer, max_tokens=80_000, batch_size=12):
    sequence_lengths = [512]  # Sequence lengths: 1, 2, 4, 8, ..., 2048
    total_tokens = 0
    batch_texts = []
    batch_lengths = []
    
    # Initialize progress bar
    with tqdm.tqdm(total=max_tokens, desc='Collecting activations', unit='tokens') as progress:
        for i,sample in enumerate(dataset):
            # print("starting with sample: ", i+1, " total tokens: ", total_tokens)
            if total_tokens >= max_tokens:
                print("Maximal tokens reached, total tokens: ", total_tokens)
                break
            # report which sample item we're in using tqdm report next to the progress bar
            progress.set_postfix({
                'sample': f"{i+1}/211m",
            })

            text = sample['text']
            for seq_len in sequence_lengths:
                if not (len(tokenizer.encode(text)) >= 512):
                    continue
                if random.random() > 0.9:
                    continue

         
                batch_texts.append(text)
                batch_lengths.append(seq_len)

                # Process in batches
                if len(batch_texts) >= batch_size:
                    
                    inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to("cuda")
                    total_tokens += inputs['input_ids'].numel()
                    
                    progress.update(inputs['input_ids'].numel())
                    with torch.no_grad():
                        out = model(**inputs, output_attentions=True)

                    batch_texts = []
                    batch_lengths = []
                    del inputs
                    del out
                    torch.cuda.empty_cache()
                    

                if total_tokens >= max_tokens:
                    print("Maximal tokens reached, total tokens: ", total_tokens)
                    break

        # Process any remaining texts in the batch
        print("SOME TEXTS LEFT, NUMBER OF TEXTS LEFT: ", len(batch_texts))
        
    print("done with collecting activations, now concatenating them")
    # Convert activations to numpy arrays
    for layer_type in activations:
        for layer_index in activations[layer_type]:
            activations[layer_type][layer_index] = torch.stack(activations[layer_type][layer_index], dim=0)

    return activations


In [11]:
collected_activations = collect_activations(model, fw, tokenizer)

Collecting activations: 86016tokens [00:31, 2705.56tokens/s, sample=380/211m]                                                                                                                                                                                                                                                           

Maximal tokens reached, total tokens:  86016
Maximal tokens reached, total tokens:  86016
SOME TEXTS LEFT, NUMBER OF TEXTS LEFT:  0
done with collecting activations, now concatenating them





In [15]:
import torch

layer_importance = {}
for i in range(32):
    layer_key = f'{i}'
    if layer_key in activations['out']:
        layer_activation = activations['out'][layer_key]
        layer_importance[i] = torch.mean(torch.abs(layer_activation)).item()

# Sort layers by importance
sorted_layers = sorted(layer_importance.items(), key=lambda x: x[1], reverse=True)

## Construct distillation network (child network)

In [13]:
from transformers import AutoConfig
import torch
model_id = "meta-llama/Meta-Llama-3-8B"



config = AutoConfig.from_pretrained(
    model_id,
    dtype=torch.float32,
    device_map="auto",
)


config.num_hidden_layers = 8
config.num_attention_heads = 8
config.num_key_value_heads = 4
config.hidden_size = 1024
config.intermediate_size = 3584

def get_model():
    new_model = AutoModelForCausalLM.from_config(config).cuda()
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        new_model.config.pad_token_id = new_model.config.eos_token_id
    return new_model
new_model = get_model()



## Subclone weights

In [16]:
import torch
import torch.nn as nn
neuron_importancee = []
def compute_global_neuron_importance(activations):
    global_neuron_importance = torch.zeros(list(activations['resid'].values())[0].shape[-1])
    for layer_key in activations['resid']:
        layer_activations = activations['resid'][layer_key]
        neuron_importancee.append(layer_activations.abs().mean(dim=(0, 1)))
        global_neuron_importance += layer_activations.abs().mean(dim=(0, 1))
    return global_neuron_importance


def compute_head_importance(activations, num_heads):
    head_importance = []
    for layer_key in activations['attn']: 
        acts = activations['attn'][layer_key] # (B, H, S)
        acts = acts.sum(dim=-1).mean(dim=0) # (B, H, S)
        
        head_importance.append(acts)


    return torch.stack(head_importance, dim=0)

def prepare_attention_weights(old_attn):
    # Repeat K and V weights to match Q
    num_kv_heads = old_attn.num_key_value_heads
    repeat_factor = old_attn.num_heads // num_kv_heads
    q_weights = old_attn.q_proj.weight.data.view(old_attn.num_heads, old_attn.head_dim, -1)
    k_weights = old_attn.k_proj.weight.data.view(num_kv_heads, old_attn.head_dim, -1)
    v_weights = old_attn.v_proj.weight.data.view(num_kv_heads, old_attn.head_dim, -1)
    o_weights = old_attn.o_proj.weight.data.view(-1, old_attn.num_heads, old_attn.head_dim)

    
    return q_weights, k_weights, v_weights, o_weights




def subclone_attention(old_attn, new_attn, neuron_indices, head_indices):   
    # print (new_attn.q_proj.weight.data.shape) 
    num_kv_heads = old_attn.num_key_value_heads
    new_num_kv_heads = new_attn.num_key_value_heads
    repeat_factor = old_attn.num_heads // num_kv_heads

    q_weights, k_weights, v_weights, o_weights = prepare_attention_weights(old_attn)

    kv_head_indices = (head_indices % num_kv_heads)[:new_num_kv_heads]

    q_weights = q_weights[head_indices][:, :, neuron_indices]
    k_weights = k_weights[kv_head_indices][:, :, neuron_indices]
    v_weights = v_weights[kv_head_indices][:, :, neuron_indices]
    o_weights = o_weights[neuron_indices][:, head_indices, :]

    new_attn.q_proj.weight.data = q_weights.reshape(-1, len(neuron_indices)).float()
    new_attn.k_proj.weight.data = k_weights.reshape(-1, len(neuron_indices)).float()
    new_attn.v_proj.weight.data = v_weights.reshape(-1, len(neuron_indices)).float()
    new_attn.o_proj.weight.data = o_weights.reshape(len(neuron_indices), -1).float()

successful_weight_transfers = 0
def subclone_weight(old_layer, new_layer, activations, sort_dim, pre_adjust=None):
    global successful_weight_transfers
    dim_k = new_layer.weight.data.shape[sort_dim]
    neuron_importance = activations.abs()
    neuron_importance = neuron_importance.sum(dim=[0])
    if (neuron_importance.dim() > 1):
        neuron_importance = neuron_importance.sum(dim=[0])
    
    sorted_neurons = torch.topk(neuron_importance, k=dim_k, sorted=False).indices
    weight_data = old_layer.weight.data
    
    if sort_dim == 0:
        new_layer.weight.data = weight_data[sorted_neurons].float()
    elif sort_dim == 1:
        if pre_adjust is not None:
            new_layer.weight.data = pre_adjust(weight_data)[:, sorted_neurons].float()
        else:
            new_layer.weight.data = weight_data[:, sorted_neurons].float()
    elif sort_dim == 2:
        if pre_adjust is not None:
            new_layer.weight.data = pre_adjust(weight_data)[:, :, sorted_neurons].float()
        else:
            new_layer.weight.data = weight_data[:, :, sorted_neurons].float()

    if (weight_data.dim() > 1):
        new_layer.weight.data *= ((weight_data.shape[-1]/new_layer.weight.data.shape[-1]) ** 0.5)
    successful_weight_transfers += 1
    

def subclone_both_weight(old_layer, new_layer, outact, inact):
    global successful_weight_transfers
    new_input_dim = new_layer.weight.data.shape[1]
    new_output_dim = new_layer.weight.data.shape[0]
    def get_neuron_importance(activations, dim_k):
        neuron_importance = activations.abs()
        neuron_importance = neuron_importance.sum(dim=[0])
        if (neuron_importance.dim() > 1):
            neuron_importance = neuron_importance.sum(dim=[0])
        
        sorted_neurons = torch.topk(neuron_importance, k=dim_k, sorted=False).indices
        return sorted_neurons

    out_importance = get_neuron_importance(outact, new_output_dim)
    in_importance = get_neuron_importance(inact, new_input_dim)
    
    new_layer.weight.data = old_layer.weight.data[out_importance][:, in_importance].float() 
    if (new_layer.weight.data.dim() > 1):
        new_layer.weight.data *= ((old_layer.weight.data.shape[-1]/new_layer.weight.data.shape[-1]) ** 0.5)
    successful_weight_transfers += 1
    

def subclone_layer(old_idx, activations,old_layer, new_layer, neuron_indices, head_indices, new_model):
    subclone_attention(old_layer.self_attn, new_layer.self_attn, neuron_indices, head_indices)
    

    new_layer_named_modules = dict(new_layer.named_modules())

    def pre_adjust_setup(name):
        if "down_proj" in name:
            return lambda x: x[neuron_indices, :]
        elif "gate_proj" in name:
            return lambda x: x[:, neuron_indices]
        elif "o_proj" in name:
            return lambda x: x[:, neuron_indices]
        elif "norm" in name:
            return lambda x: x[neuron_indices]
        return None
    
    for name, module in old_layer.named_modules():
        if len(name) == 0:
            continue
        if not ((isinstance(module, torch.nn.Linear)) or ("norm" in name)):
            continue

        new_module = new_layer_named_modules[name]

        
        if (f"{old_idx}_{name}" in activations["out"] and f"{old_idx}_{name}" in activations["in"]):

            outact = activations["out"][f"{old_idx}_{name}"]
            inact = activations["in"][f"{old_idx}_{name}"]
            subclone_both_weight(module, new_module, outact, inact)

        elif (f"{old_idx}_{name}" in activations["out"]):
      
            outact = activations["out"][f"{old_idx}_{name}"]
            subclone_weight(module, new_module, outact, 0, pre_adjust=pre_adjust_setup(name))

        elif (f"{old_idx}_{name}" in activations["in"]):
      
            inact = activations["in"][f"{old_idx}_{name}"]
            subclone_weight(module, new_module, inact, 1, pre_adjust=pre_adjust_setup(name))
        
                
    

def subclone_model(base_model, new_model, activations):
    if ("subcloned_model" in locals()):
        del subcloned_model

    global_neuron_importance = compute_global_neuron_importance(activations)

    head_importance = compute_head_importance(activations, base_model.config.num_attention_heads)
    
    # Select top neurons and heads
    top_neurons = torch.topk(global_neuron_importance, k=new_model.config.hidden_size, sorted=False).indices # global
    top_heads = torch.topk(head_importance, k=new_model.config.num_attention_heads, sorted=False).indices # for each layer (layer, head)
    
    # Subclone embedding
    subclone_weight(base_model.model.embed_tokens, new_model.model.embed_tokens, activations["main"]["embed"], 1) # because it's embedding
    subclone_weight(base_model.lm_head, new_model.lm_head, activations["main"]["unembed"], 1)
    subclone_weight(base_model.model.norm, new_model.model.norm, activations["main"]["norm"], 0)
    
    # Subclone layers
    layers_to_keep = [0, 1, 2,3, -4, -3, -2, -1]  # Adjust as needed
    for new_idx, old_idx in enumerate(layers_to_keep):
        if old_idx < 0:
            old_idx = len(base_model.model.layers) + old_idx
        subclone_layer(old_idx, activations ,base_model.model.layers[old_idx], new_model.model.layers[new_idx], top_neurons, top_heads[old_idx], new_model)
    
    
    return new_model


In [17]:
subcloned_model = subclone_model(model, get_model(), collected_activations)

## Analyze results

In [18]:
successful_weight_transfers, sum(1 for i in subcloned_model.parameters())

(43, 75)

In [19]:
mismatch = []
sucloned_params = dict(subcloned_model.named_parameters())
for n,p in new_model.named_parameters():
    if p.shape != sucloned_params[n].shape:
        mismatch.append((n, p.shape, sucloned_params[n].shape))
print(mismatch)

[]


In [21]:
# number of params
def count_parameters(model):
    num_of_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    model_norm = sum(p.norm() for p in model.parameters()) / num_of_params
    memory_needed = num_of_params * 4 / 1024 / 1024 / 1024
    print (f"Number of parameters: {num_of_params}, memory needed: {memory_needed:.2f} GB, model norm: {model_norm}")
    
count_parameters(new_model)
count_parameters(subcloned_model)
count_parameters(model)

Number of parameters: 375931904, memory needed: 1.40 GB, model norm: 6.601186669286108e-06
Number of parameters: 375931904, memory needed: 1.40 GB, model norm: 6.537608442158671e-06
Number of parameters: 8030261248, memory needed: 29.92 GB, model norm: 2.0712614059448242e-06


## Check convergence of the subcloned network

In [27]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
    def __init__(self, fw_dataset, tokenizer, max_length=512, count=None):
        self.fw_dataset = fw_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.count = count

    def __iter__(self):
        for item in self.fw_dataset:
            encoding = self.tokenizer(
                item['text'],
                truncation=True,
                max_length=self.max_length,
                padding='max_length',
                return_tensors='pt'
            )
            yield {
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'labels': encoding['input_ids'].squeeze() # Shifted right by 1, original input as label
            }

    def __len__(self):
        return self.count 
    
def train(model, train_dataset, val_dataset, tokenizer, batch_size=4, epochs=6, lr=0.0001, max_grad_norm=1.0, warmup_steps=0, gradient_accumulation_steps=4):
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.005)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=len(train_dataloader) * epochs)
    scaler = GradScaler()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for step, batch in enumerate(train_dataloader):
            input_ids = batch['input_ids'].to(model.device).long()
            attention_mask = batch['attention_mask'].to(model.device).float()
            labels = batch['labels'].to(model.device).long()
            
            with autocast(dtype=torch.bfloat16):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                # Shift logits and labels for next token prediction
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                
                # Calculate loss
                loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=model.config.eos_token_id)
                loss = loss / gradient_accumulation_steps

            scaler.scale(loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += loss.item() * gradient_accumulation_steps

            if step % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Step {step}/{len(train_dataloader)} | Loss: {total_loss / (step+1)}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(model.device).long()
                attention_mask = batch['attention_mask'].to(model.device)
                labels = batch['labels'].to(model.device).long()
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                
                loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=model.config.eos_token_id)
                val_loss += loss.item()

        val_loss /= len(val_dataloader)
        print(f"Epoch {epoch+1}/{epochs} | Validation Loss: {val_loss}")

        # Save checkpoint
        torch.save(model.state_dict(), f"model_checkpoint_epoch_{epoch+1}.pt")


In [28]:
# Prepare your datasets
train_dataset = StreamingDataset(fw.take(130_000_000).shuffle(), tokenizer, max_length=2048, count=130_000_000)
val_dataset = StreamingDataset(fw.skip(130_000_000).take(10_000_000), tokenizer, max_length=2048, count=10_000_000)


In [None]:
subcloned_model.gradient_checkpointing_enable()
train(subcloned_model, train_dataset, val_dataset, tokenizer) # Subcloned model