In [1]:
import torch
import torch.nn as nn
from math import pi
from itertools import chain
import os
import sys
dir_name = os.getcwd()
parent_dir_name = os.path.dirname(dir_name)
sys.path.insert(0, parent_dir_name)
from modules.model_gpt import GPT, GPTConfig
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
from matplotlib import colormaps
import torch.nn.functional as F
from torchmetrics.text import Perplexity

Fix state dict '_orig_mode.' prefix:

In [2]:
def remove_state_dict_prefix(state_dict, prefix='_orig_mod.'):
    new_state_dict = state_dict.copy()
    for (key, value) in state_dict.items():
        if key.startswith(prefix):
            del new_state_dict[key]
            new_state_dict.update({key[len(prefix):]: value})
    return new_state_dict

In [3]:
device_linear = 'cuda:0'
lr = 1e-3
override_args = dict(           
                   batch_size=32,
                    # quantization_levels_input=16,
                    # quantization_levels_weights=8,    
                    # quantization_levels_output=32,                
                   )

# First model
# init_from = "/Data/pgi-15/common_models/dram_attention_project/gpt2-LinearDRAMAttention.pt"
init_from = "/Users/leroux/sEMG/saved_models/gpt2.pt"
model_sd_linear = torch.load(init_from, map_location='cpu')
model_ld_linear = model_sd_linear['model']
model_ld_linear = remove_state_dict_prefix(model_ld_linear)
config_args_linear = model_sd_linear['model_args']
[config_args_linear.update({k:v}) for (k, v) in override_args.items()]
print(f"Initializing {config_args_linear['attention']} model from weights: {init_from}")
config = GPTConfig(**config_args_linear)
model_linear = GPT(config)
model_linear.load_state_dict(model_ld_linear, strict=False)

# Model to compare
device = 'cuda:1'
init_from = "/Users/leroux/sEMG/saved_models/gpt2-from-scratch.pt"
model_sd = torch.load(init_from, map_location='cpu')
model_ld = model_sd['model']
model_ld = remove_state_dict_prefix(model_ld)
config_args = model_sd['model_args']
override_args = dict(           
                   batch_size=32,
                    # attention="DRAMAttention",             
                   )
[config_args.update({k:v}) for (k, v) in override_args.items()]
print(f"Initializing {config_args['attention']} model from weights: {init_from}")
config = GPTConfig(**config_args)
model = GPT(config)
model.load_state_dict(model_ld, strict=False)

# [print(f"Buffer {name}: {buf.size()}") for name, buf in model.named_buffers()]
pass

Initializing CausalSelfAttention model from weights: /Users/leroux/sEMG/saved_models/gpt2.pt
number of parameters: 123.65M
Initializing CausalSelfAttention model from weights: /Users/leroux/sEMG/saved_models/gpt2-from-scratch.pt
number of parameters: 123.59M


Use Data Parallel

In [4]:
# model_linear = torch.nn.DataParallel(model_linear, device_ids=[0,1,2,3]).to('cuda:0')
# model = torch.nn.DataParallel(model, device_ids=[0,1,2,3]).to('cuda:0')

Open Web Text

In [5]:
data_dir = os.path.join('/Users/leroux/sEMG/datasets/texts/', "openwebtext")
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
block_size = config_args['block_size']
batch_size = config_args['batch_size']
device_type = 'cuda'

block_size = 1024

def get_batch(split, device_id):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device_id, non_blocking=True), y.pin_memory().to(device_id, non_blocking=True)
    else:
        x, y = x.to(device_id), y.to(device_id)
    return x, y

Fix issue with model statistics

In [6]:
# with torch.no_grad():
#     model_linear.train()
#     model_linear.to(device)
#     for i in range(1000):
#         X, Y = get_batch('train')     
#         logits, loss = model_linear(X, Y)
#         print(f'iter: {i}, loss: {loss.item():.2f}')

# model_ld_linear = model_linear.state_dict()
# model_sd_linear.update({'model': model_ld_linear})
# del model_sd['optimizer']
# torch.save(model_sd, "/Users/leroux/sEMG/saved_models/gpt2-xl-LinearDRAMAttention-no-saved_optimizer.pt")
# print('ok')
pass

Init a and b

Calibration

