|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[29] Hidden state dimensionality reduction</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

# for part 1
from PIL import Image
import requests
from io import BytesIO

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Reducing dimensions while preserving shape**

In [None]:
url = 'https://upload.wikimedia.org/wikipedia/commons/6/61/De_nieuwe_vleugel_van_het_Stedelijk_Museum_Amsterdam.jpg'

response = requests.get(url,headers={'User-Agent':'Mozilla/5.0 (compatible; ImageLoader/1.0)'})
img = Image.open(BytesIO(response.content))

# convert to grayscale and numpy
img = np.array( img.convert('L') ,dtype=float)
img = img[::2,::2] # downsample

plt.imshow(img,cmap='gray')

plt.tight_layout()
plt.savefig('ch5_proj29_part1a.png')
plt.show()

In [None]:
# SVD
U,s,Vh = np.linalg.svd(img)


# visualize
fig = plt.figure(figsize=(8,5))
gs = GridSpec(2,3,figure=fig)
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])
ax3 = fig.add_subplot(gs[0,2])
ax4 = fig.add_subplot(gs[1,:])

ax1.imshow(U,vmin=-.01,vmax=.01,cmap='plasma')
ax1.set(title='U')

ax2.imshow(np.diag(s),vmin=0,vmax=200)
ax2.set(title='$\\mathbf{\\Sigma}$')

ax3.imshow(Vh,vmin=-.01,vmax=.01,cmap='plasma')
ax3.set(title='V$^{\\top}$')

ax4.plot(s,'ks-',markerfacecolor=[.9,.7,.7],markersize=8)
ax4.set(xlabel='Singular values (index)',ylabel='$\\sigma$')

plt.tight_layout()
plt.savefig('ch5_proj29_part1b.png')
plt.show()

In [None]:
# convert to cumulative percent
pct = 100*s**2/np.sum(s**2)
cum_pct = np.cumsum(pct)


# components to keep
comps2keep = 50
S = np.zeros_like(img)
for i in range(comps2keep):
  S[i,i] = s[i]

# reconstruct the matrix
recon_img1 = U@S@Vh


_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(img,cmap='gray')
axs[0].set(title='Original image')

axs[1].imshow(recon_img1,cmap='gray')
axs[1].set(title=f'Recon. with {comps2keep}/{len(s)} components ({cum_pct[comps2keep]:.2f}% var.)')

plt.tight_layout()
plt.savefig('ch5_proj29_part1c.png')
plt.show()

In [None]:
# percent variance to keep
thresh = 98
r = np.where(cum_pct>thresh)[0][0]

# loopless method
S = np.zeros_like(img)
S[:r,:r] = np.diag(s[:r])

# reconstruction
recon_img2 = U@S@Vh

_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(img,cmap='gray')
axs[0].set(title='Original image')

axs[1].imshow(recon_img2,cmap='gray')
axs[1].set(title=f'Recon. with {cum_pct[r]:.2f}% variance ({r} comps)')

plt.tight_layout()
plt.savefig('ch5_proj29_part1d.png')
plt.show()

In [None]:
print(f'Rank of original matrix: {np.linalg.matrix_rank(img)}')
print(f'Rank of r = 50 matrix  : {np.linalg.matrix_rank(recon_img1)}')
print(f'Rank of p = {thresh}% matrix : {np.linalg.matrix_rank(recon_img2)}')

# **Part 2: Dimension-reduce one hidden state**

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Eleuther's tokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-1.3b')
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3b') # can use '125m' instead of '1.3b'
model.eval()
model.to(device)

In [None]:
# https://en.wikipedia.org/wiki/Stedelijk_Museum_Amsterdam

text = 'The Stedelijk Museum Amsterdam (Dutch pronunciation: [ˈsteːdələk myˈzeːjʏm ˌɑmstərˈdɑm]; Municipal Museum Amsterdam), colloquially known as the Stedelijk, is a museum for modern art, contemporary art, and design located in Amsterdam, Netherlands.[6] The 19th-century building was designed by Adriaan Willem Weissman. The connecting 21st-century wing, which houses the current entrance, was designed by Benthem Crouwel Architects. The museum is located at the Museum Square in the borough Amsterdam South,[2] where it is close to the Van Gogh Museum, the Rijksmuseum, and the Concertgebouw.'
tokens = tokenizer.encode(text,return_tensors='pt').to(device)

# clean run
layer2reduce = 4000
with torch.no_grad():
  outs_clean = model(tokens,output_hidden_states=True)

outs_clean.hidden_states[0].shape

In [None]:
hs = outs_clean.hidden_states[4][0,:,:].cpu()

# SVD of mean-centered data
vectmeans0 = hs.mean(dim=0,keepdims=True)
vectmeans1 = hs.mean(dim=1,keepdims=True)
globalmean = hs.mean()
hs_demean = hs - vectmeans0 - vectmeans1 + globalmean
U,s,Vh = torch.linalg.svd(hs_demean)

