|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[21] Predict token position with linear and logistic regressions</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from datasets import load_dataset

from sklearn.linear_model import LinearRegression,LogisticRegression
from sklearn.metrics import confusion_matrix
import statsmodels.api as sm
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300,
})

# **Part 1: Create a dataset**

In [None]:
# load pretrained GPT-2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval()

In [None]:
# import the HellaSwag validation set
dataset = load_dataset('hellaswag',split='validation')
dataset

In [None]:
# batch parameters
seqlen = 6
nSamples = 500

# initialize the batch
batch = torch.zeros((nSamples,seqlen),dtype=torch.long)

# get the tokens for each sequence
for i in range(nSamples):
  firsttokens = tokenizer.encode(dataset[i]['ctx_a'])[:seqlen]
  batch[i,:] = torch.tensor(firsttokens)

print(f'Shape of batch: {batch.shape}')
batch

In [None]:
tokenizer.decode(batch[13,:])

In [None]:
# forward pass to get logits (16s on cpu)
with torch.no_grad():
  out = model(batch)

out.logits.shape

In [None]:
# log-softmax to get losses
logits_logsm = F.log_softmax(out.logits,dim=-1)

# initialize matrix of losses
losses = np.zeros((nSamples,seqlen-1))

# loop over sequences and tokens
for seqi in range(nSamples):
  for tokeni in range(seqlen-1):

    # single-token loss is -loglikelihood of target token
    target_idx = batch[seqi,tokeni+1]
    losses[seqi,tokeni] = -logits_logsm[seqi,tokeni,target_idx].item()

In [None]:
# matrix of ordinal positions
ord_position = np.tile(np.arange(seqlen-1),(nSamples,1))
ord_position

In [None]:
# flatten (vectorize) the matrices
losses_flat = losses.flatten().reshape(-1,1)
ord_position_flat = ord_position.flatten().reshape(-1,1)

print(f'           Shape of losses: {losses.shape}')
print(f'      Shape of losses_flat: {losses_flat.shape}\n')
print(f'     Shape of ord_position: {ord_position.shape}')
print(f'Shape of ord_position_flat: {ord_position_flat.shape}')

In [None]:
# setup the figure and axes
fig = plt.figure(figsize=(12,4))
gs = GridSpec(1,3,figure=fig)
ax1 = fig.add_subplot(gs[:-1])
ax2 = fig.add_subplot(gs[-1])

# bar plot and scatter
ax1.bar(range(seqlen-1),losses.mean(axis=0),color=[.7,.7,.9],edgecolor='k',linewidth=.5)
ax1.plot(ord_position_flat+np.random.normal(0,.03,ord_position_flat.shape),
         losses_flat,'ko',markerfacecolor=[.9,.7,.7,.5])

ax1.set(xlabel='Token position',ylabel='Loss',title='A) Next-token prediction loss')


# distributions
for i in range(seqlen-1):
  y,x = np.histogram(losses[:,i],bins=10,density=True)
  ax2.plot(x[:-1],y,linewidth=2,label=f'Token {i}')

ax2.set(xlabel='Per-token loss',ylabel='Probability density',title='B) Distribution by token position')
ax2.legend()

plt.tight_layout()
plt.savefig('ch4_proj21_part1.png')
plt.show()

# **Part 2: Linear regression in statsmodels and sklearn**

In [None]:
# create and fit the model
X = sm.add_constant(losses_flat) # design matrix with intercept term
smreg = sm.OLS(ord_position_flat,X).fit()
print( smreg.summary() )

In [None]:
# repeat using sklearn's LinearRegression
# ordpos must be squeezed back
reg = LinearRegression().fit(losses_flat,ord_position_flat.squeeze())
print(f'const: {reg.intercept_:7.4f}')
print(f'x1   : {reg.coef_[0]:7.4f}')

# **Part 3: Does linear regression reconstruct ordinal position?**

In [None]:
# generate predictions
predictions = reg.predict(losses_flat)

# visualize
plt.figure(figsize=(8,4))
plt.plot(ord_position_flat+np.random.normal(0,.03,ord_position_flat.shape),
         predictions,'ko',markerfacecolor=[.7,.7,.9,.5])
plt.gca().set(xlabel='True positions',ylabel='Predicted positions',xticks=range(seqlen-1),yticks=range(seqlen-1))

plt.tight_layout()
plt.savefig('ch4_proj21_part3.png')
plt.show()

# **Part 4: Fit a multinomial logistic regression**

In [None]:
# fit a logistic regression model
logreg = LogisticRegression(solver='lbfgs').fit(losses_flat,ord_position_flat.squeeze())

print('Intercepts (beta_0):\n',logreg.intercept_,'\n')
print('Coefficients (beta_1):\n',logreg.coef_.T)

In [None]:
# visualize the coefficients
_,axs = plt.subplots(1,2,figsize=(12,4))

axs[0].axhline(y=0,color='k',linestyle='--',linewidth=.5)
axs[0].bar(range(len(logreg.coef_)),logreg.coef_.squeeze(),color=[.9,.7,.7],edgecolor='k',linewidth=.5)
axs[0].set(xlabel='Token position',ylabel='Coefficient value',xticks=range(seqlen-1),title='A) Relative slopes')

