In [15]:
import os
import math
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
import torch.nn.functional as F

In [None]:
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.notebook import tqdm

In [2]:
## Imports for plotting


  set_matplotlib_formats('svg', 'pdf') # For export


<Figure size 640x480 with 0 Axes>

In [3]:
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/t6"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Seed set to 42


Device: cuda:0


In [6]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/"
# Files to download
pretrained_files = ["ReverseTask.ckpt", "SetAnomalyTask.ckpt"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/ReverseTask.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/SetAnomalyTask.ckpt...


In [20]:
from icecream import ic

In [23]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.shape[-1]
    attn_logits = torch.matmul(q, k.transpose(-2,-1))
    attn_logits = attn_logits/math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask==0, -9e15)
    ic(attn_logits)
    attention = F.softmax(attn_logits, dim=-1)
    ic(attention)
    values = torch.matmul(attention, v)
    return values, attention 

In [24]:
seq_len, d_k = 3,2
pl.seed_everything(42)
q = torch.randint(1,5, size=(seq_len, d_k)).float()
k = torch.randint(1,5, size=(seq_len, d_k)).float()
v = torch.randint(1,5, size=(seq_len, d_k)).float()
# q, k, v
# q.shape
# kt = k.transpose(-2,-1)
# k, k.shape, kt, kt.shape
values, attention = scaled_dot_product(q,k,v)
q, k, v, values, attention

Seed set to 42
ic| attn_logits: tensor([[ 4.9497, 12.0208, 14.8492],
                         [ 2.8284,  6.3640,  8.4853],
                         [ 4.9497, 12.0208, 14.8492]])
ic| attention: tensor([[4.7396e-05, 5.5805e-02, 9.4415e-01],
                       [3.1098e-03, 1.0671e-01, 8.9018e-01],
                       [4.7396e-05, 5.5805e-02, 9.4415e-01]])


(tensor([[3., 4.],
         [1., 3.],
         [3., 4.]]),
 tensor([[1., 1.],
         [3., 2.],
         [3., 3.]]),
 tensor([[3., 3.],
         [4., 1.],
         [4., 4.]]),
 tensor([[4.0000, 3.8325],
         [3.9969, 3.6768],
         [4.0000, 3.8325]]),
 tensor([[4.7396e-05, 5.5805e-02, 9.4415e-01],
         [3.1098e-03, 1.0671e-01, 8.9018e-01],
         [4.7396e-05, 5.5805e-02, 9.4415e-01]]))