In [None]:
import json
import time
import logging
import torch
from torch.nn import functional as F
from torch.nn.parameter import Parameter
import torch.profiler
from torch.profiler import profile, record_function, ProfilerActivity

from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
    int8_matmul_mixed_dequantize,
)

from bitsandbytes.triton.quantize_rowwise import quantize_rowwise

from torch.profiler import profile, record_function, ProfilerActivity

# Set up logging
logging.basicConfig(filename='output.log', level=logging.INFO, format='%(message)s')
logger = logging.getLogger()

def get_time(k, fn, X, W, repeat):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeat):
        fn(X, W)
    torch.cuda.synchronize()
    end = time.time()
    ms = (end - start) / repeat * 1000
    logger.info(f"time {k}: {ms:.3f} ms")

def get_time_swichback(k, fn, X, W, b, repeat):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeat):
        fn(X, W, b)
    torch.cuda.synchronize()
    end = time.time()
    ms = (end - start) / repeat * 1000

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeat):
        fn(X, W, b)
    torch.cuda.synchronize()
    end = time.time()
    ms = (end - start) / repeat * 1000
    logger.info(f"time {k} after fit-params: {ms:.3f} ms")
    return ms

def standatrt(x, w, b):
    return F.linear(x, w, b)

def swichback(X_3D, W, b):
    X = X_3D.view(-1, X_3D.size(-1))
    X_int8, state_X = quantize_rowwise(X)
    W_int8, state_W = quantize_rowwise(W)
    return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, None).view(*X_3D.size()[:-1], -1)

def mul_matrix(x_in, x_out, w_in, w_out, repeat, profiler_results_path='profiler_results'):
    x = torch.randn(x_in, x_out, dtype=torch.float16).cuda()
    w = torch.randn(w_in, w_out, dtype=torch.float16).cuda()
    b = torch.empty(w_in, dtype=torch.float16).cuda()
    
    logger.info(f"Running fwd with x size: {x.size()}, w size: {w.size()}")
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU], 
        record_shapes=True,
        with_flops=True,
        profile_memory=True
    ) as prof:
        with record_function("standard_fwd"):
            standart_time = get_time_swichback("standard_fwd", standatrt, x, w, b, repeat)
        with record_function("swichback"):
            swichback_time = get_time_swichback("swichback", swichback, x, w, b, repeat)

    prof.export_chrome_trace(f"profiler_results/{profiler_results_path}_({x.shape[0]},{x.shape[1]})_({w.shape[0]},{w.shape[1]}).json")
    
    logger.info(prof.key_averages().table(sort_by="cuda_time_total"))
    logger.info('\n')
    return standart_time, swichback_time


In [2]:
def calculate_times(layers, sequence_length, batch_sizes, repeat=64):
    standard_times = []
    switchback_times = []
    optimal_times = []

    total_standard_times = []
    total_switchback_times = []
    total_optimal_times = []

    for batch in batch_sizes:
        batch_standard_times = {}
        batch_switchback_times = {}
        batch_optimal_times = {}

        total_standard_time = 0
        total_switchback_time = 0
        total_optimal_time = 0

        for key, value in layers.items():
            layer_name = key
            size_1, size_2 = value
            batch_size_adjusted = batch * sequence_length

            standard_time, switchback_time = mul_matrix(
                batch_size_adjusted, size_1, size_2, size_1, repeat=repeat
            )

            batch_standard_times[layer_name] = standard_time
            batch_switchback_times[layer_name] = switchback_time
            
            optimal_time = min(standard_time, switchback_time)
            batch_optimal_times[layer_name] = optimal_time

            total_standard_time += standard_time
            total_switchback_time += switchback_time
            total_optimal_time += optimal_time

        standard_times.append(batch_standard_times)
        switchback_times.append(batch_switchback_times)
        optimal_times.append(batch_optimal_times)

        total_standard_times.append(total_standard_time)
        total_switchback_times.append(total_switchback_time)
        total_optimal_times.append(total_optimal_time)

    return {
        "batch_sizes": batch_sizes,
        "standard_times": standard_times,
        "switchback_times": switchback_times,
        "optimal_times": optimal_times,
        "total_standard_times": total_standard_times,
        "total_switchback_times": total_switchback_times,
        "total_optimal_times": total_optimal_times
    }


