|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 2:</h2>|<h1>Large language models<h1>|
|<h2>Section:</h2>|<h1>Build a GPT<h1>|
|<h2>Lecture:</h2>|<h1><b>Averaging the past while ignoring the future (code)<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import numpy as np

# Equal-weighted average of the past

In [None]:
thepast = torch.tensor([ 4,1,-2,-3 ])
N = len(thepast)

weights = torch.ones(N) / N

print(f'The past: {thepast}')
print(f'Weights (importance) of the past: {weights}')
print(f'Sum over all weights: {sum(weights)}')

In [None]:
thepresent = sum(thepast*weights)
print(f'The present (weighted sum of the past): {thepresent}')

# Weighted average of the past

In [None]:
weights = torch.tensor([ 2,1,1,1 ])
print(f'Sum of weights: {sum(weights)}. Uh oh...')

In [None]:
linear_weights = weights / sum(weights)
softmax_weights = torch.exp(weights) / sum(torch.exp(weights))

print(f'Scaled weights: {linear_weights}')
print(f'\tTheir sum: {sum(linear_weights)}')

print(f'\nSoftmax weights: {softmax_weights}')
print(f'\tTheir sum: {sum(softmax_weights)}')

In [None]:
thepresent_linear  = sum(thepast*linear_weights)
thepresent_softmax = sum(thepast*softmax_weights)

print(f'The present (linear sum of the past):  {thepresent_linear}')
print(f'The present (softmax sum of the past): {thepresent_softmax}')

# Ignoring the future

In [None]:
thedata = torch.tensor([ 4,1,-2,-3,8,3,-1 ])
present_moment = 4
N = len(thedata)

print(f'Past data: {thedata[:present_moment]}')
print(f'The present: {thedata[present_moment]}')
print(f'The future: {thedata[present_moment+1:]}')

In [None]:
past_weights = torch.ones(N)
past_weights[present_moment+1:] = 0
past_weights

In [None]:
past_weights_linear = past_weights / torch.sum(past_weights)

print(f'Scaled weights: {past_weights_linear}')
print(f'\tTheir sum: {sum(past_weights_linear)}')

In [None]:
# softmax the weights with zeros
past_weights_softmax = torch.exp(past_weights) / torch.sum(torch.exp(past_weights))

print(f'Softmax weights: {past_weights_softmax}')
print(f'\tTheir sum: {sum(past_weights_softmax)}')

In [None]:
# e.g.:
torch.exp(torch.tensor([-10]))

In [None]:
# recreate the weights for the past, but setting future values to -infinity
past_weights = torch.ones(N)
past_weights[present_moment+1:] = -torch.inf

# softmaxify
past_weights_softmax = torch.exp(past_weights) / torch.sum(torch.exp(past_weights))

# print the results
print(f'Unscaled weights: {past_weights}')
print(f'Scaled weights: {past_weights_softmax}')
print(f'\tTheir sum: {sum(past_weights_softmax)}')

# Steps toward the future, looking back into the past

In [None]:
# rows are calculation steps, columns are time points
tril = torch.tril(torch.ones(9,9))
tril

In [None]:
tril[tril==0] = -torch.inf
tril

In [None]:
# softmaxify
tril_softmax = F.softmax(tril,dim=-1)
tril_softmax

In [None]:
for timepoint in range(tril.shape[0]):
  print(f'\nWeights for calculation at time point {timepoint}:')
  print(f'\t{tril_softmax[timepoint]}')

# Final demo with random activations

In [None]:
activations = torch.randn(N,N)
tril = torch.tril(torch.ones(N,N))

print('-- ORIGINAL ACTIVATIONS:')
print(activations)

print('\n-- PAST WEIGHTING FACTOR:')
print(tril)

scaled_activations = activations * tril
scaled_activations[scaled_activations==0] = -torch.inf
print('\n-- SCALED PAST ACTIVATIONS:')
print(scaled_activations)

softmax_past = F.softmax(scaled_activations,dim=-1)
print('\n-- SOFTMAX PAST ACTIVATIONS:')
print(softmax_past)

In [None]:
# confirm:
torch.sum(softmax_past,dim=-1)

# FYI, timing some alternatives

In [None]:
import time
nIters = int(2e5)

# option 1: find zeros and -torch.inf
start_time = time.time()
for _ in range(nIters):
  tril = torch.tril(torch.ones(10,10))
  tril[tril==0] = -torch.inf
print(f'Option 1: {time.time()-start_time:.3f} sec')

# option 2: find zeros and float('-inf')
start_time = time.time()
for _ in range(nIters):
  tril = torch.tril(torch.ones(10,10))
  tril[tril==0] = float('-inf')
print(f'Option 2: {time.time()-start_time:.3f} sec')

# option 3: masked_fill with float('-inf')
start_time = time.time()
for _ in range(nIters):
  tril = torch.tril(torch.ones(10,10))
  tril = tril.masked_fill(tril==0, float('-inf'))
print(f'Option 3: {time.time()-start_time:.3f} sec')