In [7]:
if False:
    X_linear, Y_linear = get_batch('train', device_linear)  
    model_linear = model_linear.to(device_linear)
    X, Y = get_batch('train', device)  
    model = model.to(device)
    model.train()
    done = False
    error_threshold = 0.001    
    error_threshold_TIA = 0.2
    max_calibration_iter = 100
    alpha_max = 1.0
    alpha_min = 0.01
    alpha_decay_max_step = 50
    # alpha_calibration = torch.cos(torch.arange(alpha_decay_max_step) / alpha_decay_max_step * pi / 2) + alpha_min
    alpha_calibration = torch.linspace(alpha_max, alpha_min, alpha_decay_max_step)
    cmap = colormaps['plasma']
    # Take colors at regular intervals spanning the colormap.
    colors = cmap(np.linspace(0, 1, 5 * config_args['n_layer']))
    # colors = cmap(np.linspace(0, 1, 6 * config_args['n_layer']))

    calibration_iter = 0
    with torch.no_grad():        
        for layer in model.transformer.h:
        # for layer in model.module.transformer.h:
            for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.output_scaler]:
            # for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
                nn.init.constant_(scaler_a_b.a, val=1.0)
                nn.init.constant_(scaler_a_b.b, val=0.0)
        
        std_errors_plots = torch.zeros(5 * config_args['n_layer'], 0)
        mean_errors_plots = torch.zeros(5 * config_args['n_layer'], 0)
        # std_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
        # mean_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
        losses = []
        while not(done):
            # X, Y = get_batch('train')
            logits_linear, loss_linear = model_linear(X_linear, targets=Y_linear)
            logits, loss = model(X, targets=Y)
            if calibration_iter==0:
                q = model.transformer.h[-1].attn.Q[0,:,0].flatten().cpu().clone()
                q_after_scale = model.transformer.h[-1].attn.Q_after_scale[0,:,0].flatten().cpu().clone()
            losses += [loss.item()]
            # Init a and b w.r.t computed statistics
            std_errors = []
            mean_errors = []            
            if calibration_iter < alpha_decay_max_step:
                alpha = alpha_calibration[calibration_iter].item()
            else:
                alpha = alpha_min                
            done_list = []
            for l, layer in enumerate(model.transformer.h):
                for param_id, scaler_a_b in enumerate([layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.output_scaler]):
                # for param_id, scaler_a_b in enumerate([layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]):
                    std_errors += [torch.abs(scaler_a_b.std_after_scale-scaler_a_b.target_std).item()]   
                    if param_id != 3:   
                        threshold = error_threshold
                    else:
                        threshold = error_threshold_TIA
                    if std_errors[-1] > threshold:
                        if scaler_a_b.std_after_scale != 0.:
                            new_a = scaler_a_b.a * scaler_a_b.target_std / scaler_a_b.std_after_scale
                            scaler_a_b.a.fill_(alpha * new_a + (1-alpha) * scaler_a_b.a)
                        done_list += [False]
                    else:
                        done_list += [True]
                    
                    mean_errors += [torch.abs(scaler_a_b.mean_after_scale-scaler_a_b.target_mean).item()]
                    if param_id != 3: 
                        if mean_errors[-1] > error_threshold:
                            new_b = scaler_a_b.b + (scaler_a_b.target_mean - scaler_a_b.mean_after_scale)
                            scaler_a_b.b.fill_(alpha * new_b + (1-alpha) * scaler_a_b.b)
                            done_list += [False]
                        else:
                            done_list += [True]
                    else:
                        done_list += [True]
                    
            std_errors_plots = torch.cat((std_errors_plots, torch.tensor(std_errors).unsqueeze(1)), dim=-1)
            mean_errors_plots = torch.cat((mean_errors_plots, torch.tensor(mean_errors).unsqueeze(1)), dim=-1)
            fig, ax = plt.subplots(1, 2, figsize=(8, 4))
            clear_output(wait=True)
            for p, param in enumerate(range(len(std_errors_plots))):
                ax[0].plot(std_errors_plots[param], linewidth=0.5, color=colors[p])
                ax[0].plot(mean_errors_plots[param], '--', linewidth=0.5, color=colors[p])
            ax[0].set_yscale('log')
            ax[0].set_xlabel('iters')
            ax[0].hlines(error_threshold, 0., len(std_errors_plots[param])-1.0, colors='black', linewidth=2)
            ax[1].plot(torch.tensor(losses), linewidth=1)
            ax[1].set_xlabel('iters')
            ax[1].set_ylabel('Cross Entropy Loss')
            fig.tight_layout()
            plt.show()
            sys.stdout.flush()
            plt.pause(0.1)  # Pause to update the plot
            print(f'Calibraton iter {calibration_iter} | Loss: {loss.item():.4f} | error threshold: {error_threshold:.3f}\tnum valid params: {torch.sum(torch.tensor(done_list))}/{len(done_list)}\tstd errors: {torch.sort(torch.tensor(std_errors), descending=True)[0][:3]}\tmean errors: {torch.sort(torch.tensor(mean_errors), descending=True)[0][:3]}')
            calibration_iter += 1
            assert calibration_iter < max_calibration_iter, f'Calibration algorithm did not converge after {calibration_iter} steps.'
            if torch.all(torch.tensor(done_list)):
                print(f'Calibration finished after {calibration_iter} steps.')
                done = True          
    print(f"Losses tensor: {torch.tensor(losses)}")
    # End calibration procedure
    # for layer in model.module.transformer.h:
    for layer in model.transformer.h:
        for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.output_scaler]:
        # for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
            scaler_a_b.calibration = False

Inference

In [8]:
model_linear = model_linear.to(device_linear)
model = model.to(device)
model_linear.eval()
model.eval()
torch.cuda.empty_cache()
ppl_linear = Perplexity().to(device_linear)
ppl = Perplexity().to(device)

avg_loss_linear = []
avg_loss = []
avg_ppl_linear = []
avg_ppl = []

with torch.no_grad():
    for i in range(100):   
        X_linear, Y_linear = get_batch('test', device_linear)     
        logits_linear, loss_linear = model_linear(X_linear, targets=Y_linear)
        X, Y = get_batch('test', device)  
        logits, loss = model(X, targets=Y)
        layer = 0        
        # a = model_linear.transformer.h[layer].attn.x_cache.std()
        # b = model_linear.transformer.h[layer].attn.x_cache.mean()
        # c = model.transformer.h[layer].attn.x_cache.std()
        # d = model.transformer.h[layer].attn.x_cache.mean()
        # print(f'Batch iteration: {i}\tLinear loss: {loss_linear:.2f}\tNonlinear loss: {loss:.2f}\t Linear model std mean: {a:.3f}, {b:.3f}\t Nonlinear model std mean: {c:.3f}, {d:.3f}')
        
        # ppl_exp = torch.exp(loss) -> same measure as perplexity
        ppl_linear_model = ppl_linear(logits_linear, Y_linear)
        ppl_model = ppl(logits, Y)
        
        avg_loss_linear += [loss_linear.item()]
        avg_loss += [loss.item()]
        avg_ppl_linear += [ppl_linear_model.item()]
        avg_ppl += [ppl_model.item()]
        
        print(f'Batch iteration: {i}\tLinear loss: {loss_linear:.2f}\tNonlinear loss: {loss:.2f}\tppl linear: {ppl_linear_model:.2f}\tppl: {ppl_model:.2f}')

avg_loss_linear = np.array(avg_loss_linear).mean()
avg_loss = np.array(avg_loss).mean()
avg_ppl_linear = np.array(avg_ppl_linear).mean()
avg_ppl = np.array(avg_ppl).mean()

print(f'Average: Linear loss: {avg_loss_linear:.2f}\tNonlinear loss: {avg_loss:.2f}\tppl linear: {avg_ppl_linear:.2f}\tppl: {avg_ppl:.2f}')

