### This jupyter notebook employs a fully connective neural network(FC) or its alias artificial neural network (ANN) to learn the mapping between input current configuration between output magnetic field 

In [None]:
!pip install -U "ray[data,train,tune,serve]"

In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import torch
from early_stopping import EarlyStopping

if torch.cuda.device_count():
    device = 'cuda'
    use_gpu = True
    print('Good to go')
else:
    device = 'cpu'
    use_gpu = False
    print('Using cpu')

In [None]:
from ReadData import ReadCurrentAndField 
import glob
import os 

# print(os.getcwd())
foldername="./Data/"
filepattern = "MagneticField[0-9]*.txt"
#data = ReadFolder(foldername,filepattern)
load_file_num = 1460
train_file_num = 1000
grid_size = 21
data = ReadCurrentAndField (foldername,filepattern, load_file_num)


data=data.reshape(load_file_num,grid_size,grid_size,grid_size,18)
mask = torch.cat((torch.ones(1,1,1,1,12),1e3*torch.ones(1,1,1,1,6)), dim=4)
# position unit mm, B field unit mT, Current unit Ampere
data = mask*data

sparsity = 4

Current_position =data[:train_file_num,0::sparsity,0::sparsity,0::sparsity,:15].reshape(-1,15) # position unit mm
Bfield = data[:train_file_num,0::sparsity,0::sparsity,0::sparsity,15:].reshape(-1,3) # B field unit mT

print(data.shape)
print(data[:train_file_num,0::sparsity,0::sparsity,0::sparsity,:15].shape)
print('position shape', Current_position.shape)
print('Bfield shape', Bfield.shape)


In [None]:
from Neural_network import NN_net, Plain_fc_block, weight_init, eMNS_Dataset
from Training_loop_v2 import train_ANN
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
from ray.tune.tuner import Tuner
from ray import tune
from ray.tune.schedulers import ASHAScheduler
import ray, os
import torch.nn.functional as F

# construct dataset

Current_position =data[:train_file_num,0::sparsity,0::sparsity,0::sparsity,:15].reshape(-1,15) # position unit mm

Bfield = data[:train_file_num,0::sparsity,0::sparsity,0::sparsity,15:].reshape(-1,3) # B field unit mT

dataset = eMNS_Dataset(
    x=Current_position,
    y=Bfield
)

# split the dataset to train, validation, test
train_set, valid_set = torch.utils.data.random_split(dataset, [0.9,0.1])

# normailzation
extremes = dataset.train_norm_ANN(train_indices = train_set.indices, boundary_index=12)

###############################################
# Config the neural network
###############################################
num_input = 15
num_output = 3
fc_stages = [(num_input,128,1),(128,64,1),(64,32,1)]
fc_network = NN_net(None,fc_stages,None,Plain_fc_block, num_output=num_output)

loss_func = lambda preds, y: F.l1_loss(preds, y)


################################################
# Train the neural network
################################################

train_loop_config = {
                'epochs': 50,
                'lr_max': 1e-3,
                'lr_min': 2.5e-6,
                'batch_size': 128,
                'L2_norm'   : 0,
                'verbose': False,
                'schedule': [],
                'learning_rate_decay': 0.5,
                'num_input'   : num_input,
                'num_output'  : num_output,
                'fc_stages'   : fc_stages,
                'backward'    : False,
                'maxB'        : extremes[4],
                'minB'        : extremes[5],
                'device'      : device,
                'loss_func'   : loss_func,
                'forward_model_path' : None
                # You can even grid search various datasets in Tune.
                # "datasets": tune.grid_search(
                #         [ds1, ds2]
                #     ),
}

scaling_config = ScalingConfig(
    num_workers = 1,
    use_gpu = use_gpu,
    resources_per_worker = {"CPU":7, "GPU":1}
)

run_config = RunConfig(name="EMS_ANN_v2", storage_path= "~/ray_results",checkpoint_config=CheckpointConfig(num_to_keep=1))

def train_loop_per_worker(params):
    train_ANN(train_set=train_set, valid_set=valid_set, config=params)

trainer = TorchTrainer(
    train_loop_per_worker = train_loop_per_worker,
    train_loop_config = train_loop_config,
    scaling_config = scaling_config,
    run_config = run_config,

)
result = trainer.fit()



In [None]:
from torchsummary import summary
summary(fc_network, (1,15))
for param_tensor in fc_network.state_dict():
    print(param_tensor, '\t', fc_network.state_dict()[param_tensor].size())

In [None]:
from utils import plot_ray_results
plot_ray_results(results=result, metrics_names = ['rmse_val','rmse_train'])

In [None]:
checkpoint_data = torch.load(os.path.join(result.checkpoint.path,"model.pt"))

model_path = r"./Trained_model/EMS_ANN_v2.pt"
torch.save(checkpoint_data, model_path)

## Test dataset performance

In [None]:
# position unit mm, current unit Ampere
sparsity = 1
Current_position_test =data[train_file_num:, ::sparsity, ::sparsity, ::sparsity, :15].reshape(-1,15) 
# B field unit mT
Bfield_test = data[train_file_num:, ::sparsity, ::sparsity, ::sparsity,15:].reshape(-1,3)

num_sample = Current_position_test.shape[0]
print('position shape', Current_position_test.shape)
print('Bfield shape', Bfield_test.shape)

# construct dataset
test_set = eMNS_Dataset(
    x=Current_position_test,
    y=Bfield_test
)
test_set.test_norm_ANN(extremes=extremes, boundary_index=12)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=train_loop_config['batch_size'],
    shuffle=False)


In [None]:
from utils import predict_check_rmse_ANN, check_rmse_ANN
from Training_loop_v2 import construct_model_ANN 

model_path = r"./Trained_model/EMS_ANN_v2.pt"
model = torch.load(model_path)['model']

prediction, rmse, mse, Rsquare = predict_check_rmse_ANN(test_loader, model,config=train_loop_config)
# check_rmse_ANN(test_loader, model, device, extremes[4], extremes[5])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
B_est = prediction.reshape(-1, grid_size, grid_size, grid_size, 3)
Bfield_test = Bfield_test.reshape(-1, grid_size, grid_size, grid_size, 3)

# B_est = prediction.reshape(-1, 6, 6, 6, 3)
# Bfield_test = Bfield_test.reshape(-1, 6, 6, 6, 3)

current_index=3
z_plane_index= 5
# fig, ax = plt.subplots(3, 2)
# fig.tight_layout(h_pad=2)
fig = plt.figure()
ylables=['Bx\mT','By\mT','Bz\mT']
# fig.tight_layout(pad=0.4, w_pad=0, h_pad=0)
for i in range(1,4):
    plt.subplot(3,2,2*i-1)
    plt.imshow(B_est[current_index,:,:,z_plane_index,i-1])    
    plt.colorbar()
    plt.ylabel(ylables[i-1])
    plt.subplot(3,2,2*i)
    plt.imshow(Bfield_test[current_index,:,:,z_plane_index,i-1])
    plt.colorbar()
plt.show()