# Training of PyTorch models

In this notebook we build and train the 4 transformer models created using PyTorch.

In [2]:
%load_ext autoreload
%autoreload 2

%cd -q ..


import lcpfn
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import math

from torch.utils.data import TensorDataset, DataLoader

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Standard (sin/cos) positional encoding

In [12]:
# encode each point in input sequence into positional vectors that hold their positional information
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # turn input tokens into positional vecotrs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        # calulate positional value for every element in vector
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # Add positional encoding to input tensor
        x = x + self.pe[: x.size(0), :]
        return x

Euclidean positional encoding

In [13]:
# encode each point in input sequence into positional vectors that hold their positional information using euclidean distance
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # get x-value of each point (index)
        position = torch.arange(max_len).unsqueeze(1).float()

        #calculate euclidean distance
        euclidean_distance = torch.sqrt(position ** 2)

        # need to be in vector format for transformer to accept
        pe = euclidean_distance.repeat(1, d_model)

        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to input tensor
        x = x + self.pe[:x.size(0), :]
        return x

Transformer Model

In [14]:
# Transformer model built using PyTorch Transformer module
class TransformerModel(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, dim_feedforward, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.model_dim = model_dim
        # use the positional encoder defined above
        self.pos_encoder = PositionalEncoding(model_dim)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=model_dim, nhead=num_heads, dim_feedforward=dim_feedforward
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, num_layers=num_layers
        )
        self.encoder = nn.Linear(input_dim, model_dim)
        self.decoder = nn.Linear(model_dim, 1)
        self.init_weights()

        self.dropout = nn.Dropout(dropout)


    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    # forwards pass
    def forward(self, src):
        src = self.encoder(src) * math.sqrt(self.model_dim)

        # encode input into positional vectors
        src = self.pos_encoder(src)
        src = self.dropout(src)

        # src_mask = nn.Transformer.generate_square_subsequent_mask(src.size(0)).to(device)
        # output = self.transformer_encoder(src, src_mask)

        output = self.transformer_encoder(src)
        output = self.decoder(output)

        # output only one prediction point (last)
        output = output[:, -1, :]
        return output.squeeze(-1)

In [15]:
# allow cuda support if GPU is present on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Instantiate the large model. 

In [16]:
# instantiate model
model = TransformerModel(
    input_dim=1,
    model_dim=128,
    num_heads=8,
    num_layers=4,
    dim_feedforward=512,
).to(device)

Instantiate the small model.

In [17]:
model = TransformerModel(
    input_dim=1,
    model_dim=32,
    num_heads=8,
    num_layers=4,
    dim_feedforward=32,
).to(device)

In [18]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [19]:
get_batch_func = lcpfn.create_get_batch_func(prior=lcpfn.sample_from_prior)
# get a curce from their prior data
# the x-values of the curve points are stored in X ([1..100]), the y-values are stored in Y
X, Y, Y_noisy = get_batch_func(batch_size=1, seq_len=100, num_features=1)
Y = Y.permute(1, 0) # permute because batch size was in dim 1

dataset = TensorDataset(Y, Y)

data_loader = DataLoader(
    dataset, batch_size=1, shuffle=True
)

No Forced Teaching:

The first 5 epochs are ran with forced teaching to give a reference point to the model

Then for every of the remaining 95 epochs:

- First 15 points are predicted with forced treaching (before cutoff). So model knows what curve to predict (otherwise would be just predicted points)
    
- Remaining 85 points are predicted without forced training (after cutoff). The prediction is added to the input of the model when predicting the next point.

In [21]:
# first 5 epochs are done with forced teaching 
for epoch in range(0, 5):
    model.train()
    # initialise loss
    total_loss = 0
    #  for every curve in batch
    for input_sequence, target_sequence in data_loader:
        optimizer.zero_grad()

        # transfer to device (GPU if available)
        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)

        input_sequence = input_sequence.unsqueeze(-1)  # [batch_size, input_length, features]

        step_loss = 0

        # for every point in curve
        for i in range(99):
            
            # points in sequence before point we want to predict
            current_input = input_sequence[:,:i + 1]

            # make prediction
            prediction = model(current_input)

            # calculate loss from prediction
            loss = criterion(prediction, target_sequence[:, i + 1])
            step_loss += loss


        # update params based on loss
        step_loss.backward()
        optimizer.step()
        total_loss += step_loss.item()

    print(f"Epoch {epoch}, Loss: {total_loss / len(data_loader)}")