# torch.save({'model': model_linear.state_dict()}, '/Users/leroux/sEMG/saved_models/gpt2_with_scaling_statistics.pt')

Batch iteration: 0	Linear loss: 3.15	Nonlinear loss: 3.09	ppl linear: 23.32	ppl: 21.95
Batch iteration: 1	Linear loss: 3.05	Nonlinear loss: 3.11	ppl linear: 21.04	ppl: 22.37
Batch iteration: 2	Linear loss: 3.13	Nonlinear loss: 3.26	ppl linear: 22.94	ppl: 25.98
Batch iteration: 3	Linear loss: 3.14	Nonlinear loss: 3.11	ppl linear: 23.07	ppl: 22.38
Batch iteration: 4	Linear loss: 3.11	Nonlinear loss: 2.99	ppl linear: 22.35	ppl: 19.79
Batch iteration: 5	Linear loss: 3.16	Nonlinear loss: 3.15	ppl linear: 23.47	ppl: 23.41
Batch iteration: 6	Linear loss: 3.05	Nonlinear loss: 3.19	ppl linear: 21.14	ppl: 24.40
Batch iteration: 7	Linear loss: 3.01	Nonlinear loss: 2.99	ppl linear: 20.35	ppl: 19.91
Batch iteration: 8	Linear loss: 3.11	Nonlinear loss: 3.27	ppl linear: 22.50	ppl: 26.25
Batch iteration: 9	Linear loss: 3.03	Nonlinear loss: 3.04	ppl linear: 20.61	ppl: 20.96
Batch iteration: 10	Linear loss: 3.01	Nonlinear loss: 3.22	ppl linear: 20.25	ppl: 25.14
Batch iteration: 11	Linear loss: 3.19	Nonl

Plot hists

In [None]:
with torch.no_grad():
    logits, loss = model(X, targets=Y)

fig, ax = plt.subplots(2, 3, figsize=(15/2.54, 10/2.54))
q_linear = model_linear.transformer.h[-1].attn.Q[0,:,0].flatten().cpu().clone()
q_linear_after_scale = model_linear.transformer.h[-1].attn.Q_after_scale[0,:,0].flatten().cpu().clone()
q_after_adapation = model.transformer.h[-1].attn.Q[0,:,0].flatten().cpu().clone()
q_after_adapation_after_scale = model.transformer.h[-1].attn.Q_after_scale[0,:,0].flatten().cpu().clone()
rwidth = 0.8
ax[0,0].hist(q_linear, rwidth=rwidth)
ax[1,0].hist(q_linear_after_scale, rwidth=rwidth)
ax[0,1].hist(q, rwidth=rwidth)
ax[1,1].hist(q_after_scale, rwidth=rwidth)
ax[0,2].hist(q_after_adapation, rwidth=rwidth)
ax[1,2].hist(q_after_adapation_after_scale, rwidth=rwidth)
fig.tight_layout()
plt.show()

Test multiples quantization (without quantization aware training)

In [None]:
X, Y = get_batch('test')
input_levels = torch.tensor([2**32, 64, 32, 16, 8, 4, 2])
weights_levels = torch.tensor([2**32, 64, 32, 16, 8, 4, 2])
# input_levels = torch.tensor([2**32, 64, 32, 16])
# weights_levels = torch.tensor([2**32, 64, 32, 16])
input_levels_len = len(input_levels)
weights_levels_len = len(weights_levels)
losses = torch.zeros(input_levels_len, weights_levels_len)
losses_std = torch.zeros(input_levels_len, weights_levels_len)
for i, input_level in enumerate(input_levels):
    for j, weights_level in enumerate(weights_levels):
        config_args.update({"quantization_levels_input": input_level})
        config_args.update({"quantization_levels_weights": weights_level})
        config = GPTConfig(**config_args)
        model = GPT(config)
        model.load_state_dict(model_ld, strict=False)
        model.eval()
        model.to(device)
        with torch.no_grad():
            logits, loss = model(X, Y)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1, reduce=False)
            loss = loss.view(config_args['batch_size'], config_args['block_size'])
            loss = loss.mean(dim=-1) # average over tokens
            loss_mean, loss_std = loss.mean(), loss.std()
        losses[i, j] = loss_mean.item()
        losses_std[i, j] = loss_std.item()
        print(f'input levels: {input_level}\tweights_level: {weights_level}\tLoss: {loss_mean.item():.4f}\tStandard dev: {loss_std.item():.4f}')

print('losses mean', losses)
print('losses std', losses_std)

Saved quantize experience with "checkpoints/gpt2-DRAMAttention_temporal_encoding_real_grad_lr_1e-5/ckpt" (output quantize, outpus clamp, decay not implemented yet)

In [None]:
input_levels = torch.tensor([2**32, 64, 32, 16, 8, 4, 2])
weights_levels = torch.tensor([2**32, 64, 32, 16, 8, 4, 2])

losses = torch.tensor([[3.1784, 3.1790, 3.1806, 3.1893, 3.2325, 4.0268, 9.2421],
        [3.1793, 3.1801, 3.1823, 3.1919, 3.2328, 4.0729, 9.3272],
        [3.1813, 3.1823, 3.1826, 3.1950, 3.2400, 4.0452, 9.2603],
        [3.2022, 3.2033, 3.2034, 3.2152, 3.2580, 4.1199, 9.1787],
        [3.4788, 3.4855, 3.4964, 3.5085, 3.6436, 4.9188, 9.1064],
        [5.4929, 5.4719, 5.4791, 5.5990, 6.0570, 7.7586, 8.8878],
        [8.2598, 8.2439, 8.2380, 8.2542, 8.2821, 8.2791, 9.0375]])

