In [None]:
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from pathlib import Path
import numpy as np
from transformers import AutoTokenizer
import torch
import pickle

import matplotlib.pyplot as plt

In [None]:
files_root = Path("/mnt/datasets/tests/denis/tensors_f32/")
#files_root = Path("/mnt/datasets/tests/denis/tensors/")

In [None]:
fm_files = {int(file.stem.split("tensor")[1]): file for file in (files_root / "fast_llm/logits/").glob("tensor*.pt")}
hf_files = {int(file.stem.split("tensor")[1]): file for file in (files_root / "hf/logits").glob("tensor*.pt")}
assert len(fm_files) == len(hf_files)
len(fm_files)

In [None]:
hf_tokens = []
fm_tokens = []
max_adiff = []
mean_adiff = []
sum_adiff = []
for i in range(len(fm_files)):
    fm_data = torch.load(fm_files[i])
    hf_data = torch.load(hf_files[i])
    
    hf_tokens.append(hf_data[0, -1, :].argmax().item())
    fm_tokens.append(fm_data[0, -1, :].argmax().item())

    adiff = torch.abs(hf_data[0, -1, :] - fm_data[0, -1, :])
    max_adiff.append(adiff.max().item())
    mean_adiff.append(adiff.mean().item())
    sum_adiff.append(adiff.sum().item())
    
all(a == b for a, b in zip(hf_tokens, fm_tokens))

In [None]:
min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)

# Left plot: max and mean absolute differences
axes[0].plot(max_adiff, label='max')
axes[0].plot(mean_adiff, label='mean')
axes[0].set_title('Max and Mean Absolute Difference')
axes[0].set_xlabel('Token Position Index')
axes[0].set_ylabel('Absolute Difference')
axes[0].legend()
axes[0].grid(True)

# Right plot: sum absolute difference
axes[1].plot(sum_adiff, label='sum', color='tab:orange')
axes[1].set_title('Sum Absolute Difference')
axes[1].set_xlabel('Token Position Index')
axes[1].set_ylabel('Absolute Difference')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
fm_hidden_files = {int(file.stem.split("data")[1]): file for file in (files_root / "fast_llm/hidden_states/").glob("data*.pickle")}
hf_hidden_files = {int(file.stem.split("data")[1]): file for file in (files_root / "hf/hidden_states").glob("data*.pickle")}

In [None]:
def mad(new_token_index, fm_hidden_files, hf_hidden_files):
    with fm_hidden_files[new_token_index].open("rb") as f:
        fm_data = pickle.load(f)
    with hf_hidden_files[new_token_index].open("rb") as f:
        hf_data = pickle.load(f)
    max_adiffs_hidden_layers = []
    for i in range(len(hf_data)):
        max_adiff = torch.abs(hf_data[i][0,-1,:]-fm_data[i]['tensor'][0,-1,:]).max().item()
        max_adiffs_hidden_layers.append(max_adiff)
    return max_adiffs_hidden_layers
    

In [None]:
new_token_index = 107
new_token_index1 = 108
max_adiffs_hidden_layers = mad(0, fm_hidden_files, hf_hidden_files)
max_adiffs_hidden_layers2 = mad(new_token_index, fm_hidden_files, hf_hidden_files)
max_adiffs_hidden_layers3 = mad(new_token_index1, fm_hidden_files, hf_hidden_files)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)

axes[0].plot(max_adiffs_hidden_layers, label='new_token_0', color='blue')
axes[0].plot(max_adiffs_hidden_layers2, label=f'new_token_{new_token_index}', color='green')
axes[0].set_title('Max and Mean Absolute Difference')
axes[0].set_xlabel('Hidden Layer Index')
axes[0].set_ylabel('Max Absolute Difference')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(max_adiffs_hidden_layers, label='new_token_0', color='blue')
axes[1].plot(max_adiffs_hidden_layers3, label=f'new_token_{new_token_index1}', color='green')
axes[1].set_title('Max and Mean Absolute Difference')
axes[1].set_xlabel('Hidden Layer Index')
axes[1].set_ylabel('Max Absolute Difference')
axes[1].legend()
axes[1].grid(True)



plt.title('Per-layer Max Absolute Differences')
plt.tight_layout()
plt.show()

In [None]:
print(hf_tokens_bf16[106:120])
print(fm_tokens_b16[106:120])

In [None]:
print(hf_tokens[106:120])
print(fm_tokens[106:120])

In [None]:
hf_tokens_bf16  = hf_tokens
fm_tokens_b16 = fm_tokens

In [None]:
min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))

In [None]:
min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, hf_tokens_bf16)))

In [None]:
min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(fm_tokens, fm_tokens_b16)))

In [None]:
min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))

In [None]:
import safetensors

In [None]:

# this is just to show possibility
# assumes no converiosn of key names or tensors or aggregation of tensors is needed
def load(path, model):
   with safetensors.safe_open(path, 'pt', device=model.distributed.device) as f:
      key = 'model.embed_tokens.weight'
      # this would load only part of the tensor for this tensor parallel, etc rank
      # get_local_slice_ranges would return a multidimensional range object 
      tensor = f.get_slice(key)[model.get_local_slice_ranges(key)]
      model.import_tensor(key, tensor)
      

In [None]:
from fast_llm.engine.distributed.config import DistributedConfig

In [None]:
print("| rank | local_rank | tensor_rank | pipeline_rank | data_rank | sequence_data_rank | batch_data_rank | | | | | | |")
print("| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |")
for rank in range(16):
    cfg  = DistributedConfig(rank=rank, world_size=16, local_world_size=8, tensor_parallel=2, pipeline_parallel=2, sequence_data_parallel=2, pipeline_first=True)
    res = f"| {cfg.rank} | {cfg.local_rank} | {cfg.tensor_rank} | {cfg.pipeline_rank} | {cfg.data_rank} | {cfg.sequence_data_rank} | {cfg.batch_data_rank} |"
    for name, dm in cfg.distributed_dims.items():
        if name == 'world':
            continue
        res += f"{name}_{dm.id} |"
    print(res)


In [None]:
res = '|'
for name, dm in cfg.distributed_dims.items():
    if name == 'world':
        continue
    res += f"{name}_{dm.id} |"

In [None]:
res

In [None]:
import pickle

In [None]:
with  open("/mnt/checkpoints/test/denis/smol_eval_experiment_test/lm_eval/batch_0.pkl", 'rb') as f:
    data = pickle.load(f)

In [None]:
data[0]

In [None]:
data[1:]