In [1]:
import os
import json
import torch
import glob

def decompose_recompose_llama_model(num_shards, input_model_dir, output_model_dir):

    with open(os.path.join(input_model_dir, 'params.json'), 'r') as fp:
        params = json.loads(fp.read())

    assert params['dim'] % num_shards == 0, "number of shards need to divide parameter dimension %d" % params['dim']

    print('loading...')
    checkpoints = [torch.load(path, map_location=torch.device('cpu')) for path in glob.glob(os.path.join(input_model_dir, '*.pth'))]

    layer_kind = {
        'tok_embeddings': 'ParallelEmbedding',
        'output': 'ColumnParallelLinear',
        'attention.wq': 'ColumnParallelLinear',
        'attention.wk': 'ColumnParallelLinear',
        'attention.wv': 'ColumnParallelLinear',
        'attention.wo': 'RowParallelLinear',
        'feed_forward.w1': 'ColumnParallelLinear',
        'feed_forward.w2': 'RowParallelLinear',
        'feed_forward.w3': 'ColumnParallelLinear',
        'attention_norm': None,
        'ffn_norm': None,
        'norm': None,
        'rope.freqs': None,
    }

    output = [dict() for x in range(num_shards)]

    print('converting...')
    for key in checkpoints[0].keys():
        tensors = [m[key] for m in checkpoints]
        print(key)
        print('  in shapes=', [p.shape for p in tensors])
        for pattern, kind in layer_kind.items():
            if key.replace('.weight', '').endswith(pattern):
                print('  kind=', kind)
                if kind == 'ColumnParallelLinear':
                    with torch.no_grad():
                        merged = torch.cat(tensors, 0)
                        slice_size = merged.shape[0] // num_shards
                        for rank in range(num_shards):
                            output[rank][key] = merged[slice_size * rank: slice_size * (rank + 1),:].clone().detach()
                elif kind in ('ParallelEmbedding', 'RowParallelLinear'):
                    with torch.no_grad():
                        merged = torch.cat(tensors, 1)
                        slice_size = merged.shape[1] // num_shards
                        for rank in range(num_shards):
                            output[rank][key] = merged[:,slice_size * rank: slice_size * (rank + 1)].clone().detach()
                else:
                    for rank in range(num_shards):
                        output[rank][key] = tensors[0]
                print('  out shapes=', [output[rank][key].shape for rank in range(num_shards)])
                print()
                break
        else:
            raise Exception('parameter name not recognized')

    print('saving...')
    os.makedirs(output_model_dir, exist_ok=True)
    with open(os.path.join(output_model_dir, 'params.json'), 'w') as fp:
        fp.write(json.dumps(params))

    for rank in range(num_shards):
        print(' ', rank)
        torch.save(output[rank], os.path.join(output_model_dir, 'consolidated.%02d.pth' % rank))

    print('done.')

# Call the function with your parameters:
decompose_recompose_llama_model(8, '../7B', '../7B_sharded_2/')

loading...
converting...
tok_embeddings.weight
  in shapes= [torch.Size([32000, 4096])]
  kind= ParallelEmbedding
  out shapes= [torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512]), torch.Size([32000, 512])]

norm.weight
  in shapes= [torch.Size([4096])]
  kind= None
  out shapes= [torch.Size([4096]), torch.Size([4096]), torch.Size([4096]), torch.Size([4096]), torch.Size([4096]), torch.Size([4096]), torch.Size([4096]), torch.Size([4096])]

output.weight
  in shapes= [torch.Size([32000, 4096])]
  kind= ColumnParallelLinear
  out shapes= [torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096]), torch.Size([4000, 4096])]

layers.0.attention.wq.weight
  in shapes= [torch.Size([4096, 4096])]
  kind= ColumnParallelLinear
  out shapes= [torch.