In [1]:
import os
import shutil
import numpy as np
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

from Task_1.synthetic_dataset import SyntheticDataset
from Arwin.model.trainer import Trainer
from Arwin.model.deeponet import *
from Task_1.utils import collate_fn_fixed, collate_fn

from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 256
data_train_path = 'Arwin/dataset/training_dataset.pkl'
data_test_path = 'Arwin/dataset/testing_dataset.pkl'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_dataset = SyntheticDataset(100000, 128, padding=False, verbose=True)
valid_dataset = SyntheticDataset(1000, 128, padding=False, verbose=True, test=True)

Sampling Functions from Beta Distribution 1/6 with a=1, b=2: 100%|██████████| 16666/16666 [00:59<00:00, 280.37it/s]
Sampling Functions from Beta Distribution 2/6 with a=1, b=5: 100%|██████████| 16666/16666 [00:58<00:00, 282.48it/s]
Sampling Functions from Beta Distribution 3/6 with a=2, b=1: 100%|██████████| 16666/16666 [00:55<00:00, 298.07it/s]
Sampling Functions from Beta Distribution 4/6 with a=2, b=5: 100%|██████████| 16666/16666 [00:56<00:00, 293.68it/s]
Sampling Functions from Beta Distribution 5/6 with a=5, b=1: 100%|██████████| 16666/16666 [00:56<00:00, 295.04it/s]
Sampling Functions from Beta Distribution 6/6 with a=5, b=2: 100%|██████████| 16670/16670 [00:56<00:00, 297.41it/s]
Generating Observations: 100%|██████████| 100000/100000 [00:07<00:00, 13887.16it/s]
Sampling Functions from Beta Distribution 1/6 with a=1, b=2: 100%|██████████| 166/166 [00:00<00:00, 294.36it/s]
Sampling Functions from Beta Distribution 2/6 with a=1, b=5: 100%|██████████| 166/166 [00:00<00:00, 252.48it

In [None]:
# with open(data_train_path, 'wb') as f:
#     pickle.dump(train_dataset, f)
# with open(data_test_path, 'wb') as f:
#     pickle.dump(valid_dataset, f)

In [2]:
train_dataset = pickle.load(open(data_train_path, 'rb'))
valid_dataset = pickle.load(open(data_test_path, 'rb'))

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_fixed)
validation_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_fixed)

In [4]:
TBOARD_LOGS = os.path.join("./Arwin", "tboard_logs", "IM1_test2_old_data")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)
shutil.rmtree(TBOARD_LOGS) 
writer = SummaryWriter(TBOARD_LOGS)

indicator_dim = 128

deeponet = DeepONet(indicator_dim=indicator_dim, d_model=128, heads=2, p=128).to(device)
criterion = nn.MSELoss()
trainer = Trainer(model=deeponet, criterion=criterion, train_loader=train_loader, valid_loader=validation_loader, modelname="IM1_test2_old_data", epochs=1, writer=writer)

In [5]:
trainer.fit()

Ep 0 Iter 1: Loss=1.00008:   0%|          | 0/391 [00:01<?, ?it/s]

Valid loss @ iteration 0: Loss=1.0000367164611816


Ep 0 Iter 51: Loss=0.42655:  13%|█▎        | 50/391 [01:01<06:41,  1.18s/it]

Valid loss @ iteration 50: Loss=0.44927891716361046


Ep 0 Iter 101: Loss=0.43601:  26%|██▌       | 100/391 [02:01<05:32,  1.14s/it]

Valid loss @ iteration 100: Loss=0.4092009849846363


Ep 0 Iter 151: Loss=0.42706:  38%|███▊      | 150/391 [03:00<04:38,  1.16s/it]

Valid loss @ iteration 150: Loss=0.4825812354683876


Ep 0 Iter 201: Loss=0.23276:  51%|█████     | 200/391 [04:00<03:38,  1.14s/it]

Valid loss @ iteration 200: Loss=0.2068569427356124


Ep 0 Iter 251: Loss=0.21065:  64%|██████▍   | 250/391 [04:59<02:43,  1.16s/it]

Valid loss @ iteration 250: Loss=0.18529073055833578


Ep 0 Iter 301: Loss=0.17802:  77%|███████▋  | 300/391 [05:58<01:44,  1.15s/it]

Valid loss @ iteration 300: Loss=0.16290128882974386


Ep 0 Iter 351: Loss=0.17941:  90%|████████▉ | 350/391 [06:57<00:47,  1.17s/it]

Valid loss @ iteration 350: Loss=0.14298695605248213


Ep 0 Iter 391: Loss=0.11295: 100%|█████████▉| 390/391 [07:45<00:01,  1.19s/it]


<Figure size 640x480 with 0 Axes>