losses_std = torch.tensor([[0.3129, 0.3121, 0.3144, 0.3122, 0.3117, 0.2819, 0.4138],
        [0.3136, 0.3135, 0.3140, 0.3088, 0.3106, 0.3500, 0.4447],
        [0.3113, 0.3111, 0.3126, 0.3101, 0.3121, 0.3053, 0.5186],
        [0.3115, 0.3132, 0.3124, 0.3110, 0.3075, 0.3048, 0.5548],
        [0.2694, 0.2728, 0.2812, 0.2676, 0.2618, 0.2552, 0.6936],
        [0.7377, 0.7313, 0.7468, 0.6983, 0.6816, 0.3085, 0.4240],
        [0.2267, 0.2179, 0.2080, 0.2623, 0.1870, 0.1764, 0.1862]])

Plots

In [None]:
fontsize = 12
# Errorbar plots
fig, ax = plt.subplots()
fig.set_figwidth(5)
fig.set_figheight(3)
weights_levels_ = weights_levels[1:-1]
input_levels_ = input_levels[1:-1]
losses_ = losses[1:-1]
losses_ = losses_[:, 1:-1]
losses_std_ = losses_std[1:-1]
losses_std_ = losses_std_[:, 1:-1]
cmap = colormaps['plasma']
for i, input_level in enumerate(input_levels_):
    ax.errorbar(torch.log2(weights_levels_), losses_[i], yerr=losses_std_[i], fmt='-o', linewidth=1, c=cmap(i/len(input_levels_)), label=f'{torch.log2(input_level).item():.0f} input bits')
    ax.hlines(losses[i, 0], torch.log2(weights_levels_[0]), torch.log2(weights_levels_[-1]), linestyles='--', linewidth=1.0, color=cmap(i/len(input_levels_)),
            #   label=f'{torch.log2(input_level).item():.0f} input bits reference',
              )
ax.set_xlabel('# KV bits', fontsize=fontsize)
ax.set_ylabel('Cross-entropy loss', fontsize=fontsize)
# ax.invert_xaxis()
ax.set_yscale('log')
ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper right', frameon=False, fontsize=fontsize)
plt.xticks([2, 3, 4, 5, 6])
plt.yticks([3, 4, 5, 6, 7, 8])
from matplotlib.ticker import StrMethodFormatter, NullFormatter
ax.yaxis.set_major_formatter(StrMethodFormatter('{x:.0f}'))
# ax.set_yscale('log')
fig.tight_layout()

file_out = '/Users/leroux/sEMG/python_codes/plots/KV_and_Q_quantization_results'
for fmt in ['png', 'svg', 'pdf']:
    plt.savefig(file_out + '.%s' % fmt, format=fmt, dpi=1200)

plt.show()

# 3D plot
input_levels_len = len(input_levels)
weights_levels_len = len(weights_levels)

ax = plt.figure().add_subplot(projection='3d')
ax.view_init(elev=30, azim=45, roll=0)

# input_levels_ = input_levels[1:-1]
# weights_levels_ = weights_levels[1:-1]
# losses_ = losses[1:-1]
# losses_ = losses_[:,1:-1]

input_levels_ = input_levels[1:-2]
weights_levels_ = weights_levels[1:-2]
losses_ = losses[1:-2]
losses_ = losses_[:,1:-2]

# x_mesh, y_mesh = np.meshgrid(torch.log2(input_levels_), torch.log2(weights_levels_))
x_mesh, y_mesh = torch.log2(input_levels_).unsqueeze(-1).expand(-1,len(weights_levels_)), torch.log2(weights_levels_).unsqueeze(0).expand(len(input_levels_),-1),
x, y = x_mesh.ravel(), y_mesh.ravel()
top = losses_.cpu().ravel()
bottom = losses_.min().cpu() * torch.ones_like(top)
width = depth = 1

ax.bar3d(x, y, bottom, dx=1, dy=1, dz=top-bottom, shade=True)
ax.set_xlabel('# input bits')
ax.set_ylabel('# KV bits')
ax.set_zlabel('Cross entropy loss')
# ax.set_zscale('log')
plt.show()

Save statistics histogram

In place histograms

In [None]:
if True:
    total_q = torch.zeros(0)
    total_k = torch.zeros(0)
    total_v = torch.zeros(0)
    total_A = torch.zeros(0)
    total_out = torch.zeros(0)
    for layer in model.transformer.h:
        total_q = torch.cat((total_q, layer.attn.bins_count_q.data), dim=0)
        total_k = torch.cat((total_k, layer.attn.bins_count_k.data), dim=0)
        total_v = torch.cat((total_v, layer.attn.bins_count_v.data), dim=0)
        total_A = torch.cat((total_A, layer.attn.bins_count_A.data), dim=0)
        total_out = torch.cat((total_out, layer.attn.bins_count_out.data), dim=0)
        break
    
    total_density_q, bins_edges_q = torch.histogram(total_q, bins=config_args['quantization_levels_input'], density=False)
    total_density_k, bins_edges_k = torch.histogram(total_k, bins=config_args['quantization_levels_weights'], density=False)
    total_density_v, bins_edges_v = torch.histogram(total_v, bins=config_args['quantization_levels_weights'], density=False)
    total_density_A, bins_edges_A = torch.histogram(total_A, bins=config_args['quantization_levels_input'], density=False)
    total_density_out, bins_edges_out = torch.histogram(total_out, bins=config_args['quantization_levels_output'], density=False)
    
    total_density_q /= total_q.numel()
    total_density_k /= total_k.numel()
    total_density_v /= total_v.numel()
    total_density_A /= total_A.numel()
    total_density_out /= total_out.numel()

    hist_to_save = {'Q': {'density': total_density_q.to(torch.float).numpy(), 'bins_edges': bins_edges_q[:-1].to(torch.float).numpy()},
                    'K': {'density': total_density_k.to(torch.float).numpy(), 'bins_edges': bins_edges_k[:-1].to(torch.float).numpy()},
                    'V': {'density': total_density_v.to(torch.float).numpy(), 'bins_edges': bins_edges_v[:-1].to(torch.float).numpy()},
                    'Attention': {'density': total_density_A.to(torch.float).numpy(), 'bins_edges': bins_edges_A[:-1].to(torch.float).numpy()},
                    'Output': {'density': total_density_out.to(torch.float).numpy(), 'bins_edges': bins_edges_out[:-1].to(torch.float).numpy()},
    }
    
    file_name = f"./{quantization}_levels_histogram.npz"
    np.savez(file_name, **hist_to_save, allow_pickle=True)
    
    network_histogram = np.load(file_name, allow_pickle=True)
    for (key, value) in network_histogram.items():
        print(key, value)

