In [1]:
!pip install -q -U auto_mix_prep datasets icecream prompt_toolkit pydantic
!pip install -q -U torch tqdm transformers accelerate

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import json
from prompt_toolkit.shortcuts import checkboxlist_dialog, input_dialog
import argparse
from tqdm import tqdm
import os
import time

In [2]:
class ModelModifier:
  # loading model, optimizer and tokenizer
    def __init__(self, model_name=None, top_percent=50, batch_size=1):
        self.model_name = model_name
        self.top_percent = top_percent
        self.batch_size = batch_size

        if model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto")
            self.optimizer = torch.optim.Adam(self.model.parameters())
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True, add_prefix_space=True)
        else:
            self.model = None
            self.optimizer = None
            self.tokenizer = None

        self.layer_snr = {}
        self.layer_types = []

    # getting weight types of layers
    def get_weight_types(self):
        weight_types = set()
        for name, module in self.model.named_modules():
            parts = name.split('.')
            if any(hasattr(module, attr) for attr in ['weight', 'bias','inv_freq']):
                layer_index = next((i for i, part in enumerate(parts) if part.isdigit()), -1)
                weight_type = '.'.join(parts[layer_index + 1:]) if layer_index != -1 else name
                weight_types.add(weight_type)
        return list(weight_types)

    def interactive_select_weights(self):
        weight_types = self.get_weight_types()
        sorted_weight_types = self.sort_weight_types(weight_types)
        # selected_types = checkboxlist_dialog(
        #     title="Select Weight Types",
        #     text="Deselect the weight types you do not want to scan for SNR:",
        #     values=[(wt, wt) for wt in sorted_weight_types],
        #     default_values=sorted_weight_types
        # )
        selected_types = sorted_weight_types
        self.layer_types = selected_types
        return selected_types

    def sort_weight_types(self, weight_types):
        categories = {}
        for wt in weight_types:
            category = wt.split('.')[0]
            categories.setdefault(category, []).append(wt)
        sorted_categories = {k: sorted(v) for k, v in sorted(categories.items(), key=lambda item: item[0])}
        sorted_weight_types = [wt for sublist in sorted_categories.values() for wt in sublist]
        return sorted_weight_types

    def calculate_snr_for_layer(self, layer_type):
        layers = [(name, module) for name, module in self.model.named_modules() if layer_type in name and hasattr(module, 'weight')]
        num_batches = (len(layers) + self.batch_size - 1) // self.batch_size

        with tqdm(total=num_batches, unit='batch', desc=f'Calculating SNR for {layer_type}') as progress_bar:
            for i in range(0, len(layers), self.batch_size):
                batch_layers = layers[i:i + self.batch_size]
                for name, module in batch_layers:
                    weights = module.weight.detach()
                    if weights.ndim < 2:
                        weights = weights.unsqueeze(0)
                    S = torch.linalg.svdvals(weights)
                    max_singular_value = S[0]
                    sigma_estimated = self.estimate_sigma_with_full_iqr(S)
                    n, m = weights.shape[-2:]
                    mp_threshold = self.marchenko_pastur_threshold(sigma_estimated, n, m)
                    signal = S[S > mp_threshold].sum()
                    noise = S[S <= mp_threshold].sum()
                    snr = signal / noise if noise != 0 else float('inf')
                    snr_ratio = snr / max_singular_value
                    self.layer_snr[name] = {'type': layer_type, 'snr': snr_ratio.item()}
                progress_bar.update(1)

    @staticmethod
    def marchenko_pastur_threshold(sigma, n, m):
        beta = n / m if n < m else m / n
        threshold = sigma * np.sqrt((1 + np.sqrt(beta)) ** 2)
        return threshold

    @staticmethod
    def estimate_sigma_with_full_iqr(S):
        q75 = torch.quantile(S, 0.75)
        q25 = torch.quantile(S, 0.25)
        iqr = q75 - q25
        sigma_estimated = iqr / 1.349
        return sigma_estimated

    def assess_layers_snr(self, selected_weight_types):
        total_layers = sum(1 for name, module in self.model.named_modules() if any(layer_type in name for layer_type in selected_weight_types) and hasattr(module, 'weight'))
        start_time = time.time()

        with tqdm(total=len(selected_weight_types), unit='type', desc='Calculating SNR for types') as progress_bar:
            for layer_type in selected_weight_types:
                self.calculate_snr_for_layer(layer_type)
                progress_bar.update(1)

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Total time taken: {total_time:.2f} seconds")

    def save_snr_to_json(self):
        model_name_slug = self.model_name.replace('/', '-').replace('_', '-')
        directory = 'model_snr_results'
        filename = os.path.join(directory, f'snr_results_{model_name_slug}.json')

        # Ensure the directory exists
        if not os.path.exists(directory):
            os.makedirs(directory)

        serializable_data = {}
        for layer_name, info in self.layer_snr.items():
            snr_value = info['snr'].item() if isinstance(info['snr'], torch.Tensor) else info['snr']
            layer_type = str(info['type'])
            serializable_data[layer_name] = {'snr': snr_value, 'type': layer_type}

        with open(filename, 'w') as file:
            json.dump(serializable_data, file, indent=4)

        print(f"Results saved to {filename}")
        self.save_top_snr_ratios_to_json(filename)
        self.generate_unfrozen_params_yaml(filename)

    def generate_unfrozen_params_yaml(self, json_filename, top_percent=None):
        top_percent = top_percent if top_percent is not None else self.top_percent
        with open(json_filename, 'r') as file:
            snr_data = json.load(file)
        unfrozen_parameters = {}
        for layer_name, info in snr_data.items():
            layer_type = info['type']
            if layer_type not in unfrozen_parameters:
                unfrozen_parameters[layer_type] = []
            unfrozen_parameters[layer_type].append((layer_name, info['snr']))
        top_layers_by_type = {}
        for layer_type, layers in unfrozen_parameters.items():
            layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)
            num_top_layers = int(len(layers) * top_percent / 100)
            top_layers_by_type[layer_type] = [layer[0] for layer in layers_sorted[:num_top_layers]]
        # Modify the yaml_filename to include the input json name and top_percent
        json_file_base = os.path.splitext(os.path.basename(json_filename))[0]
        yaml_filename = f"{json_file_base}_unfrozenparameters_{top_percent}percent.yaml"
        with open(yaml_filename, 'w') as file:
            file.write("unfrozen_parameters:\n")
            file.write("- ^lm_head.weight$\n")
            file.write("- ^model.embed_tokens.weight$\n")
            for layer_type, layer_names in top_layers_by_type.items():
                file.write(f"# {layer_type} layers\n")
                for layer_name in layer_names:
                    file.write(f"- {layer_name}\n")
        print(f"Top {top_percent}% SNR layers saved to {yaml_filename}")

    def save_top_snr_ratios_to_json(self, json_filename, filename=None):
        with open(json_filename, 'r') as file:
            snr_data = json.load(file)
        all_snr_layers = {}
        for layer_name, info in snr_data.items():
            layer_type = info['type']
            if layer_type not in all_snr_layers:
                all_snr_layers[layer_type] = []
            all_snr_layers[layer_type].append((layer_name, info['snr']))
        for layer_type, layers in all_snr_layers.items():
            layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)
            all_snr_layers[layer_type] = {layer[0]: layer[1] for layer in layers_sorted}

        json_file_base = os.path.splitext(os.path.basename(json_filename))[0]
        filename = f"{json_file_base}_sorted.json" if filename is None else filename

        with open(filename, 'w') as file:
            json.dump(all_snr_layers, file, indent=4)
        print(f"All SNR layers sorted and saved to {filename}")


