In [135]:
%load_ext autoreload
%autoreload 2

%cd -q ..


import lcpfn
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import math

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [136]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return x


class TransformerModel(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, dim_feedforward, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.model_dim = model_dim
        self.pos_encoder = PositionalEncoding(model_dim)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=model_dim, nhead=num_heads, dim_feedforward=dim_feedforward
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, num_layers=num_layers
        )
        self.encoder = nn.Linear(input_dim, model_dim)
        self.decoder = nn.Linear(model_dim, 1)
        self.init_weights()

        self.dropout = nn.Dropout(dropout)


    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    # forwards pass
    def forward(self, src):
        src = self.encoder(src) * math.sqrt(self.model_dim)

        src = self.pos_encoder(src)
        src = self.dropout(src)

        src_mask = nn.Transformer.generate_square_subsequent_mask(src.size(0)).to(device)
        output = self.transformer_encoder(src, src_mask)
        # Use the representation of the last position
        output = self.decoder(output)
        output = output[:, -1, :]
        return output.squeeze(-1)

Used to test our models

In [137]:
from torch.utils.data import TensorDataset, DataLoader

def test(model):

    get_batch_func = lcpfn.create_get_batch_func(prior=lcpfn.sample_from_prior)
    X, Y, Y_noisy = get_batch_func(batch_size=100, seq_len=100, num_features=1)
    Y = Y.permute(1, 0)

    dataset = TensorDataset(Y, Y)
    data_loader = DataLoader(
        dataset, batch_size=5, shuffle=False
    )

    criterion = torch.nn.MSELoss()

    total_loss = 0


    for input_sequence, target_sequence in data_loader:
        model.eval()

        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)

        input_sequence = input_sequence.unsqueeze(-1)  # [batch_size, input_length, features]

        step_loss = 0

        current_input = input_sequence[:,:15]


        for i in range(15, 99):

            prediction = model(current_input)

            loss = criterion(prediction, target_sequence[:, i + 1])
            step_loss += loss

            current_input =  torch.cat((current_input, prediction.unsqueeze(-1).unsqueeze(-1)), dim=1)

        total_loss += step_loss

    print(f"Loss: {total_loss / len(data_loader)}")
    return total_loss / len(data_loader)

In [141]:
model = torch.load('small_model_no_teaching_euclidean.pth')

In [142]:
import gc


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

results = []

for i in range(10):
    torch.cuda.empty_cache()
    gc.collect()
    with torch.no_grad():
        results.append(test(model))

results = torch.tensor(results)
print("mean:", torch.mean(results))
print("std:", torch.std(results))

Using device: cuda
Loss: 4.238088130950928
Loss: 4.261197090148926
Loss: 3.578749179840088
Loss: 3.11883282661438
Loss: 4.031520366668701
Loss: 2.3648316860198975
Loss: 4.415432929992676
Loss: 7.342675685882568
Loss: 5.032954692840576
Loss: 3.9560811519622803
mean: tensor(4.2340)
std: tensor(1.3174)


Used to test their models

In [143]:
def test_LCPFN():

    total_loss = 0
    criterion = torch.nn.MSELoss()

    for i in range(100):
        model = lcpfn.LCPFN()

        prior = lcpfn.sample_from_prior(np.random)
        curve, _ = prior()


        x = torch.arange(1, 101).unsqueeze(1)
        y = torch.from_numpy(curve).float().unsqueeze(1)

        cutoff = 15

        pred = model.predict_mean(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:])
        loss = criterion(y[cutoff:], pred)
        total_loss += loss.item()

    print(f"Loss: {total_loss / 100}")

    return total_loss/100

In [144]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

results = []

for i in range(10):
    torch.cuda.empty_cache()
    gc.collect()
    with torch.no_grad():
        results.append(test_LCPFN())

results = torch.tensor(results)
print("mean:", torch.mean(results))
print("std:", torch.std(results))

Using device: cuda
Loss: 0.0002437597459966234
Loss: 0.0001133347373431448
Loss: 0.00019366741526180675
Loss: 0.0003397745910585215
Loss: 0.00019523907275143415
Loss: 0.00023576721137182232
Loss: 0.00024106429422417362
Loss: 0.00025883323125299286
Loss: 0.0002560805691150847
Loss: 0.0002724775551940439
mean: tensor(0.0002)
std: tensor(5.9255e-05)


Code to make the graphs (saved in graphs)

In [145]:
def get_results(cutoff, y, model):

    def predict_single_sequence(input_sequence):
        model.eval()  # Ensure the model is in eval mode
        with torch.no_grad():  # No gradients needed
            # Assuming input_sequence is already a PyTorch tensor with the right shape and dtype
            prediction = model(input_sequence).to(device)
            # Convert the prediction back to a Python number for easy interpretation
            predicted_value = prediction.item()
        return predicted_value

    # Example usage
    input_data = y[:cutoff].unsqueeze(0).to(device)  # Example input sequence

    result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)

    for i in range(100-cutoff):
        predictions = torch.tensor(predict_single_sequence(result_tensor)).to(device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        # print(f"Predicted Value: {predictions}")
        result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)


    results = result_tensor.squeeze(0).squeeze(-1).cpu()

    return results


In [154]:
def plot_all(filename):

    cutoff = 15

    prior = lcpfn.sample_from_prior(np.random)
    curve, _ = prior()
    plt.plot(curve, "black")
    plt.ylim(0, 1)

    model.eval()

    x = torch.arange(1, 101).unsqueeze(1)
    y = torch.from_numpy(curve).float().unsqueeze(1)

    model_lcpfn = lcpfn.LCPFN()
    lcpfn_curve = model_lcpfn.predict_mean(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff: ])

    big_model = torch.load('100_epochs.pth')
    big_model_result = get_results(cutoff, y, big_model)

    small_model = torch.load('small_model.pth')
    small_model_result = get_results(cutoff, y, small_model)

    small_model_no = torch.load('small_model_no_teaching.pth')
    small_model_result_no = get_results(cutoff, y, small_model_no)

    small_model_no_eu = torch.load('small_model_no_teaching_euclidean.pth')
    small_model_result_no_eu = get_results(cutoff, y, small_model_no_eu)

    plt.plot(curve, "black", label="Target")
    plt.plot(x[cutoff:], small_model_result[cutoff:], label="Small Model FT")
    plt.plot(x[cutoff:], small_model_result_no[cutoff:], label="Small Model")
    plt.plot(x[cutoff:], small_model_result_no_eu[cutoff:], label="Small Model EU")
    plt.plot(x[cutoff:], big_model_result[cutoff:], label="Large Model FT")
    plt.plot(x[cutoff:], lcpfn_curve, label="LCPFN")

    # plot cutoff
    plt.vlines(cutoff + 1, 0, 1, linewidth=2, color="k", label="Cutoff")
    plt.ylim(0, 1)

    plt.legend(loc="lower right")
    plt.savefig(f"graphs/{filename}.pdf")
    plt.close()


In [155]:
for i in range(20):
    plot_all(str(i))

  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor = torch.cat((result_tensor, torch.tensor(predictions)), dim=1)
  result_tensor = torch.tensor(input_data, dtype=torch.float).to(device)
  result_tensor