In [1]:
import os
import shutil
import numpy as np
import pickle
import optuna
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

from Arwin.dataset.synthetic_dataset import SyntheticDataset
from Arwin.dataset.windowed_dataset import WindowedDataset
from Arwin.model.trainer_Im2 import TrainerIM2 as Trainer
from Arwin.model.deeponet import DeepONet
from Arwin.model.embedding_forecaster import EmbeddingForcaster
from Arwin.src.utils import *

from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
data_train_path = 'Arwin/dataset/training_dataset.pkl'
data_test_path = 'Arwin/dataset/testing_dataset.pkl'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
training_dataset = pickle.load(open(data_train_path, 'rb'))
valid_dataset = pickle.load(open(data_test_path, 'rb'))

In [3]:
""" Prepare dataset """
training_dataset_windows = WindowedDataset(training_dataset.functions, training_dataset.observations, training_dataset.masks, shuffle=True)
valid_dataset_windows = WindowedDataset(valid_dataset.functions, valid_dataset.observations, valid_dataset.masks, eval=True, shuffle=True)

train_loader = torch.utils.data.DataLoader(dataset=training_dataset_windows, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_windows)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset_windows, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_windows)

In [8]:
def new_forward(deeponet, y, t, eval_grid_points, mask):
    # Generate the fine grid points batch dynamically for the current batch size
    batch_size = y.shape[0]
    fine_grid_points_batch = eval_grid_points.unsqueeze(0).expand(batch_size, -1)

    # Mask the input data
    y = y.unsqueeze(-1) * mask.unsqueeze(-1)
    t = t.unsqueeze(-1) * mask.unsqueeze(-1)
    t_sample =  fine_grid_points_batch.unsqueeze(-1)
    # Branch and Trunk Embedding
    branch_embedding_y = deeponet.branch_embedding_y(y)
    branch_embedding_t = deeponet.branch_embedding_t(t)
    trunk_encoder_input = deeponet.trunk_embedding_t(t_sample)

    # generate mask for the transformer encoder
    mask_enc = torch.where(mask == 1, False, True)

    # Transformer Encoder for the Branch Network
    branch_encoder_input = deeponet.embedding_act(branch_embedding_y + branch_embedding_t)
    branch_encoder_output = deeponet.branch_encoder(branch_encoder_input, src_key_padding_mask=mask_enc)

    # Mask the output of the transformer encoder
    branch_encoder_output = branch_encoder_output * mask.unsqueeze(-1)
    """ Modifications to the original DeepONet """
    # Attention-based summary
    H = branch_encoder_output  # Shape: [batch_size, 128, d_model]
    q = deeponet.query.unsqueeze(0).expand(batch_size, -1, -1)  # Shape: [batch_size, 1, d_model]

    # Multihead attention: query (q), keys/values (H)
    h_b, _ = deeponet.summary_attention(q, H, H, key_padding_mask=mask_enc)  # h_b: [batch_size, 1, d_model]
    h_b = h_b.squeeze(1)  # Flatten to [batch_size, d_model]

    return h_b
    """ -------------------------------------- """
    branch_output = self.branch_mlp(h_b) 
    trunk_output = self.trunk_mlp(trunk_encoder_input)

    """ Modifications to the original DeepONet """
    # Expand branch_output to match the sequence length of trunk_output
    branch_output_expanded = branch_output.unsqueeze(1).expand(-1, trunk_output.shape[1], -1)  # [batch_size, d_model, p]

    # Concatenate branch_output with trunk_output along the feature dimension
    combined_input = torch.cat((branch_output_expanded, trunk_output), dim=-1)  # [batch_size, d_model, 2 * p]

    # Flatten the batch and sequence dimensions to apply combined_mlp in one step
    combined_input_flattened = combined_input.view(-1, combined_input.shape[-1])  # [batch_size * d_model, 2 * p]

    # Pass through combined_mlp
    mlp_output = self.combined_mlp(combined_input_flattened)  # [batch_size * d_model, 1]

    # Reshape back to the original structure
    combined_out = mlp_output.view(branch_output.shape[0], trunk_output.shape[1])  # [batch_size, d_model]

    """ -------------------------------------- """
        
    return combined_out

In [None]:
d_model = 128
p = 128

deeponet = DeepONet(indicator_dim=128, d_model=d_model, heads=2, p=p).to(device)
scales_mlp = nn.Sequential(
                nn.Linear(6 ,d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, d_model),
                nn.LeakyReLU(),
)
scales_mlp = scales_mlp.to(device)