Hist density plots

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# file_name = f"./16_levels_histogram_temporal_encoding.npz"
# file_name = f"./256_levels_histogram_saturated_layer_1.npz"
network_histogram = np.load(file_name, allow_pickle=True)
fig, ax = plt.subplots(1, len(network_histogram) - 1)
fig.set_figwidth(10)
fig.set_figheight(2.5)
fontsize=12
names = ['Q', 'K', 'V', 'Attention', 'Output']
for i, (key, param) in enumerate(network_histogram.items()):
    if i==len(network_histogram)-1:
        break
    param = param[()]
    if i==3:
        pass
    
    ax[i].bar(param['bins_edges'], param['density'], color='darkblue', width=0.05) 
    ax[i].plot(param['bins_edges'], param['density'], color='darkblue', linewidth=2)         
    ax[i].set_xlabel(names[i], fontsize=fontsize)
    if i==0:
        ax[i].set_ylabel('Frequency')  
    ax[i].set_yscale('log')
fig.tight_layout()
file_out = '/Users/leroux/sEMG/python_codes/plots/values_density_log'
# file_out = '/Users/leroux/sEMG/python_codes/plots/values_density'
# for fmt in ['png', 'svg', 'pdf']:
#     plt.savefig(file_out + '.%s' % fmt, format=fmt, dpi=1200)
plt.show()

Histogram compare results

In [None]:
probability_linear = torch.nn.functional.softmax(logits_linear, dim=-1)
probability = torch.nn.functional.softmax(logits, dim=-1)
n_samples, n_tokens, vocab_size = logits_linear.shape
token = int(torch.rand(1)*n_tokens)
target = torch.nn.functional.one_hot(Y, num_classes=vocab_size)
xaxis = torch.arange(0, vocab_size)
n_values = 200
for i in range(n_samples):
    fig, ax = plt.subplots()
    clear_output(wait=True)
    value = Y[i, token]
    max_idx_linear = torch.argmax(probability_linear[i, token])
    max_idx_other= torch.argmax(probability[i, token])
    print(f'Target word: {value}\t Linear model word: {max_idx_linear}\t Other network word: {max_idx_other}') 
    ax.bar(xaxis[value-n_values:value+n_values], target[i, token, value-n_values:value+n_values].detach().cpu().numpy(), alpha=0.5, label='Target', color='blue')
    ax.bar(xaxis[value-n_values:value+n_values], probability_linear[i, token, value-n_values:value+n_values].detach().cpu().numpy(), alpha=0.5, label='Linear out', color='red')
    ax.bar(xaxis[value-n_values:value+n_values], probability[i, token, value-n_values:value+n_values].detach().cpu().numpy(), alpha=0.5, label='Normal network out', color='green')
    ax.legend(loc='upper right')
    ax.set_title(f'Output probabillities')
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    plt.show()
    sys.stdout.flush()
    plt.pause(1.0) # Pause to update the plot

Histogram compare intermediate results

In [None]:
with torch.no_grad():
    # for i in range(len(model_linear.transformer.h)):
    for i in range(0):
        out_linear = model_linear.transformer.h[i].attn.x_cache
        out = model.transformer.h[i].attn.x_cache
        fig, ax = plt.subplots(1, 2)
        clear_output(wait=True)
        mask = model_linear.transformer.h[i].attn.masking_1.mask.expand(config_args['batch_size'], config_args['n_head'], -1, -1).flatten()!=0
        mask = torch.ones_like(out, dtype=torch.bool).flatten()
        ax[0].hist(
                    out_linear.flatten()[mask].detach().cpu().numpy(),
                    alpha=0.5,
                    label='Linear out',
                #    density=True,
                #    bins='auto',
                    bins=100,
                )
        ax[0].hist(
                    out.flatten()[mask].detach().cpu().numpy(),
                    alpha=0.5,
                    label='Nonlinear out',
                    # density=True,
                    # bins='auto',
                    bins=100,
                )
        ax[0].legend(loc='upper right')
        ax[0].set_title(f'Distribution attention layer {i}')
        ax[0].set_xlabel('Value')
        ax[0].set_ylabel('Frequency')
        
        mask = model_linear.transformer.h[i].attn.masking_1.mask[0,0].flatten()!=0
        mask = torch.ones_like(out[0,0], dtype=torch.bool).flatten()
        ax[1].plot(out_linear[0,0].flatten()[mask].detach().cpu().numpy(), out_linear[0,0].flatten()[mask].detach().cpu().numpy(), 'k', linewidth=1, label='Linear vs linear')
        ax[1].scatter(out_linear[0,0].flatten()[mask].detach().cpu().numpy(), out_linear[0,0].flatten()[mask].detach().cpu().numpy(), color='darkblue', s=1, label='Linear vs linear')
        ax[1].scatter(out_linear[0,0].flatten()[mask].detach().cpu().numpy(), out[0,0].flatten()[mask].detach().cpu().numpy(), color='darkred', s=1, label='Nonlinear vs linear')
        
        # ax[1].scatter(out[0,0].flatten()[mask].detach().cpu().numpy(), out[0,0].flatten()[mask].detach().cpu().numpy(), color='darkblue', s=1, label='Nonlinear vs Nonlinear')
        # ax[1].scatter(out[0,0].flatten()[mask].detach().cpu().numpy(), out_linear[0,0].flatten()[mask].detach().cpu().numpy(), color='darkred', s=1, label='Linear vs Nonlinear')
        
        ax[1].legend(loc='upper right')
        ax[1].set_title(f'Attention layer {i}')
        ax[1].set_xlabel('X')
        ax[1].set_ylabel('Y')
        
        plt.tight_layout()
        plt.show()
        sys.stdout.flush()
        plt.pause(0.1) # Pause to update the plot

