|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[21] Predict token position with linear and logistic regressions</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from datasets import load_dataset

from sklearn.linear_model import LinearRegression,LogisticRegression
from sklearn.metrics import confusion_matrix
import statsmodels.api as sm
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300,
})

# **Part 1: Create a dataset**

In [None]:
# load pretrained GPT-2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval()

In [None]:
# import the HellaSwag validation set
dataset = load_dataset('hellaswag',split='validation')
dataset

In [None]:
# batch parameters
seqlen = 6
nSamples = 500

# initialize the batch
batch = torch.zeros((,),dtype=)

# get the tokens for each sequence
for i in range(nSamples):
  firsttokens =
  batch[i,:] =

print(f'Shape of batch: {batch.shape}')
batch

In [None]:
tokenizer.decode(batch[13,:])

In [None]:
# forward pass to get logits (~16s on cpu)


out.logits.shape

In [None]:
# log-softmax to get losses
logits_logsm =

# initialize matrix of losses
losses = np.zeros(

# loop over sequences and tokens
for seqi in range(nSamples):
  for tokeni in range(seqlen-1):

    # single-token loss is -loglikelihood of target token
    target_idx =
    losses[seqi,tokeni] =

In [None]:
# matrix of ordinal positions
ord_position =
ord_position

In [None]:
# flatten (vectorize) the matrices
losses_flat =
ord_position_flat =

print(f'           Shape of losses: {losses.shape}')
print(f'      Shape of losses_flat: {losses_flat.shape}\n')
print(f'     Shape of ord_position: {ord_position.shape}')
print(f'Shape of ord_position_flat: {ord_position_flat.shape}')

In [None]:
# setup the figure and axes
fig = plt.figure(figsize=(12,4))
gs = GridSpec(1,3,figure=fig)
ax1 = fig.add_subplot(gs[:-1])
ax2 = fig.add_subplot(gs[-1])

# bar plot and scatter
ax1.bar()
ax1.plot()

ax1.set(xlabel='Token position',ylabel='Loss',title='A) Next-token prediction loss')


# distributions
for i in range(seqlen-1):
  y,x = np.histogram(
  ax2.plot(,,linewidth=2,label=f'Token {i}')

ax2.set(xlabel='Per-token loss',ylabel='Probability density',title='B) Distribution by token position')
ax2.legend()

plt.tight_layout()
plt.savefig('ch4_proj21_part1.png')
plt.show()

# **Part 2: Linear regression in statsmodels and sklearn**

In [None]:
# create and fit the model
X = sm.add_constant # design matrix with intercept term
smreg = sm.OLS().fit()
print( smreg.summary() )

In [None]:
# repeat using sklearn's LinearRegression
# ordpos must be squeezed back
reg = LinearRegression().fit(
print(f'const: {}')
print(f'x1   : {}')

# **Part 3: Does linear regression reconstruct ordinal position?**

In [None]:
# generate predictions
predictions = reg.

# visualize
plt.figure(figsize=(8,4))
plt.plot(,predictions,'ko',markerfacecolor=[.7,.7,.9,.5])
plt.gca().set(xlabel='True positions',ylabel='Predicted positions',xticks=range(seqlen-1),yticks=range(seqlen-1))

plt.tight_layout()
plt.savefig('ch4_proj21_part3.png')
plt.show()

# **Part 4: Fit a multinomial logistic regression**

In [None]:
# fit a logistic regression model
logreg = LogisticRegression

print('Intercepts (beta_0):\n',,'\n')
print('Coefficients (beta_1):\n',)

In [None]:
# visualize the coefficients
_,axs = plt.subplots(1,2,figsize=(12,4))

axs[0].bar()
axs[0].set(xlabel='Token position',ylabel='Coefficient value',xticks=range(seqlen-1),title='A) Relative slopes')

axs[1].axhline
axs[1].bar()
axs[1].set(xlabel='Token position',ylabel='Coefficient value',xticks=range(seqlen-1),title='B) Relative intercepts')

plt.tight_layout()
plt.savefig('ch4_proj21_part4.png')
plt.show()

# **Part 5: Categorical predictions**

In [None]:
# prediction scores
predictions = logreg.

print(f'Shape of predictions: {predictions.shape}\n')
print(f'Sum of predictions for each token:\n{predictions.sum(axis=1)}')

In [None]:
# get the class predictions (np.argmax selects the category with highest score)
predicted_categories =
predicted_categories

In [None]:
# create the confusion matrix
cm = confusion_matrix

# and normalize it by row sum
cm_norm = 100 *

In [None]:
# visualize per-category (token position) prediction accuracy
_,axs = plt.subplots(1,2,figsize=(10,4))

for i in range(seqlen-1):
  idxs =
  accs = predicted_categories[idxs]==
  axs[0].bar(i,
  axs[0].text(,,f'{}%',
              fontsize=12,fontweight='bold',ha='center',va='bottom')

# chance-level performance
axs[0].axhline(,linestyle='--',color='k',linewidth=.5,zorder=-10)

axs[0].set(xlabel='Token position',ylabel='Prediction accuracy (%)',xticks=range(seqlen-1),
           title='A) Category-specific prediction accuracy')


# and the confusion matrix
sns.heatmap(,,ax=axs[1])
axs[1].set(xlabel='Predicted position',ylabel='True position',
              title='B) Confusion matrix (% row-wise)')
plt.suptitle('TRAIN test performance',fontsize=16,fontweight='bold')

plt.tight_layout()
plt.savefig('ch4_proj21_part5.png')
plt.show()

# **Part 6: Test on new (untrained) data**

In [None]:
# get a new batch of data (parameters defined in Part 1)

# get the tokens for each sample (note dataset indexing)
for i in range(nSamples):
  firsttokens =
  batch[i,:] =

# forward pass to get logits (16s on cpu)
with torch.no_grad():
  out = model(batch)

In [None]:
# log-softmax to get losses
logits_logsm = F.log_softmax

# initialize matrix of losses
losses_test = np.zeros((,))

# loop over sequences and tokens
for seqi in range:
  for tokeni in range:

    # single-token loss is -loglikelihood of target token
    target_idx =
    losses_test[seqi,tokeni] =

losses_test_flat = losses_test.flatten() # need to reshape?

In [None]:
# prediction scores using logistic regression model from the first dataset
predictions =
predicted_categories = np.argmax

In [None]:
# create the confusion matrix
cm =

# and normalize it by row sum
cm_norm =

In [None]:
# visualize per-category (token position) prediction accuracy
_,axs = plt.subplots(1,2,figsize=(10,4))

for i in range(seqlen-1):
  idxs =
  accs =
  axs[0].bar(i,100*accs.mean())
  axs[0].text(i,1+100*accs.mean(),f'{100*accs.mean():.1f}%',
              fontsize=12,fontweight='bold',ha='center',va='bottom')

# chance-level performance
axs[0].axhline(,linestyle='--',color='k',linewidth=.5,zorder=-10)

axs[0].set(xlabel='Token position',ylabel='Prediction accuracy (%)',xticks=range(seqlen-1),
           title='A) Category-specific prediction accuracy')


# and the confusion matrix
sns.heatmap()
axs[1].set(xlabel='Predicted position',ylabel='True position',
              title='B) Confusion matrix (% row-wise)')

plt.suptitle('TEST test performance',fontsize=16,fontweight='bold')

plt.tight_layout()
plt.savefig('ch4_proj21_part6.png')
plt.show()