In [1]:
from datetime import datetime
import os

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from dataset import TenBarsPlanarTrussDataset
from loss import DirectStiffnessLoss, construct_k_from_ea
from models.architecture import MultiLayerPerceptron

In [2]:
filepath = {
    'train': "data/dataset/10_bar_truss/train/data.hdf5",
    'validation': "data/dataset/10_bar_truss/validation/data.hdf5",
    'test': "data/dataset/10_bar_truss/test/data.hdf5"
}

In [3]:
layers = [50, 50, 50]
model = MultiLayerPerceptron(25, layers, 10)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

fn_loss_data = nn.MSELoss()
fn_loss_physics = DirectStiffnessLoss()

In [4]:
N_EPOCH = 3
BATCH_SIZE = 256

_train_dataset = TenBarsPlanarTrussDataset(filepath['train'])
_test_dataset = TenBarsPlanarTrussDataset(filepath['test'])
_validation_dataset = TenBarsPlanarTrussDataset(filepath['validation'])

train_dataloader = DataLoader(_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=1, persistent_workers=True)
test_dataloader = DataLoader(_test_dataset, batch_size=BATCH_SIZE, shuffle=True,
                             num_workers=1, persistent_workers=True)
validation_dataloader = DataLoader(_validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                   num_workers=1, persistent_workers=True)

connectivity = torch.tensor([[0, 1, 3, 4, 1, 2, 0, 3, 1, 4],
                             [1, 2, 4, 5, 4, 5, 4, 1, 5, 2]]).T

support = torch.tensor([0, 1, 6, 7])

In [5]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = './_runs/EA_prediction_{}'.format(timestamp)

if not os.path.exists(f"{log_dir}/log"):
    os.makedirs(f"{log_dir}/log")
    
if not os.path.exists(f"{log_dir}/weights"):
    os.makedirs(f"{log_dir}/weights")
    
writer = SummaryWriter(log_dir + "/log")

best_v_loss = np.inf

for epoch in range(N_EPOCH):
    loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=True)
    loop.set_description(f'Epoch {epoch + 1:3d}/{N_EPOCH}')

    running_loss = 0.
    running_physic_loss = 0.
    running_data_loss = 0.

    last_loss = np.inf
    last_data_loss = np.inf
    last_physic_loss = np.inf

    model.train(True)
    for i, data in loop:
        inputs, target, nodes, _, u, q = data

        optimizer.zero_grad()

        ea_pred = model(inputs)
        k_pred = construct_k_from_ea(ea_pred, nodes, connectivity, support)

        loss_data = fn_loss_data(ea_pred, target)
        loss_physic = fn_loss_physics(k_pred, u, q)
        loss = loss_data + loss_physic
        loss.backward()

        optimizer.step()

        running_loss += loss_data.item()
        running_physic_loss += loss_physic.item()
        running_data_loss += loss_data.item()

        if i % 10 == 9:
            last_loss = running_loss / 10
            last_physic_loss = running_physic_loss / 10
            last_data_loss = running_data_loss / 10

            tb_x = epoch * len(train_dataloader) + i + 1
            writer.add_scalar('Loss/train', last_loss, tb_x)
            writer.add_scalar('Physics_loss/train', last_physic_loss, tb_x)
            writer.add_scalar('Data_loss/train', last_data_loss, tb_x)

            running_loss = 0.
            running_physic_loss = 0.
            running_data_loss = 0.

        loop.set_postfix({
            'train loss': running_loss,
            'train MSE': running_data_loss,
            'train Phys. loss': running_physic_loss,
        })

    model.eval()
    with torch.no_grad():
        v_loss = 0
        v_loss_data = 0
        v_loss_physic = 0
        for data in validation_dataloader:
            inputs, target, nodes, _, u, q = data

            ea_pred = model(inputs)
            k_pred = construct_k_from_ea(ea_pred, nodes, connectivity, support)

            v_loss_data += fn_loss_data(ea_pred, target).item()
            v_loss_physic += fn_loss_physics(k_pred, u, q).item()
            v_loss += loss_data.item() + loss_physic.item()

    v_loss /= len(validation_dataloader.dataset)
    v_loss_data /= len(validation_dataloader.dataset)
    v_loss_physic /= len(validation_dataloader.dataset)

    writer.add_scalars('Training vs. Validation Loss',
                       {'Training loss': last_loss,
                        'Training data loss': last_data_loss,
                        'Training physics loss': last_physic_loss,
                        'Validation loss': v_loss,
                        'Validation data loss': v_loss_data,
                        'Validation physics loss': v_loss_physic, },
                       epoch + 1)
    writer.flush()

    tqdm.write(
        f"Validation Loss: {v_loss:.4f}, Validation MSE: {v_loss_data:.4f}, Validation Phys. Loss: {v_loss_physic:.4f}")

    if v_loss < best_v_loss:
        best_vloss = v_loss
        model_path = f"{log_dir}/weights/model_{timestamp}_{epoch}"
        torch.save(model.state_dict(), model_path)


Epoch   1/3: 100%|██████████| 391/391 [00:08<00:00, 44.82it/s, train loss=2.75e+16, train MSE=2.75e+16, train Phys. loss=2.11e+9] 


Validation Loss: 108571829812853.1250, Validation MSE: 125463980693939.0469, Validation Phys. Loss: 11490762.7021


Epoch   2/3: 100%|██████████| 391/391 [00:08<00:00, 46.92it/s, train loss=3.32e+16, train MSE=3.32e+16, train Phys. loss=2.34e+9] 


Validation Loss: 131239672947278.4688, Validation MSE: 124534806885412.2188, Validation Phys. Loss: 13597599.7644


Epoch   3/3: 100%|██████████| 391/391 [00:08<00:00, 47.87it/s, train loss=2.74e+16, train MSE=2.74e+16, train Phys. loss=4.45e+9] 


Validation Loss: 108336592585383.8906, Validation MSE: 124008457312578.1562, Validation Phys. Loss: 14064472.0716
