In [1]:
import os
import shutil
import numpy as np
import pickle
import optuna
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

from Arwin.dataset.synthetic_dataset import SyntheticDataset
from Arwin.dataset.windowed_dataset import WindowedDataset
from Arwin.model.trainer_Im2 import TrainerIM2 as Trainer
from Arwin.model.deeponet import DeepONet
from Arwin.model.embedding_forecaster import EmbeddingForcaster
from Arwin.src.utils import *

from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
data_train_path = 'Arwin/dataset/training_dataset.pkl'
data_test_path = 'Arwin/dataset/testing_dataset.pkl'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
training_dataset = pickle.load(open(data_train_path, 'rb'))
valid_dataset = pickle.load(open(data_test_path, 'rb'))

In [3]:
""" Prepare dataset """
training_dataset_windows = WindowedDataset(training_dataset.functions, training_dataset.observations, training_dataset.masks, shuffle=True)
valid_dataset_windows = WindowedDataset(valid_dataset.functions, valid_dataset.observations, valid_dataset.masks, eval=True, shuffle=True)

train_loader = torch.utils.data.DataLoader(dataset=training_dataset_windows, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_windows)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset_windows, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_windows)

In [8]:
def new_forward(deeponet, y, t, eval_grid_points, mask):
    # Generate the fine grid points batch dynamically for the current batch size
    batch_size = y.shape[0]
    fine_grid_points_batch = eval_grid_points.unsqueeze(0).expand(batch_size, -1)

    # Mask the input data
    y = y.unsqueeze(-1) * mask.unsqueeze(-1)
    t = t.unsqueeze(-1) * mask.unsqueeze(-1)
    t_sample =  fine_grid_points_batch.unsqueeze(-1)
    # Branch and Trunk Embedding
    branch_embedding_y = deeponet.branch_embedding_y(y)
    branch_embedding_t = deeponet.branch_embedding_t(t)
    trunk_encoder_input = deeponet.trunk_embedding_t(t_sample)

    # generate mask for the transformer encoder
    mask_enc = torch.where(mask == 1, False, True)

    # Transformer Encoder for the Branch Network
    branch_encoder_input = deeponet.embedding_act(branch_embedding_y + branch_embedding_t)
    branch_encoder_output = deeponet.branch_encoder(branch_encoder_input, src_key_padding_mask=mask_enc)

    # Mask the output of the transformer encoder
    branch_encoder_output = branch_encoder_output * mask.unsqueeze(-1)
    """ Modifications to the original DeepONet """
    # Attention-based summary
    H = branch_encoder_output  # Shape: [batch_size, 128, d_model]
    q = deeponet.query.unsqueeze(0).expand(batch_size, -1, -1)  # Shape: [batch_size, 1, d_model]

    # Multihead attention: query (q), keys/values (H)
    h_b, _ = deeponet.summary_attention(q, H, H, key_padding_mask=mask_enc)  # h_b: [batch_size, 1, d_model]
    h_b = h_b.squeeze(1)  # Flatten to [batch_size, d_model]

    return h_b
    """ -------------------------------------- """
    branch_output = self.branch_mlp(h_b) 
    trunk_output = self.trunk_mlp(trunk_encoder_input)

    """ Modifications to the original DeepONet """
    # Expand branch_output to match the sequence length of trunk_output
    branch_output_expanded = branch_output.unsqueeze(1).expand(-1, trunk_output.shape[1], -1)  # [batch_size, d_model, p]

    # Concatenate branch_output with trunk_output along the feature dimension
    combined_input = torch.cat((branch_output_expanded, trunk_output), dim=-1)  # [batch_size, d_model, 2 * p]

    # Flatten the batch and sequence dimensions to apply combined_mlp in one step
    combined_input_flattened = combined_input.view(-1, combined_input.shape[-1])  # [batch_size * d_model, 2 * p]

    # Pass through combined_mlp
    mlp_output = self.combined_mlp(combined_input_flattened)  # [batch_size * d_model, 1]

    # Reshape back to the original structure
    combined_out = mlp_output.view(branch_output.shape[0], trunk_output.shape[1])  # [batch_size, d_model]

    """ -------------------------------------- """
        
    return combined_out

In [4]:
d_model = 128
p = 128

