In [3]:
import json
from collections import defaultdict

# Read the JSON file
with open('Mixtral-8x7B-Instruct-v0.1-offloading-full-demo/model.safetensors.index.json', 'r') as f:
    data = json.load(f)

weight_map = data['weight_map']

# Create new organization
new_organization = defaultdict(list)
model_0_tensors = []

for key, file in weight_map.items():
    parts = key.split('.')
    if len(parts) > 3 and parts[0] == 'model' and parts[1] == 'layers':
        layer_num = int(parts[2])
        if 'block_sparse_moe' in parts and 'experts' in parts:
            new_file = f"model-{layer_num // 2 +1 :>2d}-of-17.safetensors".replace(' ','0')
            new_organization[new_file].append((key, file))
        else:
            model_0_tensors.append((key, file))
    else:
        model_0_tensors.append((key, file))

# Add model-0 tensors
new_organization['model-00-of-17.safetensors'] = model_0_tensors



In [4]:
# Create new weight map
new_weight_map = {}
for new_file, tensor_list in new_organization.items():
    for key, old_file in tensor_list:
        new_weight_map[key] = new_file

# Create new JSON structure
new_data = {
    "metadata": data["metadata"],
    "weight_map": new_weight_map
}

# Write the new JSON file
with open('new_model.safetensors.index.json', 'w') as f:
    json.dump(new_data, f, indent=2)

print("New organization created and saved to 'new_model.safetensors.index.json'")

New organization created and saved to 'new_model.safetensors.index.json'


In [42]:
import os
import torch
from safetensors import safe_open
from safetensors.torch import save_file

def load_weights(weight_files):
    """Load weights from multiple SafeTensors files."""
    state_dict = {}
    for file in weight_files:
        with safe_open(file, framework="pt", device="cpu") as f:
            for key in f.keys():
                state_dict[key] = f.get_tensor(key)
    return state_dict



# Directory containing the original weight files
input_dir = "Mixtral-8x7B-Instruct-v0.1-offloading-full-demo"

# Directory to save the reorganized weights
output_dir = "Mixtral-8x7B-Instruct-v0.1-reorganized"

# Get all weight files from the input directory
weight_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.safetensors')]

# Load all weights
state_dict = load_weights(weight_files)


In [43]:
state_dict

{'model.layers.10.block_sparse_moe.experts.1.w1.weight': tensor([[-0.0073,  0.0017,  0.0044,  ...,  0.0034, -0.0088, -0.0020],
         [-0.0128, -0.0093, -0.0141,  ..., -0.0047, -0.0042,  0.0107],
         [ 0.0028,  0.0034,  0.0022,  ..., -0.0021,  0.0121,  0.0039],
         ...,
         [ 0.0067, -0.0089,  0.0006,  ...,  0.0040, -0.0075, -0.0168],
         [-0.0013, -0.0074, -0.0089,  ...,  0.0015,  0.0229, -0.0045],
         [-0.0076,  0.0160, -0.0126,  ...,  0.0097,  0.0060, -0.0037]],
        dtype=torch.bfloat16),
 'model.layers.10.block_sparse_moe.experts.1.w2.weight': tensor([[-6.9885e-03,  1.0071e-03, -2.6733e-02,  ...,  1.9531e-02,
          -3.0670e-03,  1.4038e-02],
         [-7.9956e-03,  2.3499e-03,  6.2866e-03,  ..., -7.3547e-03,
           1.5747e-02, -4.6997e-03],
         [-8.2970e-05, -3.2959e-03, -1.0193e-02,  ...,  2.4872e-03,
          -1.0757e-03, -1.6098e-03],
         ...,
         [ 2.7771e-03, -2.1362e-02,  1.2939e-02,  ...,  2.8534e-03,
          -1.0204e-

In [10]:
'model.layers.10.block_sparse_moe.experts.1.w1.weight'.split('.')

['model', 'layers', '10', 'block_sparse_moe', 'experts', '1', 'w1', 'weight']

In [45]:
def save_layer_weights(state_dict, output_dir):
    """Save weights for each layer in separate files."""
    os.makedirs(output_dir, exist_ok=True)
    
    saving_index = {}
    for key in state_dict:
        parts = key.split('.')
        if len(parts) > 3 and parts[0] == 'model' and parts[1] == 'layers':
            layer_num = int(parts[2])
            if 'block_sparse_moe' in parts and 'experts' in parts:
                file_name = f"model-{layer_num // 2 +1 :>2d}-of-17.safetensors".replace(' ','0')  
            else:
                file_name = f"model-00-of-17.safetensors"
        else:
            file_name = f"model-00-of-17.safetensors"
        if file_name not in saving_index:
            saving_index[file_name] = {}
        saving_index[file_name].update({key:state_dict[key]})

    # return saving_index
    for filename in saving_index:
        print(filename)
        print(saving_index[filename].keys())
        file_path = os.path.join(output_dir, filename)
        save_file(saving_index[filename], file_path)




In [46]:
save_layer_weights(state_dict, output_dir)

model-06-of-17.safetensors
dict_keys(['model.layers.10.block_sparse_moe.experts.1.w1.weight', 'model.layers.10.block_sparse_moe.experts.1.w2.weight', 'model.layers.10.block_sparse_moe.experts.1.w3.weight', 'model.layers.10.block_sparse_moe.experts.2.w1.weight', 'model.layers.10.block_sparse_moe.experts.2.w2.weight', 'model.layers.10.block_sparse_moe.experts.2.w3.weight', 'model.layers.10.block_sparse_moe.experts.3.w1.weight', 'model.layers.10.block_sparse_moe.experts.3.w2.weight', 'model.layers.10.block_sparse_moe.experts.3.w3.weight', 'model.layers.10.block_sparse_moe.experts.4.w1.weight', 'model.layers.10.block_sparse_moe.experts.4.w2.weight', 'model.layers.10.block_sparse_moe.experts.4.w3.weight', 'model.layers.10.block_sparse_moe.experts.5.w1.weight', 'model.layers.10.block_sparse_moe.experts.5.w2.weight', 'model.layers.10.block_sparse_moe.experts.5.w3.weight', 'model.layers.10.block_sparse_moe.experts.6.w1.weight', 'model.layers.10.block_sparse_moe.experts.6.w2.weight', 'model.lay