forecast_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=d_model*2, nhead=8, batch_first=True), num_layers=4, enable_nested_tensor=False
).to(device)

query = nn.Parameter(torch.randn(1, 2*d_model))
summary_attention = nn.MultiheadAttention(embed_dim=2*d_model, num_heads=8, batch_first=True).to(device)

# Linear projection layer to adjust output dimension
projection_layer = nn.Linear(2 * d_model, d_model).to(device)


deeponet.eval()
scales_mlp.eval()
forecast_encoder.eval()
summary_attention.eval()
projection_layer.eval()

for y_value_windows, (y_observation_windows, t_observation_windows), mask_windows, scale_windows in train_loader:
    y_value_windows = y_value_windows.to(device)
    y_observation_windows = y_observation_windows.to(device)
    t_observation_windows = t_observation_windows.to(device)
    mask_windows = mask_windows.to(device)
    eval_grid_points = torch.linspace(0, 1, 128, device=device)

    # Flatten Windows to be of shape (batch_size * num_windows, window_size)
    y_values = y_value_windows.view(-1, y_value_windows.size(2))
    y_observations = y_observation_windows.view(-1, y_observation_windows.size(2))
    t_observations = t_observation_windows.view(-1, t_observation_windows.size(2))
    masks = mask_windows.view(-1, mask_windows.size(2))
    scales = torch.stack([tensor.to(device) for sublist in scale_windows for tensor in sublist])

    # Select indices that are not every 5th element
    indices_to_keep = torch.arange(scales.size(0)) % 5 != 0  # Create a mask to keep elements
    scales_filtered = scales[indices_to_keep]  # Filter scales to keep only elements that are not every 5th element

    scales_out = scales_mlp(scales_filtered)
    print(scales_out.shape)

    h_b = new_forward(deeponet, y_observations, t_observations, eval_grid_points, masks)
    print(h_b.shape)

    concat = torch.cat((h_b[indices_to_keep], scales_out), dim=1)
    print(concat.shape)

    forecast_out = forecast_encoder(concat)
    print(forecast_out.shape)

    forecast_out = forecast_out.view(128, -1, d_model*2)
    print(forecast_out.shape)
    print(forecast_out[0]) # 4 embeddings of size 2*d_model for each window of the first batch element

    # mask should not be required since it was already applied to get h_b and scales do not require masking
    mask_enc = torch.where(masks == 1, False, True)
    print(mask_enc.shape)
    
    q = query.unsqueeze(0).expand(BATCH_SIZE, -1, -1).to(device)  # Shape: [batch_size, 1, 2*d_model]
    u_b, _ = summary_attention(q, forecast_out, forecast_out)  # u_b: [batch_size, 1, 2*d_model]
    print(u_b.shape)

    u_b = projection_layer(u_b.squeeze(1))
    print(u_b.shape)

    break

torch.Size([512, 128])
torch.Size([640, 128])
160.0
128.0
torch.Size([512, 256])
torch.Size([512, 256])
torch.Size([128, 4, 256])
tensor([[-0.3692,  0.8525,  1.7620,  ...,  0.9500,  0.3524, -1.5420],
        [-0.2209,  0.7849,  1.7999,  ...,  0.9170,  0.1758, -1.5426],
        [-0.4388,  0.7011,  1.8005,  ...,  0.8730,  0.3761, -1.6279],
        [-0.3215,  0.8001,  1.8848,  ...,  0.9460,  0.4008, -1.5643]],
       device='cuda:0', grad_fn=<SelectBackward0>)
torch.Size([640, 128])
torch.Size([128, 1, 256])
torch.Size([128, 128])


In [4]:
deeponet = DeepONet(indicator_dim=128, d_model=128, heads=2, p=128).to(device)
optim = {"model": torch.optim.AdamW(deeponet.parameters(), lr=4.6e-4)}

deeponet, optimizers, epoch, stats = load_model(deeponet, optim, "./Arwin/checkpoints/New_Im1/checkpoint_epoch_3900_New_Im1.pth")

TBOARD_LOGS = os.path.join("./Arwin", "tboard_logs", "Im2_combined_loss")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)
shutil.rmtree(TBOARD_LOGS) 
writer = SummaryWriter(TBOARD_LOGS)

indicator_dim = 128