deeponet = DeepONet(indicator_dim=128, d_model=d_model, heads=2, p=p).to(device)
scales_mlp = nn.Sequential(
                nn.Linear(6 ,d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, d_model),
                nn.LeakyReLU(),
)
scales_mlp = scales_mlp.to(device)

forecast_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=d_model*2, nhead=8, batch_first=True), num_layers=4, enable_nested_tensor=False
).to(device)

query = nn.Parameter(torch.randn(1, 2*d_model))
summary_attention = nn.MultiheadAttention(embed_dim=2*d_model, num_heads=8, batch_first=True).to(device)

# Linear projection layer to adjust output dimension
projection_layer = nn.Linear(2 * d_model, d_model).to(device)


deeponet.eval()
scales_mlp.eval()
forecast_encoder.eval()
summary_attention.eval()
projection_layer.eval()

for y_value_windows, (y_observation_windows, t_observation_windows), mask_windows, scale_windows in train_loader:
    y_value_windows = y_value_windows.to(device)
    y_observation_windows = y_observation_windows.to(device)
    t_observation_windows = t_observation_windows.to(device)
    mask_windows = mask_windows.to(device)
    eval_grid_points = torch.linspace(0, 1, 128, device=device)

    # Flatten Windows to be of shape (batch_size * num_windows, window_size)
    y_values = y_value_windows.view(-1, y_value_windows.size(2))
    y_observations = y_observation_windows.view(-1, y_observation_windows.size(2))
    t_observations = t_observation_windows.view(-1, t_observation_windows.size(2))
    masks = mask_windows.view(-1, mask_windows.size(2))
    scales = torch.stack([tensor.to(device) for sublist in scale_windows for tensor in sublist])

    # Select indices that are not every 5th element
    indices_to_keep = torch.arange(scales.size(0)) % 5 != 0  # Create a mask to keep elements
    scales_filtered = scales[indices_to_keep]  # Filter scales to keep only elements that are not every 5th element

    scales_out = scales_mlp(scales_filtered)
    print(scales_out.shape)

    h_b = new_forward(deeponet, y_observations, t_observations, eval_grid_points, masks)
    print(h_b.shape)

    concat = torch.cat((h_b[indices_to_keep], scales_out), dim=1)
    print(concat.shape)

    forecast_out = forecast_encoder(concat)
    print(forecast_out.shape)

    forecast_out = forecast_out.view(128, -1, d_model*2)
    print(forecast_out.shape)
    print(forecast_out[0]) # 4 embeddings of size 2*d_model for each window of the first batch element

    # mask should not be required since it was already applied to get h_b and scales do not require masking
    mask_enc = torch.where(masks == 1, False, True)
    print(mask_enc.shape)
    
    q = query.unsqueeze(0).expand(BATCH_SIZE, -1, -1).to(device)  # Shape: [batch_size, 1, 2*d_model]
    u_b, _ = summary_attention(q, forecast_out, forecast_out)  # u_b: [batch_size, 1, 2*d_model]
    print(u_b.shape)

    u_b = projection_layer(u_b.squeeze(1))
    print(u_b.shape)

    break

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x9 and 6x128)

In [4]:
deeponet = DeepONet(indicator_dim=128, d_model=128, heads=2, p=128).to(device)
optim = {"model": torch.optim.AdamW(deeponet.parameters(), lr=4.6e-4)}

deeponet, optimizers, epoch, stats = load_model(deeponet, optim, "./Arwin/checkpoints/New_Im1/checkpoint_epoch_3900_New_Im1.pth")

TBOARD_LOGS = os.path.join("./Arwin", "tboard_logs", "Im2")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)
shutil.rmtree(TBOARD_LOGS) 
writer = SummaryWriter(TBOARD_LOGS)

indicator_dim = 128

forecaster = EmbeddingForcaster(d_model=128).to(device)
criterion = nn.MSELoss()
trainer = Trainer(deeponet=deeponet, model=forecaster, criterion=criterion, train_loader=train_loader, valid_loader=valid_loader, modelname="Im2", epochs=5, writer=writer)

In [5]:
trainer.fit()

Ep 0 Iter 1: Loss=0.87433:   0%|          | 0/782 [00:01<?, ?it/s]

Valid loss @ iteration 0: Loss=0.7729404451800328