Trye optimize through backprop

In [None]:
model_linear = model_linear.to(device)
model_linear.eval()
model = model.to(device)
model.eval()
calib_lr = 1e-4
max_iter = 10000
alpha = 0.1

loss_list = []
loss_total = []
loss = 1000.
# parameters = []
# for layer in model.transformer.h:
#     parameters += list(layer.attn.q_scaler.parameters())
#     parameters += list(layer.attn.k_scaler.parameters())
#     parameters += list(layer.attn.v_scaler.parameters())
#     parameters += list(layer.attn.att_score_scaler.parameters())
#     parameters += list(layer.attn.output_scaler.parameters())
parameters = model.parameters()
optim = torch.optim.AdamW(params=parameters, lr=calib_lr)
for iter in range(max_iter):        
    optim.zero_grad()
    X, Y = get_batch('train')
    device = X.device
    b, t = X.size()
    pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
    # Linear model
    with torch.no_grad():        
        out_linear, cel_linear = model_linear(X, Y)
    # Nonlinear model
    out, cel = model(X, Y)
    if iter==0:
        loss_total += [cel.item()]
    else:
        loss_total += [cel.item()*alpha + loss_total[-1]*(1-alpha)]
    loss = cel.clone()
    # loss = nn.MSELoss()(out, out_linear)
    # loss = 0.5 * cel + 0.5 * loss
    loss_list += [loss.item()]
    loss.backward()
    optim.step()
    print(f'iter: {iter}\tLoss: {loss.item():.4f}\tRunning average loss: {loss_total[-1]:.4f}')
    if iter%10==0:
        fig, ax = plt.subplots(1,3)
        clear_output(wait=True)
        # ax[0].scatter(out_linear[0].flatten().detach().cpu().numpy(), out_linear[0].flatten().detach().cpu().numpy(), color='darkblue', s=1, label='Linear vs linear')
        # ax[0].scatter(out_linear[0].flatten().detach().cpu().numpy(), out[0].flatten().detach().cpu().numpy(), color='darkred', s=1, label='Nonlinear vs linear')
        ax[0].scatter(out_linear[0,0].flatten().detach().cpu().numpy(), out_linear[0,0].flatten().detach().cpu().numpy(), color='darkblue', s=1, label='Linear vs linear')
        ax[0].scatter(out_linear[0,0].flatten().detach().cpu().numpy(), out[0,0].flatten().detach().cpu().numpy(), color='darkred', s=1, label='Nonlinear vs linear')
        ax[1].plot(torch.tensor(loss_list), linewidth=1)
        ax[1].set_yscale('log')
        ax[1].set_ylabel('Loss')        
        ax[2].plot(torch.tensor(loss_total), linewidth=1)
        ax[2].set_ylabel('Global cross entropy loss')
        fig.tight_layout()
        plt.show()
        sys.stdout.flush()
        plt.pause(0.1)  # Pause to update the plot

Trye optimize through backprop layer-wize

In [None]:
model_linear = model_linear.to(device)
model_linear.eval()
model = model.to(device)
model.eval()
calib_lr = 1e-4
max_iter = 10000
loss_threshold = 1e-2
layer_num = -1

loss_list = []
loss_total = []
loss = 1000.
# parameters = []
for iter in range(max_iter):
    if loss < loss_threshold or iter==0:
        layer_num += 1
        if layer_num+1 > config_args['n_layer']:
            break
        layer = model.transformer.h[layer_num]
        parameters = []
        parameters += list(layer.attn.q_scaler.parameters())
        parameters += list(layer.attn.k_scaler.parameters())
        parameters += list(layer.attn.v_scaler.parameters())
        parameters += list(layer.attn.att_score_scaler.parameters())
        parameters += list(layer.attn.output_scaler.parameters())        
        optim = torch.optim.AdamW(params=parameters, lr=calib_lr)
        
    optim.zero_grad()
    X, Y = get_batch('train')
    device = X.device
    b, t = X.size()
    pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
    
    with torch.no_grad():
        # Linear model
        tok_emb = model_linear.transformer.wte(X) # token embeddings of shape (b, t, n_embd)
        pos_emb = model_linear.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        out_linear = model_linear.transformer.drop(tok_emb + pos_emb)
        for l in range(layer_num):
            out_linear = model_linear.transformer.h[l](out_linear)
        out_linear = model_linear.transformer.h[layer_num](out_linear)
        # out_linear = model_linear.transformer.h[layer_num].attn(model_linear.transformer.h[layer_num].ln_1(out_linear))
        # Nonlinear model
        tok_emb = model.transformer.wte(X) # token embeddings of shape (b, t, n_embd)
        pos_emb = model.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        out = model.transformer.drop(tok_emb + pos_emb)
        for l in range(layer_num):
            out = model.transformer.h[l](out)
    out = model.transformer.h[layer_num](out)
    # out = model.transformer.h[layer_num].attn(model.transformer.h[layer_num].ln_1(out))

    loss = nn.MSELoss()(out, out_linear)
    # loss = nn.CrossEntropyLoss()(torch.nn.functional.softmax(out, dim=-1), torch.nn.functional.softmax(out_linear, dim=-1))
    loss_list += [loss.item()]
    loss.backward()
    optim.step()
    print(f'iter: {iter}\tLoss: {loss.item():.4f}\tLayer: {layer_num}')
    if iter%10==0:
        fig, ax = plt.subplots(1,3)
        clear_output(wait=True)
        ax[0].scatter(out_linear[0].flatten().detach().cpu().numpy(), out_linear[0].flatten().detach().cpu().numpy(), color='darkblue', s=1, label='Linear vs linear')
        ax[0].scatter(out_linear[0].flatten().detach().cpu().numpy(), out[0].flatten().detach().cpu().numpy(), color='darkred', s=1, label='Nonlinear vs linear')
        ax[1].plot(torch.tensor(loss_list), linewidth=1)
        ax[1].set_yscale('log')
        ax[1].set_ylabel('Cross model layer loss')
        with torch.no_grad():
            logits, loss_global = model(X, Y)
        loss_total += [loss_global.item()]
        ax[2].plot(torch.tensor(loss_total), linewidth=1)
        ax[2].set_ylabel('Global cross entropy loss')
        fig.tight_layout()
        plt.show()
        sys.stdout.flush()
        plt.pause(0.1)  # Pause to update the plot

