In [None]:
import wandb
wandb.login()

In [None]:
import sys
sys.path.append("..")
print(sys.path)
import os
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
import numpy as np
import h5py as h5
import matplotlib.pyplot as plt
from utils.PDE_Net import DeepONet_NS, weight_init
from utils.DataGenerate_DON import Dataset_DON
from torch.utils.tensorboard import SummaryWriter
from argparse import Namespace
from utils.utilities3 import *

torch.set_default_dtype(torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dx = 0.5
n_x = int(8/dx)
n_y = int(3/dx)
p_x = int(800/n_x+1)
p_y = int(300/n_y+1)
self_split = 2

config = Namespace(
	project_name='DON',

	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),

	dx=0.5,
	n_x=n_x,
	n_y=n_y,
	p_x=p_x,
	p_y=p_y,
	self_split=2,

	epochs=10000,
	batch_size=6000,
	learning_rate=0.0007565,
	dropout=0.4,
	weight_decay=0.00042,
	max_norm=2.864,

	branch_input=204,
	trunk_input=2,
	hidden_size=512,
	branch_layer=6,
	trunk_layer=4,
	
	optim_type='AdamW',

	path_trained_model=r'trained_model',
	path_label=r'../train_data/label',
	path_label_test=r'../test_data/label',
)

In [None]:


sweep_config = {
	'method': 'random',
}
metric = {
	'name': 'loss_train',
	'goal': 'minimize'
}
sweep_config['metric'] = metric

sweep_config['parameters'] = {}
sweep_config['parameters'].update({
	'project_name': {'value': 'DON'},
	'epochs': {'value': 10000},
	'dx': {'value': 0.5},
	'n_x': {'value': 16},
	'n_y': {'value': 6},
	'p_x': {'value': 51},
	'p_y': {'value': 51},
	'self_split': {'value': 2},

	'branch_input': {'value': 204},
	'trunk_input': {'value': 2},
	'hidden_size': {'value': 512},

	'path_trained_model': {'value': r'trained_model'},
	'path_label': {'value': r'../train_data/label'},
	'path_label_test': {'value': r'../test_data/label'},
})
sweep_config['parameters'].update({
	'batch_size': {'values': [600,1500,3000]},
	'branch_layer': {'values': [4,5,6,7]},
	'trunk_layer': {'values': [4,5,6,7]},
	'optim_type': {'values': ['Adam','SGD','AdamW']},
	'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 1e-1},
	'dropout': {'distribution': 'q_uniform', 'q': 0.2, 'min': 0, 'max': 0.6},
	'weight_decay': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 1e-1},
	'max_norm': {'distribution': 'uniform', 'min': 1, 'max': 10},
})

# 初始化sweep controller
sweep_id = wandb.sweep(sweep_config, project='DON')

In [None]:
def create_dataloader(config):
	train_label = torch.empty(0, config.p_x, config.p_y, 1)
	file_list = os.listdir(config.path_label)
	for file in file_list:
		if file.endswith('.npy'):
				label = np.load(os.path.join(config.path_label, file))
				train_label = torch.cat((train_label, torch.Tensor(label)), 0)

	train_bc = torch.cat((train_label[:,:,0,0], train_label[:,-1,:,0], train_label[:,:,-1,0], train_label[:,0,:,0]), axis=1)

	train_label = train_label.reshape(train_label.shape[0],-1)

	test_label = torch.empty(0, config.p_x, config.p_y, 1)
	file_list = os.listdir(config.path_label_test)
	for file in file_list:
		if file.endswith('.npy'):
				label = np.load(os.path.join(config.path_label_test, file))
				test_label = torch.cat((test_label, torch.Tensor(label)), 0)

	test_bc = torch.cat((test_label[:,:,0,0], test_label[:,-1,:,0], test_label[:,:,-1,0], test_label[:,0,:,0]), axis=1)

	test_label = test_label.reshape(test_label.shape[0],-1)

	x_data,y_data,x_min,x_max = normalize(train_bc, train_label)
	x_test,y_test,x_min_test,x_max_test = normalize(test_bc, test_label)

	train_loader = DataLoader(torch.utils.data.TensorDataset(x_data, y_data, x_min, x_max),batch_size=config.batch_size,shuffle=True)
	test_loader = DataLoader(torch.utils.data.TensorDataset(x_test, y_test, x_min_test, x_max_test),batch_size=config.batch_size,shuffle=False)

	coordinate = np.meshgrid(np.linspace(0,config.dx,config.p_x), np.linspace(0,config.dx,config.p_y))
	coordinate = np.stack((coordinate[0],coordinate[1]), axis=-1).reshape(-1,2)

	coordinate = torch.Tensor(coordinate).float().to(device)

	return train_loader, test_loader, coordinate

