In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tqdm import tqdm
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
SFT_MODEL = "meta-llama/Llama-3.1-8B"
RL_MODEL  = "/shared/storage-01/users/sagnikm3/tulu_SFT_2000_steps_lr_5e-7"

# SFT_MODEL = "/home/sagnikm3/verl/checkpoints/prime_example/Eurus-2-7B-SFT-gsm8k/global_step_320/actor/hf/"
# RL_MODEL  = "/home/sagnikm3/verl/checkpoints/prime_example/Eurus-2-7B-SFT-gsm8k_masked/global_step_320/actor/hf"
# tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL)

model_sft = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/"
)
model_rl = AutoModelForCausalLM.from_pretrained(
    RL_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/"
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 99.45it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 40.99it/s]


In [6]:
all_deltas = []
param_sizes = []
sft_params = []  # references to SFT params so we can update them
rl_state_dict = model_rl.state_dict()
sft_state_dict = model_sft.state_dict()

missing_in_rl = [name for name, _ in model_sft.named_parameters() if name not in rl_state_dict]
missing_in_sft = [name for name, _ in model_rl.named_parameters() if name not in sft_state_dict]

if missing_in_rl or missing_in_sft:
    print("Missing in RL:", missing_in_rl)
    print("Missing in SFT:", missing_in_sft)

num_nonzero_dict = {}
with torch.no_grad():
    for name_sft, param_sft in tqdm(model_sft.named_parameters()):
        try:

            delta =  rl_state_dict[name_sft] - sft_state_dict[name_sft]
            num_nonzero = (delta != 0).sum().item()

            num_nonzero_dict[name_sft] = num_nonzero/delta.numel()

            param_sizes.append(delta.numel())
            sft_params.append(param_sft)

            all_deltas.append(delta.view(-1))
        except Exception as e:
            print(e)

all_deltas_tensor = torch.cat(all_deltas, dim=0)
print(all_deltas_tensor.size())
print("percentage of 0 values in the task vector")
print((all_deltas_tensor == 0 ).sum() / len(all_deltas_tensor))

15it [00:00, 139.29it/s]

The size of tensor a (128264) must match the size of tensor b (128256) at non-singleton dimension 0


291it [00:07, 36.67it/s]


The size of tensor a (128264) must match the size of tensor b (128256) at non-singleton dimension 0
torch.Size([6979588096])
percentage of 0 values in the task vector
tensor(0.5994)


In [7]:
tolerances = [1e-5]
zero_tensor = torch.zeros_like(all_deltas_tensor)

for tol in tolerances:
    fraction_close_to_zero = torch.isclose(all_deltas_tensor, zero_tensor, atol=tol).sum() / all_deltas_tensor.numel()
    print(f"Tolerance = {tol:.0e} -> Fraction close to zero: {fraction_close_to_zero:.4f}")


Tolerance = 1e-05 -> Fraction close to zero: 0.6310


In [10]:
layerwise_sparsity = {}
for key in num_nonzero_dict:
    if key.startswith('model.layers'):
        layer = key.split(".")[2]
        
        if layer not in layerwise_sparsity:
            layerwise_sparsity[layer] = []
            layerwise_sparsity[layer].append(num_nonzero_dict[key])
        else:
            layerwise_sparsity[layer].append(num_nonzero_dict[key])


layerwise_sparsity = {k:sum(layerwise_sparsity[k])/len(layerwise_sparsity[k]) for k in layerwise_sparsity}
layerwise_sparsity = [layerwise_sparsity[k] for k in layerwise_sparsity]
keys = [i for i in range(len(layerwise_sparsity))]
# plt.bar(keys, layerwise_sparsity)

In [None]:
sft_state_dict['model.layers.0.self_attn.q_proj.weight'].size()

In [None]:
num_nonzero_dict["model.embed_tokens.weight"]

In [None]:
for key in num_nonzero_dict:
    print(key)

Experiment: checking if the same subnetwork is updating across steps

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
path_to_folder = "/home/sagnikm3/PRM/outputs_batchsize8"