Ep 0 Iter 101: Loss=0.08433:  13%|█▎        | 100/782 [01:26<08:50,  1.28it/s]

Valid loss @ iteration 100: Loss=0.07623382993772918


Ep 0 Iter 201: Loss=0.07542:  26%|██▌       | 200/782 [02:52<07:29,  1.29it/s]

Valid loss @ iteration 200: Loss=0.07223113056491404


Ep 0 Iter 301: Loss=0.06541:  38%|███▊      | 300/782 [04:18<06:23,  1.26it/s]

Valid loss @ iteration 300: Loss=0.06751365002756025


Ep 0 Iter 401: Loss=0.06448:  51%|█████     | 400/782 [05:45<04:50,  1.32it/s]

Valid loss @ iteration 400: Loss=0.06332823057092872


Ep 0 Iter 501: Loss=0.06337:  64%|██████▍   | 500/782 [07:11<03:45,  1.25it/s]

Valid loss @ iteration 500: Loss=0.06374388801700928


Ep 0 Iter 601: Loss=0.05514:  77%|███████▋  | 600/782 [08:37<02:19,  1.31it/s]

Valid loss @ iteration 600: Loss=0.05937538244852833


Ep 0 Iter 701: Loss=0.05803:  90%|████████▉ | 700/782 [10:03<01:05,  1.26it/s]

Valid loss @ iteration 700: Loss=0.058248331295508965


Ep 0 Iter 782: Loss=0.06362: 100%|██████████| 782/782 [11:14<00:00,  1.16it/s]
Ep 1 Iter 19: Loss=0.06178:   2%|▏         | 18/782 [00:14<10:09,  1.25it/s]

Valid loss @ iteration 800: Loss=0.05692511183374068


Ep 1 Iter 119: Loss=0.06196:  15%|█▌        | 118/782 [01:40<08:35,  1.29it/s]

Valid loss @ iteration 900: Loss=0.05461935092713319


Ep 1 Iter 219: Loss=0.05725:  28%|██▊       | 218/782 [03:06<07:10,  1.31it/s]

Valid loss @ iteration 1000: Loss=0.053900719200279196


Ep 1 Iter 319: Loss=0.0466:  41%|████      | 318/782 [04:33<06:06,  1.27it/s] 

Valid loss @ iteration 1100: Loss=0.050893309653974046


Ep 1 Iter 419: Loss=0.04942:  53%|█████▎    | 418/782 [05:59<04:48,  1.26it/s]

Valid loss @ iteration 1200: Loss=0.048560019144240546


Ep 1 Iter 519: Loss=0.0525:  66%|██████▌   | 518/782 [07:26<03:31,  1.25it/s] 

Valid loss @ iteration 1300: Loss=0.048180603951800106


Ep 1 Iter 619: Loss=0.0528:  79%|███████▉  | 618/782 [08:52<02:07,  1.29it/s] 

Valid loss @ iteration 1400: Loss=0.04809432425627522


Ep 1 Iter 719: Loss=0.04628:  92%|█████████▏| 718/782 [10:18<00:48,  1.31it/s]

Valid loss @ iteration 1500: Loss=0.04833162597873632


Ep 1 Iter 782: Loss=0.04344: 100%|██████████| 782/782 [11:15<00:00,  1.16it/s]
Ep 2 Iter 37: Loss=0.04425:   5%|▍         | 36/782 [00:29<09:49,  1.26it/s]

Valid loss @ iteration 1600: Loss=0.04682091251015663


Ep 2 Iter 137: Loss=0.04409:  17%|█▋        | 136/782 [01:55<08:13,  1.31it/s]

Valid loss @ iteration 1700: Loss=0.045416006854936186


Ep 2 Iter 237: Loss=0.03946:  30%|███       | 236/782 [03:22<07:13,  1.26it/s]

Valid loss @ iteration 1800: Loss=0.04514013537589241


Ep 2 Iter 337: Loss=0.05037:  43%|████▎     | 336/782 [04:48<05:40,  1.31it/s]

Valid loss @ iteration 1900: Loss=0.04442735513051351


Ep 2 Iter 437: Loss=0.03877:  56%|█████▌    | 436/782 [06:14<04:24,  1.31it/s]

Valid loss @ iteration 2000: Loss=0.043549499529249525


