In [1]:
import torch
import torch.nn as nn
import os
import sys
dir_name = os.getcwd()
parent_dir_name = os.path.dirname(dir_name)
sys.path.insert(0, parent_dir_name)
from modules.model_gpt import GPT, GPTConfig
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

Model import

In [None]:
device = 'cuda'
lr=1e-3
init_from = "gpt2-LinearDRAM"
model_ld = torch.load(f'/Users/leroux/sEMG/saved_models/{init_from}.pt')['model']
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
config_args = dict(attention="SlidingWindowAttention",
                   batch_size=1,
                   n_layer=12,
                   n_head=12,
                   n_embd=768,
                   dropout=0.,
                   vocab_size=50257,
                   block_size=1024,
                   bias=True,
                   )
config = GPTConfig(**config_args)
model_linear = GPT(config)
model_linear.load_state_dict(model_ld, strict=False)
config_args.update({"attention": "DRAMAttention"})
config = GPTConfig(**config_args)
model = GPT(config)
model.load_state_dict(model_ld, strict=False)
model_sd = model.state_dict()
pass

In [None]:

def fix_scale_and_shift_a_and_b(model):
    for layer in model.transformer.h:
        layer.attn.q_scaler.iter_num += 1000
        layer.attn.k_scaler.iter_num += 1000
        layer.attn.v_scaler.iter_num += 1000
        layer.attn.att_score_scaler.iter_num += 1000
        layer.attn.output_scaler.iter_num += 1000
    return model
model_linear = fix_scale_and_shift_a_and_b(model_linear)

Synthetic data

In [None]:
n_samples = 1
input = (torch.rand(n_samples, 1024) * 50257).long().to(device)
model_linear = model_linear.to(device)
model = model.to(device)
with torch.no_grad():
    out_linear = model_linear(input, input)
    out = model(input, input)

Histogram compare

In [None]:
# Plot histograms
plt.hist(out_linear[0][0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='Linear out')
plt.hist(out[0][0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='Nonlinear out')
plt.legend(loc='upper right')
plt.title('Distribution before training')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()

Train model to fit linear model

In [None]:
# for name, param in model.named_parameters():
#     if name.endswith('_scaler.a_param') or name.endswith('_scaler.b_param'):
#         pass
#     else:
#         param.requires_grad = False
# # Check that only the scaling factors are trained
# [print(name, param.requires_grad) for name, param in model.named_parameters()]
# for

In [None]:
class dataset_generator(torch.utils.data.Dataset):
    def __init__(self, batch_size, device):
        super().__init__()
        self.n_samples = batch_size*1000
        self.input = (torch.rand(self.n_samples, 1024) * 50257).long()
        self.device = device
        
    def __getitem__(self, idx):
        return self.input[idx].to(self.device)
    
    def __len__(self):
        return self.n_samples
    
batch_size = 1
dataset = torch.utils.data.DataLoader(dataset_generator(batch_size, device),
                                batch_size=batch_size,
                                shuffle=True,
                                drop_last=True,
                                )

In [None]:
def make_zero_grad_parameters(model):
    for name, param in model.named_parameters():
        if name.endswith('_scaler.a_param') or name.endswith('_scaler.b_param'):
            pass
        else:
            param.grad *= 0
    return model

In [None]:
lr = 1e-3
optim = torch.optim.AdamW(model.parameters(), lr=lr)
losses = []
a_param = []
b_param = []
# with torch.no_grad():
for i, x in enumerate(dataset):
    optim.zero_grad()
    target = model_linear(x, x)
    pred = model(x, x)
    target, pred = target[0], pred[0]
    loss = nn.MSELoss()(target, pred)
    losses += [loss.item()]
    a_param += [model.transformer.h[0].attn.output_scaler.a.item()]
    b_param += [model.transformer.h[0].attn.output_scaler.b.item()]
    loss.backward()
    model = make_zero_grad_parameters(model)
    optim.step()
    print(f"iter: {i} ({i/len(dataset)*100:.2f}%)\tLoss: {loss.item():.4f}")
    if i % 10 == 0:
        with torch.no_grad():
            # Initialize figure and axes
            fig, axes = plt.subplots(2, 2, figsize=(5, 5))
            # Clear previous plot and redraw the entire plot
            clear_output(wait=True)
            [ax.clear() for ax in axes[0]]
            [ax.clear() for ax in axes[1]]
            # Plot the new data
            time = torch.arange(0, i+1)
            axes[0,0].plot(time, np.array(losses), '-')
            axes[0,0].set_title(f"iter: {i} ({i/len(dataset)*100:.2f}%), Loss: {loss.item():.4f}")
            axes[0,0].set_xlabel('Iter')
            axes[0,0].set_ylabel('Loss')
            
            axes[0,1].plot(time, np.array(a_param))
            axes[0,1].set_xlabel('a_param')
            axes[0,1].set_yscale('log')
            
            # Plot histograms
            axes[1,0].hist(target[0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='Linear out')
            axes[1,0].hist(pred[0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='Nonlinear out')
            # axes[1,1].plot(time, time)
            axes[1,0].legend(loc='upper right')
            axes[1,0].set_title('Distribution after training')
            axes[1,0].set_xlabel('Value')
            axes[1,0].set_ylabel('Frequency')
            
            axes[1,1].plot(time, np.array(b_param))
            axes[1,1].set_xlabel('b_param')
            axes[1,1].set_yscale('log')
            
            plt.tight_layout()
            plt.show()
            sys.stdout.flush()
            plt.pause(0.1)  # Pause to update the plot


In [None]:
# Plot histograms
plt.hist(target[0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='Linear out')
plt.hist(pred[0][0].detach().cpu().numpy(), bins=100, alpha=0.5, label='NL out')
plt.legend(loc='upper right')
plt.title('Distribution after training')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()