directories = [os.path.join(path_to_folder, d) for d in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, d))]
per_step_deltas = []
sparsity = []
for i in tqdm(range(len(directories)-1)):
    SFT_MODEL = directories[i]
    RL_MODEL  = directories[i+1]

    tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL)
    model_sft = AutoModelForCausalLM.from_pretrained(
        SFT_MODEL,
        torch_dtype=torch.float16,
        device_map="cpu",
        cache_dir="/shared/storage-01/huggingface/models/"
    )
    model_rl = AutoModelForCausalLM.from_pretrained(
        RL_MODEL,
        torch_dtype=torch.float16,
        device_map="cpu",
        cache_dir="/shared/storage-01/huggingface/models/"
    )

    all_deltas = []
    param_sizes = []
    sft_params = []  # references to SFT params so we can update them

    num_nonzero_dict = {}
    with torch.no_grad():
        rl_state_dict = model_rl.state_dict()

        for name_sft, param_sft in tqdm(model_sft.named_parameters()):
            param_rl = rl_state_dict[name_sft].to(param_sft.device)

            delta =  param_rl - param_sft.data

            num_nonzero = (delta != 0).sum().item()
            num_nonzero_dict[name_sft] = num_nonzero/delta.numel()

            param_sizes.append(delta.numel())
            sft_params.append(param_sft)

            all_deltas.append(delta.view(-1))

    all_deltas_tensor = torch.cat(all_deltas, dim=0)
    per_step_deltas.append(all_deltas_tensor)
    sparsity.append(((all_deltas_tensor ==0 ).sum() / len(all_deltas_tensor)).item())

In [None]:
overlap = []
for i in tqdm(range(len(per_step_deltas)-1)):
    A = per_step_deltas[i]
    B = per_step_deltas[i+1]

    A_mask = (A != 0)
    B_mask = (B != 0)

    overlap_mask = A_mask & B_mask

    overlap_count = overlap_mask.sum()
    overlap.append(overlap_count/(A!=0).sum())


In [None]:
plt.plot(overlap)

In [None]:
import torch
sft_state_dict = torch.load('/home/sagnikm3/direct-preference-optimization/.cache/sagnikm3/sft_llama_full_precision/step-119808/policy.pt')['state']
rl_state_dict = torch.load('/home/sagnikm3/direct-preference-optimization/.cache/sagnikm3/dp_llama_bf16_2025-02-22_14-53-10_470298/step-20000/policy.pt')['state']

sft_state_dict = {k: v.half() for k, v in sft_state_dict.items()}
rl_state_dict = {k: v.half() for k, v in rl_state_dict.items()}

In [None]:
all_deltas = []
param_sizes = []
sft_params = []  # references to SFT params so we can update them

num_nonzero_dict = {}
with torch.no_grad():

    for name_sft in tqdm(sft_state_dict):
        param_sft = sft_state_dict[name_sft]

        param_rl = rl_state_dict[name_sft].to(param_sft.device)
        try:

            delta =  param_rl - param_sft
        

            num_nonzero = (delta != 0).sum().item()
            num_nonzero_dict[name_sft] = num_nonzero/delta.numel()

            param_sizes.append(delta.numel())
            sft_params.append(param_sft)

            all_deltas.append(delta.view(-1))
        except Exception as e:
            print(e)


all_deltas_tensor = torch.cat(all_deltas, dim=0)
print(all_deltas_tensor.size())
print("percentage of 0 values in the task vector")
print((all_deltas_tensor ==0 ).sum() / len(all_deltas_tensor))

In [None]:
for name_sft in tqdm(sft_state_dict):
    print(name_sft)

In [None]:
sft_state_dict['model.layers.0.self_attn.q_proj.weight'][0][0].item()

In [None]:
rl_state_dict['model.layers.0.self_attn.q_proj.weight'][0][0].item()

In [None]:
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

SFT_MODEL = "deepseek-ai/DeepSeek-V3-Base"
RL_MODEL  = "deepseek-ai/DeepSeek-R1-Zero"

tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL)

model_sft = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL,
    torch_dtype=torch.float32,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/",
    quantization_config=quantization_config
)
model_rl = AutoModelForCausalLM.from_pretrained(
    RL_MODEL,
    torch_dtype=torch.float32,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/",
    quantization_config=quantization_config
)

In [None]:
all_deltas = []
param_sizes = []
sft_params = []  # references to SFT params so we can update them

num_nonzero_dict = {}
with torch.no_grad():
    rl_state_dict = model_rl.state_dict()

    for name_sft, param_sft in tqdm(model_sft.named_parameters()):
        param_rl = rl_state_dict[name_sft].to(param_sft.device)
        try:

            delta =  param_rl - param_sft.data

            num_nonzero = (delta != 0).sum().item()
            num_nonzero_dict[name_sft] = num_nonzero/delta.numel()

            param_sizes.append(delta.numel())
            sft_params.append(param_sft)

            all_deltas.append(delta.view(-1))
        except Exception as e:
            print(e)

all_deltas_tensor = torch.cat(all_deltas, dim=0)
print(all_deltas_tensor.size())
print("percentage of 0 values in the task vector")
print((all_deltas_tensor ==0 ).sum() / len(all_deltas_tensor))

