# Step Analysis Script
This script is designed to analyze the model step logs and extract relevant information for debugging purposes such as tensor statistics of input and output tensors recursively for each module in the model.

In [1]:
import json
from pathlib import Path
import tqdm as tqdm
import pandas as pd

In [2]:
pd.set_option('display.max_rows', None)       # Show all rows
pd.set_option('display.max_columns', None)    # Show all columns
pd.set_option('display.width', None)          # Don't limit the width
pd.set_option('display.max_colwidth', None)   # Show full content in each cell (for older pandas versions)
# For pandas >= 1.0.0
pd.set_option('display.max_colwidth', None)

In [6]:
log_file_path = Path("/raid/s3/opengptx/max_lue/repositories/modalities/data/checkpoints/2025-07-14__14-58-03_e39ded5c/logs/tensor_stats_rank_0.jsonl") # 32 ranks , TP = 1

In [7]:
tensor_stats_list = []
with open(log_file_path, 'r') as file:
    for line in tqdm.tqdm(file):
        try:
            data = json.loads(line)
            tensor_stats_list.append(data)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
        except KeyError as e:
            print(f"Key error: {e}")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")

0it [00:00, ?it/s]

9380it [00:00, 133462.82it/s]


In [8]:
data_frame = pd.DataFrame(tensor_stats_list)
data_frame = data_frame[["counter", "hook_type", "tensor_tag", "global_shape", "local_shape",  "mean", "std", "min", "max", "nan_count", "inf_count", "rank"]]
data_frame.head()

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank
0,0,forward_input,transformer.wte,"[1, 4096]","[1, 4096]",-1.0,-1.0,0.0,50256.0,0,0,0
1,0,forward_weights,transformer.wte.weight,"[50304, 512]","[12576, 512]",1.8e-05,0.02002,-0.105469,0.105957,0,0,0
2,0,forward_output,transformer.wte,"[1, 1024, 512]","[1, 1024, 512]",0.000158,0.019897,-0.089355,0.082031,0,0,0
3,0,forward_input,transformer.drop,"[1, 1024, 512]","[1, 1024, 512]",0.000158,0.019897,-0.089355,0.082031,0,0,0
4,0,forward_output,transformer.drop,"[1, 1024, 512]","[1, 1024, 512]",0.000158,0.019897,-0.089355,0.082031,0,0,0


In [9]:

# in order
data_frame[(data_frame["counter"] < 2) & (data_frame["hook_type"] == "pre_forward")]

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank


In [10]:
# sort by max 

data_frame[(data_frame["counter"] < 2) & (data_frame["hook_type"] == "pre_forward")].sort_values(by=["min"], ascending=True).head(20)

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank


In [11]:
data_frame[(data_frame["counter"] == 1) & (data_frame["hook_type"] == "pre_forward")]

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank


In [12]:
data_frame[(data_frame["counter"] == 0) & (data_frame["hook_type"] == "forward_output")].sort_values(by=["max"], ascending=False).head(20)

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank
23,0,forward_output,transformer.h.0.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",-0.001792908,0.443359,-1.859375,10.4375,0,0,0
13,0,forward_output,transformer.h.0.attn.q_attn,"[1, 4096, 128]","[1, 4096, 128]",-0.001792908,0.443359,-1.859375,10.4375,0,0,0
215,0,forward_output,transformer.h.4.ffn_norm,"[1, 4096, 512]","[1, 1024, 512]",-1.02073e-06,0.996094,-4.90625,4.84375,0,0,0
234,0,forward_output,transformer.h.5.attention_norm,"[1, 4096, 512]","[1, 1024, 512]",-4.004687e-07,0.996094,-4.71875,4.65625,0,0,0
279,0,forward_output,transformer.h.6.attention_norm,"[1, 4096, 512]","[1, 1024, 512]",-6.854534e-07,0.996094,-5.0,4.65625,0,0,0
260,0,forward_output,transformer.h.5.ffn_norm,"[1, 4096, 512]","[1, 1024, 512]",-3.8445e-06,0.996094,-4.75,4.59375,0,0,0
530,0,forward_output,transformer.h.11.ffn_norm,"[1, 4096, 512]","[1, 1024, 512]",4.30271e-07,1.0,-4.03125,4.46875,0,0,0
170,0,forward_output,transformer.h.3.ffn_norm,"[1, 4096, 512]","[1, 1024, 512]",-1.15484e-06,0.992188,-4.125,4.40625,0,0,0
54,0,forward_output,transformer.h.1.attention_norm,"[1, 4096, 512]","[1, 1024, 512]",-8.903444e-07,0.988281,-4.40625,4.375,0,0,0
189,0,forward_output,transformer.h.4.attention_norm,"[1, 4096, 512]","[1, 1024, 512]",1.385808e-06,0.992188,-4.125,4.375,0,0,0


In [19]:
data_frame[(data_frame["counter"] == 10) & (data_frame["hook_type"] == "backward_output")].sort_values(by=["max"], ascending=False).head(20)

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank


In [14]:
data_frame[(data_frame["counter"] == 0) & (data_frame["hook_type"] == "backward_output")].sort_values(by=["max"], ascending=True)

Unnamed: 0,counter,hook_type,tensor_tag,global_shape,local_shape,mean,std,min,max,nan_count,inf_count,rank
1362,0,backward_output,transformer.lm_head,"[1, 4096, 50304]","[1, 4096, 50304]",4.067857e-13,1.087785e-06,-0.0002441406,3.632158e-08,0,0,0
1394,0,backward_output,transformer.h.28.attn.q_attn,"[1, 4096, 128]","[1, 4096, 128]",-1.077751e-10,2.04891e-08,-2.04891e-07,2.058223e-07,0,0,0
1389,0,backward_output,transformer.h.28.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",-1.077751e-10,2.04891e-08,-2.04891e-07,2.058223e-07,0,0,0
1457,0,backward_output,transformer.h.24.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",-6.082246e-12,2.28174e-08,-2.942979e-07,2.160668e-07,0,0,0
1462,0,backward_output,transformer.h.24.attn.q_attn,"[1, 4096, 128]","[1, 4096, 128]",-6.082246e-12,2.28174e-08,-2.942979e-07,2.160668e-07,0,0,0
1423,0,backward_output,transformer.h.26.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",-4.66116e-11,2.386514e-08,-2.151355e-07,2.291054e-07,0,0,0
1428,0,backward_output,transformer.h.26.attn.q_attn,"[1, 4096, 128]","[1, 4096, 128]",-4.66116e-11,2.386514e-08,-2.151355e-07,2.291054e-07,0,0,0
1542,0,backward_output,transformer.h.19.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",8.071765e-12,2.211891e-08,-2.607703e-07,2.402812e-07,0,0,0
1547,0,backward_output,transformer.h.19.attn.q_attn,"[1, 4096, 128]","[1, 4096, 128]",8.071765e-12,2.211891e-08,-2.607703e-07,2.402812e-07,0,0,0
1440,0,backward_output,transformer.h.25.attn.qkv_transforms.0,"[1, 2, 4096, 64]","[1, 2, 4096, 64]",5.31486e-12,1.944136e-08,-2.384186e-07,2.458692e-07,0,0,0