Ep 2 Iter 537: Loss=0.03786:  69%|██████▊   | 536/782 [07:40<03:11,  1.29it/s]

Valid loss @ iteration 2100: Loss=0.043499155225707034


Ep 2 Iter 637: Loss=0.04435:  81%|████████▏ | 636/782 [09:05<01:54,  1.27it/s]

Valid loss @ iteration 2200: Loss=0.04384953572469599


Ep 2 Iter 737: Loss=0.04131:  94%|█████████▍| 736/782 [10:31<00:35,  1.31it/s]

Valid loss @ iteration 2300: Loss=0.043176799209094514


Ep 2 Iter 782: Loss=0.03592: 100%|██████████| 782/782 [11:14<00:00,  1.16it/s]
Ep 3 Iter 55: Loss=0.04322:   7%|▋         | 54/782 [00:42<09:19,  1.30it/s]

Valid loss @ iteration 2400: Loss=0.04330571886955523


Ep 3 Iter 155: Loss=0.0431:  20%|█▉        | 154/782 [02:08<08:17,  1.26it/s] 

Valid loss @ iteration 2500: Loss=0.042850011119655536


Ep 3 Iter 255: Loss=0.04287:  32%|███▏      | 254/782 [03:33<06:39,  1.32it/s]

Valid loss @ iteration 2600: Loss=0.04181300541933845


Ep 3 Iter 355: Loss=0.04912:  45%|████▌     | 354/782 [04:59<05:25,  1.31it/s]

Valid loss @ iteration 2700: Loss=0.04168848408495679


Ep 3 Iter 455: Loss=0.04445:  58%|█████▊    | 454/782 [06:25<04:06,  1.33it/s]

Valid loss @ iteration 2800: Loss=0.04175590570358669


Ep 3 Iter 555: Loss=0.03779:  71%|███████   | 554/782 [07:52<03:00,  1.26it/s]

Valid loss @ iteration 2900: Loss=0.0418675480520024


Ep 3 Iter 655: Loss=0.03786:  84%|████████▎ | 654/782 [09:17<01:35,  1.34it/s]

Valid loss @ iteration 3000: Loss=0.04061716444352094


Ep 3 Iter 755: Loss=0.03544:  96%|█████████▋| 754/782 [10:43<00:22,  1.26it/s]

Valid loss @ iteration 3100: Loss=0.039965021186599545


Ep 3 Iter 782: Loss=0.05008: 100%|██████████| 782/782 [11:12<00:00,  1.16it/s]
Ep 4 Iter 73: Loss=0.04584:   9%|▉         | 72/782 [00:56<09:16,  1.28it/s]

Valid loss @ iteration 3200: Loss=0.041153512675972545


Ep 4 Iter 173: Loss=0.04235:  22%|██▏       | 172/782 [02:22<08:04,  1.26it/s]

Valid loss @ iteration 3300: Loss=0.0418728424199656


Ep 4 Iter 273: Loss=0.04526:  35%|███▍      | 272/782 [03:48<06:35,  1.29it/s]

Valid loss @ iteration 3400: Loss=0.039968300537735806


Ep 4 Iter 373: Loss=0.03485:  48%|████▊     | 372/782 [05:15<05:21,  1.27it/s]

Valid loss @ iteration 3500: Loss=0.03929505387649817


Ep 4 Iter 473: Loss=0.04095:  60%|██████    | 472/782 [06:41<04:00,  1.29it/s]

Valid loss @ iteration 3600: Loss=0.03905038294546744


Ep 4 Iter 573: Loss=0.03558:  73%|███████▎  | 572/782 [08:07<02:42,  1.29it/s]

Valid loss @ iteration 3700: Loss=0.038952243605665134


Ep 4 Iter 673: Loss=0.04026:  86%|████████▌ | 672/782 [09:33<01:27,  1.25it/s]

Valid loss @ iteration 3800: Loss=0.03860693790164648


Ep 4 Iter 773: Loss=0.03741:  99%|█████████▊| 772/782 [11:00<00:07,  1.29it/s]

Valid loss @ iteration 3900: Loss=0.040014855490595684


Ep 4 Iter 782: Loss=0.03913: 100%|█████████▉| 781/782 [11:15<00:00,  1.16it/s]


<Figure size 640x480 with 0 Axes>