In [None]:
SFT_MODEL = "allenai/Llama-3.1-Tulu-3-8B-SFT"
RL_MODEL  = "/shared/storage-01/users/sagnikm3/tulu_bs32/"

tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL)

model_sft = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL,
    torch_dtype=torch.float32,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/"
)
model_rl = AutoModelForCausalLM.from_pretrained(
    RL_MODEL,
    torch_dtype=torch.float32,
    device_map="cpu",
    cache_dir="/shared/storage-01/huggingface/models/"
)

In [None]:
import torch
import sys
sys.setrecursionlimit(100000)  # If you have very large models, you may need a higher recursion limit

import deepspeed.runtime.fp16.loss_scaler
import deepspeed.runtime.zero.config

from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint

# If you haven't already created your `model_sft`, do so here
# e.g. model_sft = ...

torch.serialization.add_safe_globals([
    deepspeed.runtime.zero.config.ZeroStageEnum,
    deepspeed.runtime.fp16.loss_scaler.LossScaler,
])


checkpoint_dir = "/shared/storage-01/users/sagnikm3/tulu_bs32_epoch5/step_18000/"

state_dict = load_state_dict_from_zero_checkpoint(model_rl, checkpoint_dir)
# model_sft.load_state_dict(state_dict)


In [None]:
all_deltas = []
param_sizes = []
sft_params = []  # references to SFT params so we can update them

num_nonzero_dict = {}
with torch.no_grad():
    rl_state_dict = model_rl.state_dict()

    for name_sft, param_sft in tqdm(model_sft.named_parameters()):
        param_rl = rl_state_dict[name_sft].to(param_sft.device)
        try:

            delta =  param_rl - param_sft.data

            num_nonzero = (delta != 0).sum().item()
            num_nonzero_dict[name_sft] = num_nonzero/delta.numel()

            param_sizes.append(delta.numel())
            sft_params.append(param_sft)

            all_deltas.append(delta.view(-1))
        except Exception as e:
            print(e)

all_deltas_tensor = torch.cat(all_deltas, dim=0)
print(all_deltas_tensor.size())
print("percentage of 0 values in the task vector")
print((all_deltas_tensor ==0 ).sum() / len(all_deltas_tensor))

In [None]:
all_deltas = []
param_sizes = []
sft_params = []  # references to SFT params so we can update them
rl_state_dict = model_rl.state_dict()
sft_state_dict = model_sft.state_dict()

for name_sft, param_sft in tqdm(model_sft.named_parameters()):
    if name_sft not in rl_state_dict:
        print("woops")
for name_rl, param_rl in tqdm(model_rl.named_parameters()):
    if name_rl not in sft_state_dict:
        print("woops")

nonzero_dict = {}
with torch.no_grad():
    for name_sft, param_sft in tqdm(model_sft.named_parameters()):
        param_rl = rl_state_dict[name_sft].to(param_sft.device)
        try:
            delta =  param_rl - param_sft.data
            nonzero_dict[name_sft] = delta
        except Exception as e:
            print(e)

In [None]:
for key in nonzero_dict.keys():
    if len(nonzero_dict[key].shape) >=2:
        print(key, nonzero_dict[key].shape)
        print(torch.linalg.matrix_rank(nonzero_dict[key].float()))

In [None]:
nonzero_dict[key].shape

In [2]:
import torch
mask_cpu = torch.load("/home/sagnikm3/open-instruct/sft_vs_rl_mask.pt", map_location="cpu")

In [3]:

total_true   = sum(mask.sum().item() for mask in mask_cpu.values())
total_elems  = sum(mask.numel()    for mask in mask_cpu.values())

In [4]:

flat = torch.zeros(total_elems, dtype=torch.bool)
perm = torch.randperm(total_elems)

In [5]:
flat[perm[:total_true]] = True
random_global_masks = {}
offset = 0

In [6]:
for name, mask in mask_cpu.items():
    n = mask.numel()
    chunk = flat[offset:offset + n]
    random_global_masks[name] = chunk.view(mask.shape)
    offset += n
mask_cpu = random_global_masks

In [None]:
true_count  = sum(m.sum().item()   for m in random_global_masks.values())
total_count = sum(m.numel()        for m in random_global_masks.values())

density  = true_count / total_count
sparsity = 1.0 - density

print(f"Random global mask: {true_count}/{total_count} ones  "
    f"({density*100:.2f}% density, {sparsity*100:.2f}% sparsity)")

In [None]:
mask_cpu