# epochs without forced teaching: 
# first 15 points still forced teaching
for epoch in range(5, 100):
    model.train()
    # initialise loss
    total_loss = 0
    #  for every curve in batch
    for input_sequence, target_sequence in data_loader:
        # transfer to device (GPU if available)
        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)
        input_sequence = input_sequence.unsqueeze(-1)  # [batch_size, input_length, features]

        current_input = input_sequence[:, :1]

        # for every point in curve
        for i in range(99):
            optimizer.zero_grad() 

            if i < 15: # first 15 points use forced teaching
                current_input = input_sequence[:, :i + 1]
            else: # remaining 85 without forced teaching (uses previous predictions)
                prediction = prediction.detach()
                current_input = torch.cat((current_input, prediction.unsqueeze(-1).unsqueeze(-1)), dim=1)

            # update params based on loss
            prediction = model(current_input)
            loss = criterion(prediction, target_sequence[:, i + 1])
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

    print(f"Epoch {epoch}, Loss: {total_loss / len(data_loader)}")

# torch.save(model, 'small_model.pth')

Epoch 0, Loss: 0.07294952124357224
Epoch 1, Loss: 0.12479734420776367
Epoch 2, Loss: 0.137273371219635
Epoch 3, Loss: 0.07458188384771347
Epoch 4, Loss: 0.13763538002967834
Epoch 5, Loss: 0.32834605314577914
Epoch 6, Loss: 0.08655254443011628
Epoch 7, Loss: 0.07276828444535965
Epoch 8, Loss: 0.06043922241916988
Epoch 9, Loss: 0.06990627178373643
Epoch 10, Loss: 0.07276537762379576
Epoch 11, Loss: 0.07949924768594485
Epoch 12, Loss: 0.06096301249914404
Epoch 13, Loss: 0.06353910553127484
Epoch 14, Loss: 0.08316188104530653
Epoch 15, Loss: 0.06655544537818514
Epoch 16, Loss: 0.06975374672482104
Epoch 17, Loss: 0.0794184916316567
Epoch 18, Loss: 0.06724467207827445
Epoch 19, Loss: 0.06402873078729954
Epoch 20, Loss: 0.07056964406859834
Epoch 21, Loss: 0.06080140176426596
Epoch 22, Loss: 0.06095988152604548
Epoch 23, Loss: 0.0630595552012494
Epoch 24, Loss: 0.07092146430665025
Epoch 25, Loss: 0.06935044517135225
Epoch 26, Loss: 0.06033167481906876
Epoch 27, Loss: 0.052580945930102985
Epoch

Forced Teaching: Give true target point values as input everytime we make model predict

In [23]:
# train with forced teaching
for epoch in range(100):
    model.train()
    # initialise loss
    total_loss = 0
    # for every curve in batch
    for input_sequence, target_sequence in data_loader:
        optimizer.zero_grad()

        # transfer to device (GPU if available)
        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)

        input_sequence = input_sequence.unsqueeze(-1)  # [batch_size, input_length, features]

        step_loss = 0

        # for every point in curve
        for i in range(99):
            
            # points in sequence before point we want to predict
            current_input = input_sequence[:,:i + 1]

            # make prediction
            prediction = model(current_input)

            # calculate loss from prediction
            loss = criterion(prediction, target_sequence[:, i + 1])
            step_loss += loss


        # update params based on loss
        step_loss.backward()
        optimizer.step()
        total_loss += step_loss.item()

    print(f"Epoch {epoch}, Loss: {total_loss / len(data_loader)}")

# torch.save(model, 'small_model_FT.pth')

Epoch 0, Loss: 0.07982659339904785
Epoch 1, Loss: 0.047912146896123886
Epoch 2, Loss: 0.07933411747217178
Epoch 3, Loss: 0.0660719946026802
Epoch 4, Loss: 0.045396190136671066
Epoch 5, Loss: 0.059247951954603195
Epoch 6, Loss: 0.060916077345609665
Epoch 7, Loss: 0.04747267812490463
Epoch 8, Loss: 0.0565451942384243
Epoch 9, Loss: 0.051484573632478714
Epoch 10, Loss: 0.04767507314682007
Epoch 11, Loss: 0.05113532394170761
Epoch 12, Loss: 0.052436549216508865
Epoch 13, Loss: 0.04851679876446724
Epoch 14, Loss: 0.04516364261507988
Epoch 15, Loss: 0.05020304024219513
Epoch 16, Loss: 0.04838121682405472
Epoch 17, Loss: 0.046012599021196365
Epoch 18, Loss: 0.047351472079753876
Epoch 19, Loss: 0.05071686580777168
Epoch 20, Loss: 0.048999276012182236
Epoch 21, Loss: 0.04605140909552574
Epoch 22, Loss: 0.047080788761377335
Epoch 23, Loss: 0.048482805490493774
Epoch 24, Loss: 0.048989132046699524
Epoch 25, Loss: 0.047332484275102615
Epoch 26, Loss: 0.046500030905008316
Epoch 27, Loss: 0.04660132