In [1]:
from GenZ import analyse_model, get_model_df, System
from GenZ.Models.get_language_model import einsum_test
import numpy as np

In [2]:
%load_ext autoreload
%autoreload 2   

In [3]:
class Einsum():
    def __init__(self, equation, dims):
        """
        equation: Einstein summation notation string
        dims: Dictionary of tensor dimensions keyed by the corresponding label in the equation
        """
        self.equation = equation
        self.dims = dims
    def get_size(self, tensor):
        return np.prod(tensor)
    def get_tensors(self):
        input_dims = self.equation.split('->')[0]
        input_a = [self.dims[label] for label in input_dims.split(',')[0]]
        input_b = [self.dims[label] for label in input_dims.split(',')[1]]
        output = [self.dims[label] for label in self.equation.split('->')[1]]
        return input_a, input_b, output

    def get_num_ops(self):
        """
        Compute the number of operations needed for the given einsum configuration.
        """
        input_dims = self.equation.split('->')[0]
        dim_labels = set(''.join(input_dims.split(',')))

        # The number of operations is the product of the dimensions involved in the contraction
        num_ops = np.prod([self.dims[label] for label in dim_labels])
        return num_ops

In [4]:
op = Einsum('bhqd,hql->bhdl', {'b': 32, 'h': 64, 'q': 128, 'd': 256, 'l': 512})

In [5]:
op.get_num_ops()

34359738368

In [6]:
op.get_tensors()

([32, 64, 128, 256], [64, 128, 512], [32, 64, 256, 512])

In [7]:
list(map(op.get_size, [op.get_tensors()[0], op.get_tensors()[1], op.get_tensors()[2]]))

[67108864, 4194304, 268435456]

In [8]:
model = einsum_test(equation='bld,dhq->blhq', einsum_vars={'b': 'b', 'l': 512, 'd': 2048, 'h': 12, 'q': 128})

In [9]:
model

'einsum_10_25_2024_17_05_17.csv'

In [10]:
get_model_df(model, system=System(), batch_size=1)

Unnamed: 0,Op Type,Dimension,Bound,C/M ratio,Op Intensity,Latency (msec),Cycles,C Effcy,Num ops (MFLOP),Input_a (MB),Input_w (MB),Output (MB),Total Data (MB),Throughput (Tflops),Compute time (msec),Memory time (msec),Communication time (msec),Compute cycle,Memory cycle,Communication cycle
0,Einsum,"([1, 512, 2048], [2048, 12, 128], [1, 512, 12,...",Compute,2.540592,323.368421,0.026189,24617.495477,1,3221.225472,2.0,6.0,1.5,9.5,123.0,0.026189,0.010308,0.0,24617.495477,9689.670139,0.0
