In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import scipy.io as scio
import scipy.signal as signal
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

In [6]:
'''
Parameters:
LSTM_step: the number of lstm steps in the network
tot_len_coarse:  total number of the input coarse mesh FDTD simulation result
len_coarse: the number of coarse mesh assigned to each lstm step
tot_len_dense:  total number of the output dense mesh FDTD simulation result
len_dense: the number of dense mesh assigned to each lstm step
INPUT_SIZE: LSTM input size
LR: learning rate
BATCH: batch size
save_model: choose if to save the trained model
'''

LSTM_step = 10
len_coarse = 18
len_dense = 13
tot_len_coarse = LSTM_step*len_coarse
tot_len_dense = LSTM_step*len_dense

INPUT_SIZE = len_coarse+len_dense     
LR = 1e-4
BATCH = 16
save_model = True

In [8]:
# Initial a hybrid CRNN

from HybridCRNN import CRNN
crnn = CRNN(INPUT_SIZE, len_dense)

# Setup optimizer, loss function, and device configuration

optimizer = torch.optim.Adam(crnn.parameters(), lr=LR, weight_decay = 5e-5)   # optimize all cnn parameters
loss_func = nn.MSELoss()

use_gpu = torch.cuda.is_available()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
if torch.cuda.device_count() > 1: 
    crnn = nn.DataParallel(crnn) 
    
crnn.to(device)
loss_func = loss_func.to(device)

# Load training and testing sets

from ClassDataset import H5Dataset
training_set = H5Dataset('training_data.h5', LSTM_step, len_coarse, len_dense)
testing_set = H5Dataset('testing_data.h5', LSTM_step, len_coarse, len_dense)

training_dataloader = DataLoader(training_set, batch_size=BATCH, shuffle=True)
testing_dataloader = DataLoader(testing_set, batch_size=BATCH, shuffle=True)

In [20]:
# Train the model
from TrainCRNN import CRNNTrainer
epochs = 100
trainer = CRNNTrainer(crnn, (training_dataloader, testing_dataloader), loss_func, optimizer, epochs, LSTM_step)
trainer.train()
training_losses, testing_losses = trainer.get_losses()

In [19]:
# post processing
if save_model:
    torch.save(crnn.state_dict(), 'crnn.pth')

In [14]:
# # validation process

# crnn = CRNN(INPUT_SIZE, len_dense)
# crnn.load_state_dict(torch.load('crnn.pth'))
# crnn.eval()  # Set the model to evaluation mode

# validation_set = H5Dataset('validation_data.h5', LSTM_step, len_coarse, len_dense)

# val_dataloader = DataLoader(validation_set, batch_size=BATCH, shuffle=True)
# prediction, gt = trainer.predict(val_dataloader)
# prediction, gt = torch.stack(prediction), torch.stack(gt)
# pre_np, gt_np = prediction.numpy().reshape((len(validation_set),-1)), gt.numpy().reshape((len(validation_set),-1))