forecaster = EmbeddingForcaster(d_model=128).to(device)
criterion = nn.MSELoss()
trainer = Trainer(deeponet=deeponet, model=forecaster, criterion=criterion, train_loader=train_loader, valid_loader=valid_loader, modelname="Im2_combined_loss", epochs=5, writer=writer)

In [5]:
trainer.fit()

Ep 0 Iter 1: Loss=1.31519:   0%|          | 0/782 [00:01<?, ?it/s]

Valid loss @ iteration 0: Loss=1.2426890526332108


Ep 0 Iter 101: Loss=0.30953:  13%|█▎        | 100/782 [01:28<09:06,  1.25it/s]

Valid loss @ iteration 100: Loss=0.3032458030125674


Ep 0 Iter 201: Loss=0.30905:  26%|██▌       | 200/782 [02:55<07:42,  1.26it/s]

Valid loss @ iteration 200: Loss=0.3000796112944098


Ep 0 Iter 301: Loss=0.29805:  38%|███▊      | 300/782 [04:22<06:21,  1.26it/s]

Valid loss @ iteration 300: Loss=0.29786408268937875


Ep 0 Iter 401: Loss=0.31274:  51%|█████     | 400/782 [05:49<05:07,  1.24it/s]

Valid loss @ iteration 400: Loss=0.29413389297677023


Ep 0 Iter 501: Loss=0.29662:  64%|██████▍   | 500/782 [07:16<03:48,  1.24it/s]

Valid loss @ iteration 500: Loss=0.29178807942890655


Ep 0 Iter 601: Loss=0.28993:  77%|███████▋  | 600/782 [08:42<02:27,  1.23it/s]

Valid loss @ iteration 600: Loss=0.2863539908738697


Ep 0 Iter 701: Loss=0.29248:  90%|████████▉ | 700/782 [10:09<01:05,  1.25it/s]

Valid loss @ iteration 700: Loss=0.2849053947948942


Ep 0 Iter 782: Loss=0.28375: 100%|██████████| 782/782 [11:21<00:00,  1.15it/s]
Ep 1 Iter 19: Loss=0.28318:   2%|▏         | 18/782 [00:15<10:20,  1.23it/s]

Valid loss @ iteration 800: Loss=0.28553333732427333


Ep 1 Iter 119: Loss=0.28283:  15%|█▌        | 118/782 [01:42<08:59,  1.23it/s]

Valid loss @ iteration 900: Loss=0.28178607964632557


Ep 1 Iter 219: Loss=0.29697:  28%|██▊       | 218/782 [03:09<07:17,  1.29it/s]

Valid loss @ iteration 1000: Loss=0.2821245548479697


Ep 1 Iter 319: Loss=0.27852:  41%|████      | 318/782 [04:36<06:07,  1.26it/s]

Valid loss @ iteration 1100: Loss=0.2812826744481629


Ep 1 Iter 419: Loss=0.29081:  53%|█████▎    | 418/782 [06:04<04:50,  1.25it/s]

Valid loss @ iteration 1200: Loss=0.2778192507285698


Ep 1 Iter 519: Loss=0.28253:  66%|██████▌   | 518/782 [07:31<03:31,  1.25it/s]

Valid loss @ iteration 1300: Loss=0.2782200012429088


Ep 1 Iter 619: Loss=0.28621:  79%|███████▉  | 618/782 [08:58<02:10,  1.25it/s]

Valid loss @ iteration 1400: Loss=0.2766389719703618


Ep 1 Iter 719: Loss=0.27857:  92%|█████████▏| 718/782 [10:25<00:51,  1.25it/s]

Valid loss @ iteration 1500: Loss=0.27701248213940977


Ep 1 Iter 782: Loss=0.27289: 100%|██████████| 782/782 [11:22<00:00,  1.15it/s]
Ep 2 Iter 37: Loss=0.28964:   5%|▍         | 36/782 [00:29<09:58,  1.25it/s]

Valid loss @ iteration 1600: Loss=0.27433865736512575


Ep 2 Iter 137: Loss=0.28584:  17%|█▋        | 136/782 [01:56<08:40,  1.24it/s]

Valid loss @ iteration 1700: Loss=0.2768523073663898


Ep 2 Iter 237: Loss=0.29464:  30%|███       | 236/782 [03:23<07:16,  1.25it/s]

Valid loss @ iteration 1800: Loss=0.2754643074437684


Ep 2 Iter 337: Loss=0.28456:  43%|████▎     | 336/782 [04:50<05:51,  1.27it/s]

Valid loss @ iteration 1900: Loss=0.2727103166136087


