In [None]:
#Note: This code plots Figure 3 of the manuscript.

import matplotlib.pyplot as plt
#fig, ax = plt.subplots()
import time
import numpy as np
import torch
from torch import nn, optim, autograd
from math import pi
from sklearn.model_selection import RepeatedKFold
from sklearn.metrics import r2_score
import torch.nn.functional as F


torch.manual_seed(123456)
np.random.seed(123456)


def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)


class Unit3(nn.Module):
    def __init__(self, in_N, out_N,actf):
        super(Unit3, self).__init__()
        self.in_N = in_N
        self.out_N = out_N
        self.actf = actf
        self.L = nn.Linear(in_N, out_N)

    def forward(self, x):
        actf=self.actf
        x1 = self.L(x)
        if actf==0:
            x2 = torch.tanh(x1)
        elif actf==1:
            x2 = torch.sigmoid(x1) 
        elif actf==2:
            x2 = torch.relu(x1)
        elif actf==3:
            x2 = torch.selu(x1)
        elif actf==4:
            x2 = F.softmax(x1, dim=1)
        return x2
    
class NN3(nn.Module):
    def __init__(self, in_N, width1, depth1,width2, depth2,out_N,bn,dp,dprate,actf):
        super(NN3, self).__init__()
        self.width1 = width1
        self.width2 = width2
        self.depth1 = depth1
        self.depth2 = depth2
        self.bn = bn
        self.dp = dp
        self.dprate = dprate
        self.actf = actf
        self.in_N = in_N
        self.out_N = out_N
        self.stack = nn.ModuleList()
        self.stack.append(Unit3(in_N, width1[0],actf))
        if bn==1:
            self.stack.append(nn.BatchNorm1d(width1[0]))
        for i in range(1,depth1):
            self.stack.append(Unit3(width1[i-1], width1[i],actf))
        
        if dp==1:
            self.stack.append(nn.Dropout(p=dprate))
        if depth2==1:
            self.stack.append(Unit3(width1[i], width2[0],1)) 
        else:
            self.stack.append(Unit3(width1[i], width2[0],actf))    
            for i in range(1,depth2-1):
                self.stack.append(Unit3(width2[i-1], width2[i],actf))
            self.stack.append(Unit3(width2[depth2-2], width2[depth2-1],4)) 
            
    def forward(self, x):
        for i in range(len(self.stack)):
            x = self.stack[i](x)
        return x

activation=0
dropout=1
dropout_rate=0.29791
normalization=1
batch_size=1000
layers1=10
layers2=1
neurons=86
learning_rate=0.00065
L1=[neurons]*layers1
L2=[neurons]*layers2+[8]
model_h = NN3(35,L1,layers1,L2,layers2+1, 8,normalization,dropout,dropout_rate,0)


        
load=1
PATH="checkpoint/model-1406.pt"
if load==1:
    checkpoint = torch.load(PATH)
    model_h.load_state_dict(checkpoint['model_h_state_dict'])
    optimizer2 = optim.AdamW([{'params': model_h.parameters()}], lr=learning_rate) 
    optimizer2.load_state_dict(checkpoint['optimizer2_state_dict'])
    
model_h.eval()
xlo_test=np.load('xlo_test.npy')
ylo_test=np.load('ylo_test.npy')
pred_2h_star_test = model_h(torch.from_numpy(xlo_test).float())

r2_test0 = r2_score(ylo_test[:,0], pred_2h_star_test.detach().numpy()[:,0])
r2_test1 = r2_score(ylo_test[:,1], pred_2h_star_test.detach().numpy()[:,1])
r2_test2 = r2_score(ylo_test[:,2], pred_2h_star_test.detach().numpy()[:,2])
r2_test3 = r2_score(ylo_test[:,3], pred_2h_star_test.detach().numpy()[:,3])
r2_test4 = r2_score(ylo_test[:,4], pred_2h_star_test.detach().numpy()[:,4])
r2_test5 = r2_score(ylo_test[:,5], pred_2h_star_test.detach().numpy()[:,5])
r2_test6 = r2_score(ylo_test[:,6], pred_2h_star_test.detach().numpy()[:,6])
r2_test7 = r2_score(ylo_test[:,7], pred_2h_star_test.detach().numpy()[:,7])
r2_test_final=np.hstack([r2_test0,r2_test1,r2_test2,r2_test3,r2_test4,r2_test5,r2_test6,r2_test7])

import matplotlib.pyplot as plt
import numpy as np

x=np.arange(0,1,0.01)    

fig, axes = plt.subplots(2, 4, figsize=(16, 8), tight_layout=True)

for i in range(2):
    for j in range(4):
        ax = axes[i, j]
        ax.scatter(ylo_test[:, i * 4 + j], pred_2h_star_test.detach().numpy()[:, i * 4 + j], color='black', alpha=0.1)
        ax.plot(x, x, color='red')
        ax.set_xlabel('True label',fontsize=25, fontname='Times New Roman')
        ax.set_ylabel('Predicted label',fontsize=25, fontname='Times New Roman')
        ax.set_aspect('equal')  
        tick_values = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
        ax.set_xticks(tick_values)
        ax.set_xticklabels(tick_values, fontsize=25, fontname='Times New Roman')
        ax.set_yticks(tick_values)
        ax.set_yticklabels(tick_values, fontsize=25, fontname='Times New Roman')
        
plt.savefig('figure.png', dpi=600)
print('R2 of all phases are as follows respecitvely:',r2_test_final)