In [27]:
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List


@dataclass
class Parameters:
    total_parameters_bytes: int
    parameters_per_layer_bytes: List[int]
    activation_parameters_bytes: List[int]


@dataclass
class Model:
    model_name: str
    num_layers: int
    parameters: Parameters


@dataclass
class ExecutionTime:
    total_time_ms: float
    forward_backward_time_ms: float
    batch_generator_time_ms: float
    layernorm_grads_all_reduce_time_ms: float
    embedding_grads_all_reduce_time_ms: float
    optimizer_time_ms: float
    layer_compute_total_ms: List[float]


@dataclass
class ExecutionMemory:
    total_memory_mb: float
    layer_memory_total_mb: List[float]


@dataclass
class ModelMetrics:
    model: Model
    execution_time: ExecutionTime
    execution_memory: ExecutionMemory


def json_2_model(json_data):
    parameters = Parameters(**json_data["model"]["parameters"])
    model = Model(
        model_name=json_data["model"]["model_name"],
        parameters=parameters,
        num_layers=json_data["model"]["num_layers"],
    )
    execution_time = ExecutionTime(**json_data["execution_time"])
    execution_memory = ExecutionMemory(**json_data["execution_memory"])
    model_metrics = ModelMetrics(
        model=model, execution_time=execution_time, execution_memory=execution_memory
    )
    return model_metrics


def read_json_file(file_name):
    with open(file_name, "r") as f:
        data = json.load(f)
    return data

import re

def extract_tp_bs(filename):
    # Match the pattern for tp and bs
    match = re.search(r"_tp(\d+)_bs(\d+)\.json", filename)
    if match:
        tp_number = int(match.group(1))  # Extract tp number
        bs_number = int(match.group(2))  # Extract bs number
        return tp_number, bs_number
    else:
        raise ValueError("Filename does not match the expected pattern")

In [None]:
def manipulate_write_new_file(json_file_path, new_file_path):
    json_file_path = Path(json_file_path)

    tp_tmp, bs_tmp = (
        extract_tp_bs(json_file_path.name)
    )
    json_data = read_json_file(json_file_path)
    json_data = json_2_model(json_data)

    tmp = json_data.execution_memory.layer_memory_total_mb
    json_data.execution_memory.layer_memory_total_mb = [i*tp_tmp for i in tmp]
    json_data = json.dumps(asdict(json_data), indent=2)

    with open(new_file_path, "w") as f:
        f.write(json_data)