Ep 2 Iter 437: Loss=0.27288:  56%|█████▌    | 436/782 [06:17<04:38,  1.24it/s]

Valid loss @ iteration 2000: Loss=0.27170293851225985


Ep 2 Iter 537: Loss=0.28387:  69%|██████▊   | 536/782 [07:43<03:13,  1.27it/s]

Valid loss @ iteration 2100: Loss=0.2731901906868991


Ep 2 Iter 637: Loss=0.28054:  81%|████████▏ | 636/782 [09:10<01:57,  1.24it/s]

Valid loss @ iteration 2200: Loss=0.2709195172669841


Ep 2 Iter 737: Loss=0.27116:  94%|█████████▍| 736/782 [10:37<00:36,  1.26it/s]

Valid loss @ iteration 2300: Loss=0.2723681633086765


Ep 2 Iter 782: Loss=0.29379: 100%|██████████| 782/782 [11:20<00:00,  1.15it/s]
Ep 3 Iter 55: Loss=0.27833:   7%|▋         | 54/782 [00:43<09:50,  1.23it/s]

Valid loss @ iteration 2400: Loss=0.2702297945232952


Ep 3 Iter 155: Loss=0.27567:  20%|█▉        | 154/782 [02:11<08:29,  1.23it/s]

Valid loss @ iteration 2500: Loss=0.26939185811024086


Ep 3 Iter 255: Loss=0.26475:  32%|███▏      | 254/782 [03:37<06:57,  1.26it/s]

Valid loss @ iteration 2600: Loss=0.26850293766634137


Ep 3 Iter 355: Loss=0.27574:  45%|████▌     | 354/782 [05:05<05:38,  1.26it/s]

Valid loss @ iteration 2700: Loss=0.2686753230644207


Ep 3 Iter 455: Loss=0.26301:  58%|█████▊    | 454/782 [06:32<04:18,  1.27it/s]

Valid loss @ iteration 2800: Loss=0.2660626462277244


Ep 3 Iter 555: Loss=0.27828:  71%|███████   | 554/782 [07:59<03:02,  1.25it/s]

Valid loss @ iteration 2900: Loss=0.26487988584181843


Ep 3 Iter 655: Loss=0.26127:  84%|████████▎ | 654/782 [09:26<01:41,  1.26it/s]

Valid loss @ iteration 3000: Loss=0.26285914594636245


Ep 3 Iter 755: Loss=0.27825:  96%|█████████▋| 754/782 [10:53<00:22,  1.26it/s]

Valid loss @ iteration 3100: Loss=0.26231999578429205


Ep 3 Iter 782: Loss=0.25759: 100%|██████████| 782/782 [11:22<00:00,  1.15it/s]
Ep 4 Iter 73: Loss=0.26271:   9%|▉         | 72/782 [00:58<09:13,  1.28it/s]

Valid loss @ iteration 3200: Loss=0.26038285695454655


Ep 4 Iter 173: Loss=0.26785:  22%|██▏       | 172/782 [02:25<08:03,  1.26it/s]

Valid loss @ iteration 3300: Loss=0.25889970128442724


Ep 4 Iter 273: Loss=0.25769:  35%|███▍      | 272/782 [03:52<06:43,  1.26it/s]

Valid loss @ iteration 3400: Loss=0.25754546520172383


Ep 4 Iter 373: Loss=0.25019:  48%|████▊     | 372/782 [05:18<05:26,  1.26it/s]

Valid loss @ iteration 3500: Loss=0.2579229818839653


Ep 4 Iter 473: Loss=0.25837:  60%|██████    | 472/782 [06:45<04:16,  1.21it/s]

Valid loss @ iteration 3600: Loss=0.25677898292448004


Ep 4 Iter 573: Loss=0.26438:  73%|███████▎  | 572/782 [08:12<02:50,  1.23it/s]

Valid loss @ iteration 3700: Loss=0.2564508906182121


Ep 4 Iter 673: Loss=0.2612:  86%|████████▌ | 672/782 [09:39<01:28,  1.24it/s] 

Valid loss @ iteration 3800: Loss=0.2564668312084441


Ep 4 Iter 773: Loss=0.25811:  99%|█████████▊| 772/782 [11:06<00:08,  1.24it/s]

Valid loss @ iteration 3900: Loss=0.25572802783811793


Ep 4 Iter 782: Loss=0.25893: 100%|█████████▉| 781/782 [11:20<00:00,  1.15it/s]


<Figure size 640x480 with 0 Axes>