fig = plt.figure(figsize=(8,5))
gs = GridSpec(2,3,figure=fig)
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])
ax3 = fig.add_subplot(gs[0,2])
ax4 = fig.add_subplot(gs[1,:])

ax1.imshow(U,vmin=-.1,vmax=.1,cmap='plasma')
ax1.set(title='U')

ax2.imshow(torch.diag(s),vmin=0,vmax=100)
ax2.set(title='$\\mathbf{\\Sigma}$')

ax3.imshow(Vh,vmin=-.01,vmax=.01,cmap='plasma')
ax3.set(title='V$^{\\top}$')

ax4.plot(s,'ks-',markerfacecolor=[.9,.7,.7],markersize=8)
ax4.set(xlabel='Singular values (index)',ylabel='$\\sigma$')

plt.tight_layout()
plt.savefig('ch5_proj29_part2a.png')
plt.show()

In [None]:
# reduction threshold
svd_thresh = 95

# find number of components to threshold
cum_pct = torch.cumsum(100*s**2/torch.sum(s**2),dim=-1)
r = torch.where(cum_pct>svd_thresh)[0][0]

# create a new Sigma matrix
S = torch.zeros(hs.shape)
S[:r,:r] = torch.diag(s[:r])

# reconstruct data (with means added back)
hs_R = U@S@Vh + vectmeans0 + vectmeans1 - globalmean

print('hidden state shape:\n ',hs.shape,'\n')
print('reconstruction shape:\n ',hs_R.shape)

In [None]:
# visualize
_,axs = plt.subplots(2,2,figsize=(12,8))

axs[0,0].imshow(hs,aspect='auto',vmin=-2,vmax=2,cmap='rainbow')
axs[0,0].set(xlabel='Embeddings indices',ylabel='Token position',title='Original HS')

axs[0,1].imshow(hs_R,aspect='auto',vmin=-2,vmax=2,cmap='rainbow')
axs[0,1].set(xlabel='Embeddings indices',ylabel='Token position',title=f'Reconstruction (r={r}, {svd_thresh}% var.)')

axs[1,0].imshow(hs-hs_R,aspect='auto',vmin=-2,vmax=2,cmap='rainbow')
axs[1,0].set(xlabel='Embeddings indices',ylabel='Token position',title='Difference')

axs[1,1].plot(hs.flatten()[::10],hs_R.flatten()[::10],'ko',markersize=8,markerfacecolor=[.7,.9,.7,.3])
axs[1,1].set(xlabel='Original HS',ylabel='Reconstructed HS',
              title=f'r = {np.corrcoef(hs.flatten(),hs_R.flatten())[0,1]:.5f}')

plt.tight_layout()
plt.savefig('ch5_proj29_part2b.png')
plt.show()

In [None]:
print(f'Rank of original hidden-state matrix: {torch.linalg.matrix_rank(hs)}')
print(f'Rank of reduced hidden-state matrix : {torch.linalg.matrix_rank(hs_R):3}')

# **Part 3: A hook to reduce hidden state dimensionality**

In [None]:
svd_thresh = 99 # %
layer2reduce = 4
compsCount = []

# hooking functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # only change one layer
    if layer_number == layer2reduce:

      # unpack tuple
      hidden, *rest = output

      # extract hidden state and numpyify
      hs = hidden[0,:,:].detach()
      vectmeans0 = hs.mean(dim=0,keepdims=True)
      vectmeans1 = hs.mean(dim=1,keepdims=True)
      globalmean = hs.mean()
      hs_demean = hs - vectmeans0 - vectmeans1 + globalmean

      # SVD
      U,s,Vh = torch.linalg.svd(hs_demean)

      # find number of components to threshold
      cum_pct = torch.cumsum(100*s**2/torch.sum(s**2),dim=-1)
      r = torch.where(cum_pct>svd_thresh)[0][0]
      compsCount.append(r.item()) # update count (back on the cpu)

      # create a new Sigma matrix
      S = torch.zeros(hidden.shape[1:],device=hs.device)
      S[:r,:r] = torch.diag(s[:r])

      # print some info
      print(f'Changed layer {layer_number:2} with {cum_pct[r]:.2f}% variance ({r:2}/{len(s)} comps)')

      # reconstruct data (with means added back)
      hidden_copy = hidden.clone() # make a copy
      hidden_copy[0,:,:] = U@S@Vh + vectmeans0 + vectmeans1 - globalmean
      output = tuple([hidden_copy]+rest)

    return output
  return hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.num_layers):
  baselayer = model.transformer.h[layeri]
  h = baselayer.register_forward_hook(implant_hook(layeri))
  handles.append(h)

In [None]:
# with reduction
layer2reduce = 4
with torch.no_grad():
  outs_reduced = model(tokens,output_hidden_states=True)

outs_reduced.hidden_states[0].shape

In [None]:
# initialize
diffnorms = torch.zeros(model.config.num_layers)