```
LlamaDecoderLayer(
  (self_attn): LlamaSdpaAttention(
    (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
    (k_proj): Linear(in_features=2048, out_features=256, bias=False)
    (v_proj): Linear(in_features=2048, out_features=256, bias=False)
    (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
    (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
    (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
  (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
```

In [None]:
layers = {
    'q_proj': [2048, 2048],
    'k_proj': [2048, 256],
    'v_proj': [2048, 256],
    'o_proj': [2048, 2048],
    'gate_proj': [2048, 5632],
    'up_proj': [2048, 5632],
    'down_proj': [5632, 2048]
}

sequence_length = 128
batch_sizes = [4, 8, 16, 32, 64, 128, 256, 512]

results = calculate_times(layers,sequence_length,batch_sizes)
print(results)

In [None]:
results

In [None]:
import matplotlib.pyplot as plt
def prepare_layer_data(times):
    layer_data = {layer: [] for layer in times[0].keys()}
    for layer_times in times:
        for layer, time in layer_times.items():
            layer_data[layer].append(time)
    return layer_data

standard_layer_data = prepare_layer_data(results['standard_times'])
switchback_layer_data = prepare_layer_data(results['switchback_times'])
optimal_layer_data = prepare_layer_data(results['optimal_times'])

x_legend_dict = {}
for x in range(len(batch_sizes)):
    x_legend_dict[x] = batch_sizes[x]

def plot_layer_data(batch_sizes, standard_layer_data, switchback_layer_data, optimal_layer_data):
    for key in standard_layer_data.keys():
        plt.figure(figsize=(10, 6))
        plt.plot(batch_sizes, optimal_layer_data[key], label=f'optimal_{key}', marker='*', color = 'green', linewidth=10)
        plt.plot(batch_sizes, standard_layer_data[key], label=f'standard_{key}', marker='o', color ='blue')
        plt.plot(batch_sizes, switchback_layer_data[key], label=f'switchback_{key}', marker='x', color = 'red')
        
        plt.title(f'Performance of {key}: ({layers[key][0]};{layers[key][0]}) layer')
        plt.xlabel('Batch Size')
        plt.ylabel('Time ')
        plt.xscale('log')
        plt.xticks(ticks=batch_sizes, labels=[x_legend_dict[i] for i in range(len(batch_sizes))])
        plt.legend(title='Components')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

plot_layer_data(batch_sizes, standard_layer_data, switchback_layer_data, optimal_layer_data)

In [None]:
batch_sizes = [4, 8, 16, 32, 64, 128, 256, 512]

x_legend_dict = {}
for x in range(len(batch_sizes)):
    x_legend_dict[x] = batch_sizes[x]

standard_layer_data = prepare_layer_data(results['standard_times'])
switchback_layer_data = prepare_layer_data(results['switchback_times'])
optimal_layer_data = prepare_layer_data(results['optimal_times'])
def plot_layer_data(batch_sizes, standard_layer_data, switchback_layer_data, optimal_layer_data):
    num_layers = len(standard_layer_data)  # Get the number of layers
    ncols = 4  # Set number of columns for subplots
    nrows = (num_layers + 1) // ncols  # Calculate rows needed to fit all layers
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30, 6 * nrows), constrained_layout=True)

    # Ensure axes is iterable, flatten if more than one row
    if nrows > 1:
        axes = axes.flatten()

    for ax, key in zip(axes, standard_layer_data.keys()):
        ax.plot(batch_sizes, optimal_layer_data[key], label=f'optimal_{key}', marker='*', color='green', linewidth=5)
        ax.plot(batch_sizes, standard_layer_data[key], label=f'standard_{key}', marker='o', color='blue')
        ax.plot(batch_sizes, switchback_layer_data[key], label=f'switchback_{key}', marker='x', color='red')

        ax.set_title(f'Performance of {key}: ({layers[key][0]};{layers[key][1]}) layer')
        ax.set_xlabel('Batch Size')
        ax.set_ylabel('Time ')
        ax.set_xscale('log')
        ax.set_xticks(batch_sizes)
        ax.set_xticklabels([str(x_legend_dict[i]) for i in range(len(batch_sizes))])
        ax.legend(title='Components')
        ax.grid(True)

    # Hide any unused subplots if they exist
    for i in range(num_layers, nrows * ncols):
        fig.delaxes(axes[i])

    plt.show()