In [3]:
model_name = "microsoft/phi-2"
batch_size = 1

In [4]:
model_name_slug = model_name.replace('/', '-').replace('_', '-')
snr_file_path = os.path.join('model_snr_results', f'snr_results_{model_name_slug}.json')

batch_size = int(batch_size) if batch_size else 1
modifier = ModelModifier(model_name=model_name, batch_size=batch_size)
selected_weight_types = modifier.interactive_select_weights()
if selected_weight_types:
    modifier.assess_layers_snr(selected_weight_types)
    modifier.save_snr_to_json()
    print("Finished SNR scanning and data saved.")
else:
    print("No weight types selected.")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Calculating SNR for types:   0%|          | 0/11 [00:00<?, ?type/s]
Calculating SNR for input_layernorm:   0%|          | 0/32 [00:00<?, ?batch/s][A
Calculating SNR for input_layernorm: 100%|██████████| 32/32 [00:01<00:00, 23.63batch/s]
Calculating SNR for types:   9%|▉         | 1/11 [00:01<00:13,  1.36s/type]
Calculating SNR for lm_head:   0%|          | 0/1 [00:00<?, ?batch/s][A
Calculating SNR for lm_head: 100%|██████████| 1/1 [00:02<00:00,  2.92s/batch]
Calculating SNR for types:  18%|█▊        | 2/11 [00:04<00:20,  2.28s/type]
Calculating SNR for mlp.fc1:   0%|          | 0/32 [00:00<?, ?batch/s][A
Calculating SNR for mlp.fc1:   3%|▎         | 1/32 [00:01<00:58,  1.90s/batch][A
Calculating SNR for mlp.fc1:   6%|▋         | 2/32 [00:03<00:56,  1.90s/batch][A
Calculating SNR for mlp.fc1:   9%|▉         | 3/32 [00:05<00:54,  1.89s/batch][A
Calculating SNR for m

Total time taken: 371.68 seconds
Results saved to model_snr_results/snr_results_microsoft-phi-2.json
All SNR layers sorted and saved to snr_results_microsoft-phi-2_sorted.json
Top 50% SNR layers saved to snr_results_microsoft-phi-2_unfrozenparameters_50percent.yaml
Finished SNR scanning and data saved.



