In [None]:
import torch
import numpy as np
from huggingface_hub import hf_hub_download
import gc  # garbage collection needed for low RAM footprint

# download Mistral-7B from https://huggingface.co/mistralai/Mistral-7B-v0.1
hf_hub_download(repo_id='mistralai/Mistral-7B-v0.1', filename='pytorch_model-00001-of-00002.bin', local_dir='.')
hf_hub_download(repo_id='mistralai/Mistral-7B-v0.1', filename='pytorch_model-00002-of-00002.bin', local_dir='.')

# load model files, use mmap to keep RAM footprint low
m1 = torch.load('pytorch_model-00001-of-00002.bin', weights_only=True, mmap=True)
m2 = torch.load('pytorch_model-00002-of-00002.bin', weights_only=True, mmap=True)

def get_weights(model, layer, name):
  """returns weight matrix of specific layer and name (such as Q, K, V)"""
  layer_str = 'layers.' + str(layer)
  match name:
    case 'Q': suffix = layer_str + '.self_attn.q_proj.weight'
    case 'K': suffix = layer_str + '.self_attn.k_proj.weight'
    case 'V': suffix = layer_str + '.self_attn.v_proj.weight'
    case 'P': suffix = layer_str + '.self_attn.o_proj.weight'
    case 'O': suffix = layer_str + '.mlp.down_proj.weight'
    case 'E': suffix = 'embed_tokens.weight'
  W = model['model.' + suffix].to(torch.float64).numpy()  # convert to float64
  return W if name == 'E' else W.T  # transpose weights, except for 'E'

for layer in range(0, 32):
  print('layer', layer)

  # get weights Q, K, V, P, O
  model = m1 if layer < 23 else m2  # use m1 for layers 0 to 22
  Q = get_weights(model, layer, 'Q')
  K = get_weights(model, layer, 'K')
  V = get_weights(model, layer, 'V')
  P = get_weights(model, layer, 'P')
  O = get_weights(model, layer - 1, 'E' if layer == 0 else 'O') # use embedding for 1st layer

  # check if weight elimination is numerically identical
  Q_inv = np.linalg.inv(Q)  # errors out if matrix is not invertible
  K_star = Q_inv @ K
  V_star = Q_inv @ V
  O_star = O @ Q
  print('   is O* @ K* close to O @ K ?  ', np.allclose(O_star @ K_star, O @ K))
  print('   is O* @ V* close to O @ V ?  ', np.allclose(O_star @ V_star, O @ V))

  # also check if P is invertible
  P_inv = np.linalg.inv(P)  # errors out if matrix is not invertible

# garbage collection (to avoid colab's RAM limit)
del m1, m2, model, Q, K, V, P, O, Q_inv, P_inv, K_star, V_star, O_star
gc.collect()