Calibration with stop updating after multiple valid values for specific parameters

In [None]:
if True:
    model = model.to(device)
    model.train()
    done = False
    error_threshold = 0.001    
    max_calibration_iter = 1000
    alpha_min = 0.01
    alpha_decay_max_step = 50
    # alpha_calibration = torch.cos(torch.arange(alpha_decay_max_step) / alpha_decay_max_step * pi / 2) + alpha_min
    alpha_calibration = torch.linspace(1, alpha_min, alpha_decay_max_step)
    
    calibration_iter = 0
    with torch.no_grad():
        std_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
        mean_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
        done_list = torch.zeros(6 * config_args['n_layer'] * 2, 0, dtype=torch.bool)
        while not(done):
            done_list = torch.cat((done_list, torch.zeros(72 * 2, dtype=torch.bool).unsqueeze(1)), dim=-1)
            logits, loss = model(X, Y)
            # Init a and b w.r.t computed statistics
            std_errors = []
            mean_errors = []            
            if calibration_iter < alpha_decay_max_step:
                alpha = alpha_calibration[calibration_iter].item()
            else:
                alpha = alpha_min
                
            param_id = 0
            for layer in model.transformer.h:
                for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
                    
                    std_errors += [torch.abs(scaler_a_b.std_after_scale-scaler_a_b.target_std).item()]                   
                    if std_errors[-1] > error_threshold:
                    # if (std_errors[-1] > error_threshold and done_list[param_id, :-10].sum()<10) or calibration_iter<10:
                        new_a = scaler_a_b.a * scaler_a_b.target_std / scaler_a_b.std_after_scale
                        scaler_a_b.a.fill_(alpha * new_a + (1-alpha) * scaler_a_b.a)
                    else:
                        done_list[param_id, -1] += True
                    param_id += 1
                    
                    mean_errors += [torch.abs(scaler_a_b.mean_after_scale-scaler_a_b.target_mean).item()]
                    if mean_errors[-1] > error_threshold:
                    # if (mean_errors[-1] > error_threshold and done_list[param_id, :-10].sum()<10) or calibration_iter<10:
                        new_b = scaler_a_b.b + (scaler_a_b.target_mean - scaler_a_b.mean_after_scale)
                        scaler_a_b.b.fill_(alpha * new_b + (1-alpha) * scaler_a_b.b)
                    else:
                        done_list[param_id, -1] += True                       
                    param_id += 1
                    
            std_errors_plots = torch.cat((std_errors_plots, torch.tensor(std_errors).unsqueeze(1)), dim=-1)
            mean_errors_plots = torch.cat((mean_errors_plots, torch.tensor(mean_errors).unsqueeze(1)), dim=-1)
            fig, ax = plt.subplots()
            clear_output(wait=True)
            for param in range(len(std_errors_plots)):
                ax.plot(std_errors_plots[param], linewidth=0.5)
                ax.plot(mean_errors_plots[param], linewidth=0.5)
                ax.set_yscale('log')
            ax.hlines(error_threshold, 0., len(std_errors_plots[param])-1.0, colors='black', linewidth=2)
            plt.show()
            sys.stdout.flush()
            plt.pause(0.1)  # Pause to update the plot
            print(f'Calibraton iter {calibration_iter} | error threshold: {error_threshold:.3f}\tnum valid params: {torch.sum(torch.tensor(done_list[:, -1]))}/{len(done_list[:, -1])}\tstd errors: {torch.sort(torch.tensor(std_errors), descending=True)[0][:3]}\tmean errors: {torch.sort(torch.tensor(mean_errors), descending=True)[0][:3]}')
            calibration_iter += 1
            assert calibration_iter < max_calibration_iter, f'Calibration algorithm did not converge after {calibration_iter} steps.'
            if torch.all(torch.tensor(done_list[:, -1])):
                done = True            
    # End calibration procedure
    # for layer in model.module.transformer.h:
    for layer in model.transformer.h:
        for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
            scaler_a_b.calibration = False
            

Try to optimize through gradient descent