plot_layer_data(batch_sizes, standard_layer_data, switchback_layer_data, optimal_layer_data)

In [None]:
x_values = range(len(batch_sizes))  

x_labels = {}
for x in range(len(batch_sizes)):
    x_labels[x] = batch_sizes[x]

standard_total_times = results['total_standard_times']
switchback_total_times = results['total_switchback_times']
optimal_total_times = results['total_optimal_times']


plt.figure(figsize=(10, 6))
# plt.plot(batch_sizes, optimal_total_times, label=f'Avarage_optimal', marker='*', color = 'green', linewidth=10)
plt.plot(batch_sizes, standard_total_times, label=f'Avarage_standard', marker='o', color ='blue')
plt.plot(batch_sizes, switchback_total_times, label=f'Avarage_switchback', marker='x', color = 'red')

plt.xlabel('Batch Size')
plt.ylabel('Time ')
plt.xscale('log')
plt.xticks(ticks=batch_sizes, labels=[x_legend_dict[i] for i in range(len(batch_sizes))])
plt.legend(title='Components')
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
def calculate_speedup(standard_times, switchback_times, total_standard_times, total_switchback_times):
    layer_speedup_data = {layer: [] for layer in standard_times[0].keys()}
    overall_speedup = []

    for standard_layer_times, switchback_layer_times in zip(standard_times, switchback_times):
        for layer in layer_speedup_data.keys():
            standard_time = standard_layer_times[layer]
            switchback_time = switchback_layer_times[layer]
            speedup = -100 * (switchback_time - standard_time) / standard_time
            layer_speedup_data[layer].append(speedup)


    for total_standard_time, total_switchback_time in zip(total_standard_times, total_switchback_times):
        overall_speedup_value = -100 * (total_switchback_time - total_standard_time) / total_standard_time
        overall_speedup.append(overall_speedup_value)

    return layer_speedup_data, overall_speedup

In [None]:
layer_speedup_data, overall_speedup = calculate_speedup(
    results['standard_times'], 
    results['switchback_times'],
    results['total_standard_times'],
    results['total_switchback_times']
)
layer_speedup_df = pd.DataFrame(layer_speedup_data, index=batch_sizes)
layer_speedup_df.index.name = 'Batch Size'

overall_speedup_df = pd.DataFrame({'Overall Speedup (%)': overall_speedup}, index=batch_sizes)
overall_speedup_df.index.name = 'Batch Size'

layer_speedup_df['Overall Speedup (%)'] = overall_speedup_df['Overall Speedup (%)']
print("Layer Speedup Data (Swichback vs Standard) with Overall Speedup:")
layer_speedup_df.round(2)
layer_speedup_df.round(2).to_csv('Layer Speedup Data (Swichback vs Standard) with Overall Speedup.csv')
layer_speedup_df.round(2)

In [None]:
layer_optimal_speedup_data, overall_optimal_speedup = calculate_speedup(
    results['standard_times'], 
    results['optimal_times'],
    results['total_standard_times'],
    results['total_optimal_times']
)

layer_optimal_speedup_df = pd.DataFrame(layer_optimal_speedup_data, index=batch_sizes)
layer_optimal_speedup_df.index.name = 'Batch Size'

overall_optimal_speedup_df = pd.DataFrame({'Overall Speedup (%)': overall_optimal_speedup}, index=batch_sizes)
overall_optimal_speedup_df.index.name = 'Batch Size'

layer_optimal_speedup_df['Overall Speedup (%)'] = overall_optimal_speedup_df['Overall Speedup (%)']

print("Layer Speedup Data (Optimal vs Standard) with Overall Speedup:")
layer_optimal_speedup_df.round(2).to_csv('Layer Speedup Data (Optimal vs Standard) with Overall Speedup.csv')
layer_optimal_speedup_df.round(2)