# loop over layers
for layeri in range(model.config.num_layers):

  # extract hidden states for this layer
  hs_c = outs_clean.hidden_states[layeri+1]
  hs_n = outs_reduced.hidden_states[layeri+1]

  # norm of difference matrix
  diffnorms[layeri] = torch.norm(hs_c-hs_n).cpu()

# and plot
plt.figure(figsize=(10,3))
plt.plot(diffnorms,'kh',markerfacecolor=[.9,.7,.7],markersize=12)
plt.axvline(layer2reduce,color='k',linestyle='--',zorder=-3)
plt.axhline(0,color='k',linestyle=':',zorder=-10)
plt.gca().set(xlabel='Layer',ylabel='Norm of difference')

plt.tight_layout()
plt.savefig('ch5_proj29_part3.png')
plt.show()

# **Part 4: Laminar-sweep reduction experiment**

In [None]:
# clean run (again)
layer2reduce = 4000
with torch.no_grad():
  outs_clean = model(tokens,output_hidden_states=True,labels=tokens)

In [None]:
# initializations
diffnorms = np.zeros((2,model.config.num_layers))
pertoken_diff = np.zeros((model.config.num_layers,tokens.shape[1]))
losses = np.zeros(model.config.num_layers)

# threshold (was 99% in part 3)
svd_thresh = 95


# reset
compsCount = []

logits_clean = outs_clean.logits[0].detach().cpu().numpy()

# loop over layers
for layer2reduce in range(model.config.num_layers):

  # run the model
  with torch.no_grad():
    outs_reduced = model(tokens,output_hidden_states=True,labels=tokens)

  # extract hidden states for this layer
  hs_clean = outs_clean.hidden_states[layer2reduce+1].detach().cpu().numpy()
  hs_compr = outs_reduced.hidden_states[layer2reduce+1].detach().cpu().numpy()

  # norm of hidden-state differences across the entire hidden state matrix
  diffnorms[0,layer2reduce] = np.linalg.norm(hs_clean-hs_compr)

  # per-token difference analysis
  pertoken_diff[layer2reduce,:] = np.linalg.norm(hs_clean-hs_compr,axis=-1)

  # norm of logit differences
  logits_compr = outs_reduced.logits[0].detach().cpu().numpy()
  diffnorms[1,layer2reduce] = np.linalg.norm(logits_compr-logits_clean)

  # loss
  losses[layer2reduce] = outs_reduced.loss.item()


In [None]:
_,axs = plt.subplots(2,2,figsize=(10,6))

# plot the impact of the perturbations
axs[0,0].plot(compsCount,'kh-',markerfacecolor=[.9,.7,.7],linewidth=.5,markersize=12,label='Noisified')
axs[0,0].set(xlabel='Layer',ylabel='Count',title=f'A) Components to achieve {svd_thresh}% variance')

# and the max logit
axs[0,1].plot(diffnorms[0,:]/max(diffnorms[0,:]),'bs-',linewidth=.5,markerfacecolor=[.7,.7,.9,.7],markersize=12,label='Hidden states')
axs[0,1].plot(diffnorms[1,:]/max(diffnorms[1,:]),'ro-',linewidth=.5,markerfacecolor=[.9,.7,.7,.7],markersize=12,label='Logits')
axs[0,1].set(xlabel='Layer',ylabel='Diff norm (max-val scaled)',
              title='B) Impact on logits and hidden states')
axs[0,1].legend()

axs[1,0].plot(losses,'kp-',linewidth=.5,markerfacecolor=[.7,.9,.7],markersize=12)
axs[1,0].set(xlabel='Layer',ylabel='Loss',title='C) Impact of reduction on loss')




# normalization function for mapping layer index onto color
norm = mpl.colors.Normalize(vmin=0,vmax=model.config.num_layers)
hs_dim = 100*np.array(compsCount)/min(hs_compr.shape[1:])
axs[1,1].scatter(hs_dim,losses,color=plt.cm.CMRmap(norm(np.arange(model.config.num_layers))),
                 alpha=.8,s=150,edgecolor='k')

sm = mpl.cm.ScalarMappable(cmap=mpl.cm.CMRmap,norm=norm)
cbar = plt.colorbar(sm,ax=axs[1,1],pad=.02)
cbar.set_label(r'Hidden layer')


axs[1,1].set(ylabel='Loss',xlabel='Hidden state dimensionality (% total)',title='D) Reduction vs. loss')

plt.tight_layout()
plt.savefig('ch5_proj29_part4a.png')
plt.show()

In [None]:
plt.figure(figsize=(11,4))

plt.imshow(pertoken_diff,aspect='auto',origin='lower',cmap='magma',vmin=20,vmax=120)
plt.gca().set(ylabel='Layer',xlabel='Token index',title='Impact of reduction on hidden states')
plt.colorbar(pad=.01)

plt.tight_layout()
plt.savefig('ch5_proj29_part4b.png')
plt.show()