In [None]:
# learning rate decay scheduler (cosine with warmup)
warmup_iters = 20
min_lr = 0.001
learning_rate = 0.1
lr_decay_iters = 100
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
if True:
    # for layer in model.transformer.h:
    for layer in model.transformer.h:
        for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
            scaler_a_b.calibration = True
       
    model_linear = model_linear.to(device)
    model_linear.train()
    model = model.to(device)
    model.train()
    done = False
    error_threshold = 0.01
    calibration_iter = 0
    max_calibration_iter = 10000
    calib_lr = 0.1
    
    parameters = []
    # for layer in model.transformer.h:
    for layer in model.transformer.h:
        parameters += list(layer.attn.q_scaler.parameters())
        parameters += list(layer.attn.k_scaler.parameters())
        parameters += list(layer.attn.v_scaler.parameters())
        parameters += list(layer.attn.att_score_scaler.parameters())
        parameters += list(layer.attn.NL_scaler.parameters())
        parameters += list(layer.attn.output_scaler.parameters())
        
    optim = torch.optim.AdamW(params=parameters, lr=calib_lr)
    
    std_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
    mean_errors_plots = torch.zeros(6 * config_args['n_layer'], 0)
    total_loss = torch.zeros(0)
    while not(done):
        X, Y = get_batch('train')
        
        with torch.no_grad():
            logits_linear, _ = model_linear(X, Y)  
        logits, _ = model(X, Y)
        # for param_group in optim.param_groups:
        #     param_group['lr'] = get_lr(calibration_iter)      
        optim.zero_grad()
        loss = 0.
        # for (layer, linear_layer) in zip(model.transformer.h, model_linear.transformer.h):
        #     loss += nn.L1Loss()(layer.attn.q_scaler.output, linear_layer.attn.q_scaler.output)
        #     loss += nn.L1Loss()(layer.attn.k_scaler.output, linear_layer.attn.k_scaler.output)
        #     loss += nn.L1Loss()(layer.attn.v_scaler.output, linear_layer.attn.v_scaler.output)
        #     loss += nn.L1Loss()(layer.attn.att_score_scaler.output, linear_layer.attn.att_score_scaler.output)
        #     loss += nn.L1Loss()(layer.attn.NL_scaler.output, linear_layer.attn.NL_scaler.output)
        #     loss += nn.L1Loss()(layer.attn.output_scaler.output, linear_layer.attn.output_scaler.output)
            
        for (layer, linear_layer) in zip(model.transformer.h, model_linear.transformer.h):
            loss += nn.L1Loss()(layer.attn.q_scaler.std_after_scale, linear_layer.attn.q_scaler.std_after_scale)
            loss += nn.L1Loss()(layer.attn.k_scaler.std_after_scale, linear_layer.attn.k_scaler.std_after_scale)
            loss += nn.L1Loss()(layer.attn.v_scaler.std_after_scale, linear_layer.attn.v_scaler.std_after_scale)
            loss += nn.L1Loss()(layer.attn.att_score_scaler.std_after_scale, linear_layer.attn.att_score_scaler.std_after_scale)
            loss += nn.L1Loss()(layer.attn.NL_scaler.std_after_scale, linear_layer.attn.NL_scaler.std_after_scale)
            loss += nn.L1Loss()(layer.attn.output_scaler.std_after_scale, linear_layer.attn.output_scaler.std_after_scale) 
            
            loss += nn.L1Loss()(layer.attn.q_scaler.mean_after_scale, linear_layer.attn.q_scaler.mean_after_scale)
            loss += nn.L1Loss()(layer.attn.k_scaler.mean_after_scale, linear_layer.attn.k_scaler.mean_after_scale)
            loss += nn.L1Loss()(layer.attn.v_scaler.mean_after_scale, linear_layer.attn.v_scaler.mean_after_scale)
            loss += nn.L1Loss()(layer.attn.att_score_scaler.mean_after_scale, linear_layer.attn.att_score_scaler.mean_after_scale)
            loss += nn.L1Loss()(layer.attn.NL_scaler.mean_after_scale, linear_layer.attn.NL_scaler.mean_after_scale)
            loss += nn.L1Loss()(layer.attn.output_scaler.mean_after_scale, linear_layer.attn.output_scaler.mean_after_scale)    
        
        total_loss = torch.cat((total_loss, loss.detach().cpu().unsqueeze(0)))
        
        with torch.no_grad():
            std_errors = []
            mean_errors = []
            done_list = []
            
            # for layer in model.transformer.h:
            for (layer, linear_layer) in zip(model.transformer.h, model_linear.transformer.h):
                for (scaler_a_b, linear_scaler_a_b) in zip([layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler], [linear_layer.attn.q_scaler, linear_layer.attn.k_scaler, linear_layer.attn.v_scaler, linear_layer.attn.att_score_scaler, linear_layer.attn.NL_scaler, linear_layer.attn.output_scaler]):                
                    std_errors += [torch.abs(scaler_a_b.std_after_scale-linear_scaler_a_b.std_after_scale).item()]
                    mean_errors += [torch.abs(scaler_a_b.mean_after_scale-linear_scaler_a_b.mean_after_scale).item()]
                    if std_errors[-1] > error_threshold:
                        done_list += [False]
                    else:
                        done_list += [True]
                    if mean_errors[-1] > error_threshold:
                        done_list += [False]
                    else:
                        done_list += [True]
            std_errors_plots = torch.cat((std_errors_plots, torch.tensor(std_errors).unsqueeze(1)), dim=-1)
            mean_errors_plots = torch.cat((mean_errors_plots, torch.tensor(mean_errors).unsqueeze(1)), dim=-1)
            if calibration_iter % 10 == 0:
                fig, ax = plt.subplots(1,2)
                clear_output(wait=True)
                for param in range(len(std_errors_plots)):
                    ax[0].plot(std_errors_plots[param], linewidth=0.5)
                    ax[0].plot(mean_errors_plots[param], linewidth=0.5)
                    ax[0].set_yscale('log')
                ax[0].hlines(error_threshold, 0., len(std_errors_plots[param])-1.0, colors='black', linewidth=2)
                ax[1].plot(total_loss, linewidth=2)
                ax[1].set_ylabel('Loss')
                ax[1].set_yscale('log')
                fig.tight_layout()
                plt.show()
                sys.stdout.flush()
                plt.pause(0.1)  # Pause to update the plot
            
            calibration_iter += 1
            
        loss.backward()
        optim.step()
        print(f'Calibraton iter {calibration_iter}| Loss: {loss.item():.3f} | error threshold: {error_threshold:.3f}\tnum valid params: {torch.sum(torch.tensor(done_list))}/{len(done_list)}\tstd errors: {torch.sort(torch.tensor(std_errors), descending=True)[0][:3]}\tmean errors: {torch.sort(torch.tensor(mean_errors), descending=True)[0][:3]}')
        print(model.transformer.h[0].attn.att_score_scaler.a.item())
        
        assert calibration_iter < max_calibration_iter, f'Calibration algorithm did not converge after {calibration_iter} steps.'
        if torch.all(torch.tensor(done_list)):
            done = True                     
    # End calibration procedure
    # for layer in model.transformer.h:
    for layer in model.transformer.h:
        for scaler_a_b in [layer.attn.q_scaler, layer.attn.k_scaler, layer.attn.v_scaler, layer.attn.att_score_scaler, layer.attn.NL_scaler, layer.attn.output_scaler]:
            scaler_a_b.calibration = False