In [None]:
def train_epoch(config,model,optimizer,myloss,scheduler,train_loader,test_loader,coordinate,device):
	model.train()
	train_loss_epoch = 0
	for batch in train_loader:
		loss_train = 0
		data_x,data_y,data_min,data_max = batch

		x = data_x.float().to(device)
		y = data_y.float().to(device)
		# batch_min = data_min.float().to(device)
		# batch_max = data_max.float().to(device)

		pred = model(x, coordinate)

		# pred = inverse_normalize(pred,batch_min,batch_max)
		# y =inverse_normalize(y,batch_min,batch_max)

		loss_train = myloss(pred.clone(), y.clone())
		regularization_loss = 0
		for param in model.parameters():
			regularization_loss += torch.norm(param, p=2)
		loss_train = loss_train + config.weight_decay * regularization_loss
		train_loss_epoch = train_loss_epoch + loss_train.item()
		
		optimizer.zero_grad()
		loss_train.backward(retain_graph=True)
		torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
		optimizer.step()
	train_loss_epoch = train_loss_epoch / len(train_loader)
	# print('epoch:', epoch, 'loss_train:', train_loss_epoch)
	# scheduler.step(train_loss_epoch)
	return model, optimizer, train_loss_epoch

def eval_epoch(config,model,optimizer,myloss,scheduler,train_loader,test_loader,coordinate,device):
	model.eval()
	test_loss_epoch = 0
	with torch.no_grad():
		for batch_test in test_loader:
			loss_test = 0
			test_x,test_y,test_min,test_max = batch_test
			
			test_x = test_x.float().to(device)			# [length,time_step,51,51,3]
			test_y = test_y.float().to(device)			# [length,time_step,51,51,3]
			# batch_min_test = test_min.float().to(device)
			# batch_max_test = test_max.float().to(device)

			pred_test = model(test_x, coordinate)

			# pred_test = inverse_normalize(pred_test,batch_min_test,batch_max_test)
			# test_y = inverse_normalize(test_y,batch_min_test,batch_max_test)

			loss_test = myloss(pred_test.clone(), test_y.clone())

			test_loss_epoch = test_loss_epoch + loss_test.item()
		test_loss_epoch = test_loss_epoch/len(test_loader)
		# print('epoch:', epoch, 'loss_test:', test_loss_epoch)
	return test_loss_epoch

In [None]:
def train(is_model_saved=False):

	train_loader, test_loader, coordinate = create_dataloader(config)
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = DeepONet_NS(config.branch_input, config.trunk_input, config.branch_layer, config.trunk_layer, config.hidden_size, config.dropout)
	model.apply(weight_init)
	begin_epoch = 0

	if is_model_saved:
		file_list = os.listdir(config.path_trained_model)
		if len(file_list) > 0:
			last_model = os.listdir(config.path_trained_model)[-1]
		load_path = os.path.join(config.path_trained_model, last_model)
		begin_epoch = load_model(load_path, optimizer, model)
	model.to(device)

	optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.learning_rate)
	myloss = nn.MSELoss()
	scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=100, verbose=True, min_lr=1e-6)
	#======================================================
	nowtime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
	wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
	model.run_id = wandb.run.id
	#======================================================
	for epoch in range(begin_epoch, config.epochs+1):
		model, optimizer, train_loss_epoch = train_epoch(config,model,optimizer,myloss,scheduler,train_loader,test_loader,coordinate,device)
		test_loss_epoch = eval_epoch(config,model,optimizer,myloss,scheduler,train_loader,test_loader,coordinate,device)
		nowtime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
		print('epoch:', epoch, 'loss_train:', train_loss_epoch, 'loss_test:', test_loss_epoch)
		#======================================================
		wandb.log({'epoch': epoch, 'loss_train': train_loss_epoch, 'loss_test': test_loss_epoch})
		#======================================================
		if epoch % 100 == 0:
			save_path = os.path.join(wandb.config.path_trained_model, 'DON_{}.pth'.format(epoch))
			save_model(save_path, epoch, optimizer, model)

	wandb.finish()
	return model, optimizer

In [None]:
model, optimizer = train(is_model_saved=False)