axs[1].axhline(y=0,color='k',linestyle='--',linewidth=.5)
axs[1].bar(range(len(logreg.coef_)),logreg.intercept_.squeeze(),color=[.7,.7,.9],edgecolor='k',linewidth=.5)
axs[1].set(xlabel='Token position',ylabel='Coefficient value',xticks=range(seqlen-1),title='B) Relative intercepts')

plt.tight_layout()
plt.savefig('ch4_proj21_part4.png')
plt.show()

# **Part 5: Categorical predictions**

In [None]:
# prediction scores
predictions = logreg.predict_proba(losses_flat)

print(f'Shape of predictions: {predictions.shape}\n')
print(f'Sum of predictions for each token:\n{predictions.sum(axis=1)}')

In [None]:
# get the class predictions (np.argmax selects the category with highest score)
predicted_categories = np.argmax(predictions,axis=1)
predicted_categories

In [None]:
# create the confusion matrix
cm = confusion_matrix(ord_position_flat.squeeze(),predicted_categories)

# and normalize it by row sum
cm_norm = 100 * cm/cm.sum(axis=1,keepdims=True)

In [None]:
# visualize per-category (token position) prediction accuracy
_,axs = plt.subplots(1,2,figsize=(10,4))

for i in range(seqlen-1):
  idxs = ord_position_flat.squeeze()==i
  accs = predicted_categories[idxs]==ord_position_flat[idxs]
  axs[0].bar(i,100*accs.mean())
  axs[0].text(i,1+100*accs.mean(),f'{100*accs.mean():.1f}%',
              fontsize=12,fontweight='bold',ha='center',va='bottom')

# chance-level performance
axs[0].axhline(100/(ord_position_flat.max()+1),linestyle='--',color='k',linewidth=.5,zorder=-10)

axs[0].set(xlabel='Token position',ylabel='Prediction accuracy (%)',xticks=range(seqlen-1),
           title='A) Category-specific prediction accuracy')


# and the confusion matrix
sns.heatmap(cm_norm,annot=True,fmt='.1f',cmap='Reds',annot_kws={'size': 15},ax=axs[1])
axs[1].set(xlabel='Predicted position',ylabel='True position',
              title='B) Confusion matrix (% row-wise)')
plt.suptitle('TRAIN test performance',fontsize=16,fontweight='bold')

plt.tight_layout()
plt.savefig('ch4_proj21_part5.png')
plt.show()

# **Part 6: Test on new (untrained) data**

In [None]:
# get a new batch of data (parameters defined in Part 1)

# get the tokens for each sample (note dataset indexing)
for i in range(nSamples):
  firsttokens = tokenizer.encode(dataset[1000+i]['ctx_a'])[:seqlen]
  batch[i,:] = torch.tensor(firsttokens)

# forward pass to get logits (16s on cpu)
with torch.no_grad():
  out = model(batch)

In [None]:
# log-softmax to get losses
logits_logsm = F.log_softmax(out.logits,dim=-1)

# initialize matrix of losses
losses_test = np.zeros((nSamples,seqlen-1))

# loop over sequences and tokens
for seqi in range(nSamples):
  for tokeni in range(seqlen-1):

    # single-token loss is -loglikelihood of target token
    target_idx = batch[seqi,tokeni+1]
    losses_test[seqi,tokeni] = -logits_logsm[seqi,tokeni,target_idx].item()

losses_test_flat = losses_test.flatten().reshape(-1,1)

In [None]:
# prediction scores using logistic regression model from the first dataset
predictions = logreg.predict_proba(losses_test_flat)
predicted_categories = np.argmax(predictions,axis=1)

In [None]:
# create the confusion matrix
cm = confusion_matrix(ord_position_flat.squeeze(),predicted_categories)

# and normalize it by row sum
cm_norm = 100 * cm/cm.sum(axis=1,keepdims=True)

In [None]:
# visualize per-category (token position) prediction accuracy
_,axs = plt.subplots(1,2,figsize=(10,4))

for i in range(seqlen-1):
  idxs = ord_position_flat.squeeze()==i
  accs = predicted_categories[idxs]==ord_position_flat[idxs]
  axs[0].bar(i,100*accs.mean())
  axs[0].text(i,1+100*accs.mean(),f'{100*accs.mean():.1f}%',
              fontsize=12,fontweight='bold',ha='center',va='bottom')

# chance-level performance
axs[0].axhline(100/(1+ord_position_flat.max()),linestyle='--',color='k',linewidth=.5,zorder=-10)

axs[0].set(xlabel='Token position',ylabel='Prediction accuracy (%)',xticks=range(seqlen-1),
           title='A) Category-specific prediction accuracy')


# and the confusion matrix
sns.heatmap(cm_norm,annot=True,fmt='.1f',cmap='Reds',annot_kws={'size': 15},ax=axs[1])
axs[1].set(xlabel='Predicted position',ylabel='True position',
              title='B) Confusion matrix (% row-wise)')

plt.suptitle('TEST test performance',fontsize=16,fontweight='bold')

plt.tight_layout()
plt.savefig('ch4_proj21_part6.png')
plt.show()