In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from Model import PointNeXt
# from dataset.FFDshape import FFDshape_ptp
from Loss import reg_loss
from Transforms import PCDPretreatment, get_data_augment
import numpy as np
from numpy.core.umath import isnan
from Parameters import *
sys.path.insert(1, os.path.dirname(os.path.abspath(__name__)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cfg = MODEL_CONFIG['basic_c']
max_input = model_cfg['max_input']
normal = model_cfg['normal']
model = PointNeXt(model_cfg).to(device=device)

checkpoint_name = 'PointNeXt_shapesffd3_epoch1000.pth'
checkpoint_dir = 'result_train//PointNeXt_model=basic_c_ds=shapesffd3_aug=basic_lr=0.001_wd=0.0001_bs=16_AdamW_cosine//'
checkpoint_file = checkpoint_dir + checkpoint_name

# checkpoint_file = 'result_train\PointNeXt_model=basic_c_ds=waverider_aug=basic_lr=1e-07_wd=0.0001_bs=8_AdamW_cosine\PointNeXt_waverider_epoch30.pth'

checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


PointNeXt(
  (mlp): Conv1d(6, 32, kernel_size=(1,), stride=(1,))
  (stage): ModuleList(
    (0): Stage(
      (sa): SetAbstraction(
        (mlp): Sequential(
          (0): Conv2d(35, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (irm): Sequential(
        (0): InvResMLP(
          (la): LocalAggregation(
            (mlp): Sequential(
              (0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
          (pw_conv): Sequential(
            (0): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
            (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3):

In [2]:

class FFDshape_eval(Dataset):
    def __init__(self, root, transforms=None, split='test', npoints=1024, augment=False, dp=False, normalize=False):
        assert(split == 'train' or split == 'test')
        self.npoints = npoints
        self.transforms = transforms
        self.train_files_list = []
        self.test_files_list = []
        if split == 'train':
            self.training = True
        elif split == 'test':
            self.training = False
        
        name_list = os.listdir(os.path.join(root,'pc'))
        for i in range(len(name_list)):
            name_list[i] = os.path.splitext(name_list[i])[0]
        
        test_files_list = self.read_list_file(name_list, root)
        self.test_files_list = test_files_list

        # self.train_files_list = train_files_list # train_files_list

        self.caches = {}
        print(
            f'Training {len(self.train_files_list)} shapes. Testing {len(self.test_files_list)} shapes '
        )

    def read_list_file(self, name_list, root):
        # base = os.path.dirname(file_path)
        files_list = []
        for shape_name in name_list:
            cur = os.path.join(root, 'pc', '{}.txt'.format(shape_name))
            files_list.append(cur)
        return files_list


    def __getitem__(self, index):
        if index in self.caches:
            return self.caches[index]
        file = self.pcd[index]
        pc = np.loadtxt(file, delimiter=',').astype(np.float32)
        xyz_points = pc[:, :6]
        gts = pc[:, 6]
        s_mesh = pc[:, -1]

        # resample
        # choice = np.random.choice(len(xyz_points), self.npoints, replace=True)
        # xyz_points = xyz_points[choice, :]
        # gts = gts[choice]

        xyz_points = torch.from_numpy(xyz_points).float()
        gts = torch.from_numpy(gts).float()
        if self.transforms is not None:
            xyz_points, gts = self.transforms(xyz_points, gts)
        else:
            xyz_points = xyz_points.T
            
        return xyz_points, gts, s_mesh

    def __len__(self):
        return len(self.pcd)

    def train(self):
        self.training = True
        self.pcd = self.train_files_list
        if self.transforms is not None:
            self.transforms.set_mode('train')

    def eval(self):
        self.training = False
        self.pcd = self.test_files_list
        if self.transforms is not None:
            self.transforms.set_mode('eval')

In [3]:
# read *.stl
def stl_read(shape_file):
    import open3d as o3d
    my_mesh = o3d.io.read_triangle_mesh(shape_file)
    Points = np.asarray(my_mesh.vertices)
    Connectivity = np.asarray(my_mesh.triangles)

    # rescale to 25 m in X-axis
    rate = 25/((Points[:,0]).max()-(Points[:,0]).min())
    Points = Points * rate
    Points[:,0] = Points[:,0] - Points[:,0].mean()
    
    return Points, Connectivity, rate

# more feature of pc
def ex_feature(Points, Connectivity):
    normals = np.zeros(np.shape(Connectivity))
    for i in range(len(Connectivity)):
        temp1 = np.cross(Points[Connectivity[i,2],:] - Points[Connectivity[i,0],:],
            Points[Connectivity[i,1],:] - Points[Connectivity[i,0],:]); # mesh outer direction
        normals[i,:] = temp1/np.linalg.norm(temp1,ord=2)
    normals = -normals
    
    # tri face points
    Points_tri = np.zeros(np.shape(Connectivity))
    for i in range(len(Connectivity)):
        Points_tri[i,:] = np.mean(Points[Connectivity[i,:],:],axis=0)

    # cal tri-area
    tri_point = np.zeros((np.size(Connectivity,0),3,3))
    side_len = np.zeros((np.size(Connectivity,0),3))
    for i in range(3):
        tri_point[:,:,i] = Points[Connectivity[:,i],:]

    for i in range(3):
        tmp = np.array([i,i+1]).astype(int)
        tmp[tmp>=3] = 0
        side_len[:,i] = np.sqrt(np.sum(
            (tri_point[:,:,tmp[0]] - tri_point[:,:,tmp[1]])**2
            ,1))
    side_p = np.sum(side_len,1)/2
    tri_area = np.sqrt(side_p*
        (side_p-side_len[:,0]) *
        (side_p-side_len[:,1]) *
        (side_p-side_len[:,2]))
    tri_area[isnan(tri_area)]=0
    return normals, Points_tri, tri_area

# rotate Y-axis
def rotate_pc(pc, rotation_angle):
    cosval = np.cos(np.deg2rad(rotation_angle))
    sinval = np.sin(np.deg2rad(rotation_angle))
    rotation_matrix = np.array([[cosval, 0, sinval],
                                [0, 1, 0],
                                [-sinval, 0, cosval]])
    rotated_pc_xyz = np.dot(pc[:,:3], rotation_matrix)
    if pc.shape[-1] > 3:
        rotated_pc_normal = np.dot(pc[:,3:], rotation_matrix)
        rotated_pc = np.concatenate((rotated_pc_xyz, rotated_pc_normal), axis=1)
    else:
        rotated_pc = rotated_pc_xyz
    return torch.tensor(rotated_pc)

def uniform_pc(pcd):
    scale_rate = np.zeros(pcd.shape[0])
    shift = np.zeros((pcd.shape[0],3))
    pcd_out = torch.zeros(pcd.shape)
    for i in range(pcd.shape[0]):
        pcd_tmp = pcd[i,:,:].squeeze(0).T
        # 坐标归一化
        pcd_xyz = pcd_tmp[:, :3]
        pcd_xyz = pcd_xyz - pcd_xyz.mean(dim=0, keepdim=True)
        dis = torch.norm(pcd_xyz, dim=1)
        max_dis = dis.max()
        pcd_xyz /= max_dis
        scale_rate[i] = max_dis
        shift[i,:] = pcd_xyz.mean(dim=0, keepdim=True)

        # 法线
        if pcd_tmp.shape[1]>3:
            pcd_tmp[:, :3] = pcd_xyz
        else:
            pcd_tmp = pcd_xyz
        pcd_out[i,:,:] = pcd_tmp.T.unsqueeze(0)

    return pcd_out, scale_rate, shift
def index_points(points, idx):
    """
    跟据采样点索引获取其原始点云xyz坐标等信息
    :param points: <torch.Tensor> (B, N, 3+) 原始点云
    :param idx: <torch.Tensor> (B, S)/(B, S, G) 采样点索引，S为采样点数量，G为每个采样点grouping的点数
    :return: <torch.Tensor> (B, S, 3+)/(B, S, G, 3+) 获取了原始点云信息的采样点
    """
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long, device=points.device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def farthest_point_sample(xyz, npoint):
    """
    最远点采样
    随机选择一个初始点作为采样点，循环的将与当前采样点距离最远的点当作下一个采样点，直至满足采样点的数量需求
    :param xyz: <torch.Tensor> (B, N, 3+) 原始点云
    :param npoint: <int> 采样点数量
    :return: <torch.Tensor> (B, npoint) 采样点索引
    """
    device = xyz.device
    B, N, C = xyz.shape
    npoint = min(npoint, N)
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10  # 每个点与最近采样点的最小距离
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)  # 随机选取初始点

    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, -1)  # [bs, 1, coor_dim]
        dist = torch.nn.functional.pairwise_distance(xyz, centroid)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

In [4]:
def three_interpolate(xyz1, xyz2, points2):
    '''

    :param xyz1: shape=(B, N1, 3)
    :param xyz2: shape=(B, N2, 3)
    :param points2: shape=(B, N2, C2)
    :return: interpolated_points: shape=(B, N1, C2)
    '''
    interp_num = 3
    _, _, C2 = points2.shape
    dists, inds = three_nn(xyz1, xyz2, interp_num)
    inversed_dists = 1.0 / (dists + 1e-8)
    weight = inversed_dists / torch.sum(inversed_dists, dim=-1, keepdim=True) # shape=(B, N1, 3)
    weight = torch.unsqueeze(weight, -1).repeat(1, 1, 1, C2)
    interpolated_points = gather_points(points2, inds)  # shape=(B, N1, 3, C2)
    interpolated_points = torch.sum(weight * interpolated_points, dim=2)
    return interpolated_points

def three_nn(xyz1, xyz2, interp_num=3):
    '''

    :param xyz1: shape=(B, N1, 3)
    :param xyz2: shape=(B, N2, 3)
    :return: dists: shape=(B, N1, 3), inds: shape=(B, N1, 3)
    '''
    dists = get_dists(xyz1, xyz2)
    dists, inds = torch.sort(dists, dim=-1)
    dists, inds = dists[:, :, :interp_num], inds[:, :, :interp_num]
    return dists, inds

def get_dists(points1, points2):
    '''
    Calculate dists between two group points
    :param cur_point: shape=(B, M, C)
    :param points: shape=(B, N, C)
    :return: 
    '''
    B, M, C = points1.shape
    _, N, _ = points2.shape
    dists = torch.sum(torch.pow(points1, 2), dim=-1).view(B, M, 1) + \
            torch.sum(torch.pow(points2, 2), dim=-1).view(B, 1, N)
    dists -= 2 * torch.matmul(points1, points2.permute(0, 2, 1))
    dists = torch.where(dists < 0, torch.ones_like(dists) * 1e-7, dists) # Very Important for dist = 0.
    return torch.sqrt(dists).float()

def gather_points(points, inds):
    '''

    :param points: shape=(B, N, C)
    :param inds: shape=(B, M) or shape=(B, M, K)
    :return: sampling points: shape=(B, M, C) or shape=(B, M, K, C)
    '''
    device = points.device
    B, N, C = points.shape
    inds_shape = list(inds.shape)
    inds_shape[1:] = [1] * len(inds_shape[1:])
    repeat_shape = list(inds.shape)
    repeat_shape[0] = 1
    batchlists = torch.arange(0, B, dtype=torch.long).to(device).reshape(inds_shape).repeat(repeat_shape)
    return points[batchlists, inds, :]
def all_nn(xyz1, xyz2):
    '''

    :param xyz1: shape=(B, N1, 3)
    :param xyz2: shape=(B, N2, 3)
    :return: dists: shape=(B, N1, N2), inds: shape=(B, N1, N2)
    '''
    dists = get_dists(xyz1, xyz2)
    dists, inds = torch.sort(dists, dim=-1)
    return dists, inds

def three_interpolate_normals(xyz1, xyz2, points2, normals1, normals2):
    '''

    :param xyz1: shape=(B, N1, 3)
    :param xyz2: shape=(B, N2, 3)
    :param points2: shape=(B, N2, C2)
    :return: interpolated_points: shape=(B, N1, C2)
    '''
    interp_num = 12
    _, _, C2 = points2.shape
    dists_normals = torch.matmul(normals1, normals2.permute(0, 2, 1))
    dists, inds = all_nn(xyz1, xyz2)
    
    # !!!矩阵化索引有问题
    # interp_num_tmp = interp_num
    # while not torch.all(torch.any((dists_normals[:,:,inds[0,i,:interp_num_tmp]]>0),dim=2)):
    #     interp_num_tmp = interp_num_tmp + 1
    #     print(interp_num_tmp)
    # for i in range(dists_normals.shape[1]):
    #     dists_normals[:,i,:interp_num_tmp] = dists_normals[:,i,inds[0,i,:interp_num_tmp]]
    
    # 粗暴方式
    interp_num_tmp = interp_num
    dists_normals_tmp = dists_normals[:,:,:64]
    while True:
        for i in range(dists_normals.shape[1]):
            dists_normals_tmp[:,i,:interp_num_tmp] = dists_normals[:,i,inds[0,i,:interp_num_tmp]]
        if torch.all(torch.any((dists_normals_tmp[:,:,:interp_num_tmp]>0),dim=2)):
            dists_normals = dists_normals_tmp
            break
        else:
            interp_num_tmp = interp_num_tmp + 2
            print(interp_num_tmp)
    
    dists_normals = dists_normals[:,:,:interp_num_tmp]
    dists_normals = (dists_normals+1e-7)
    dists_normals[dists_normals<=0] = 0
    inversed_dists = 1.0 / (dists[:,:,:interp_num_tmp] + 1e-8) *dists_normals
    weight = inversed_dists / torch.sum(inversed_dists, dim=-1, keepdim=True) # shape=(B, N1, 3)
    weight = torch.unsqueeze(weight, -1).repeat(1, 1, 1, C2)
    interpolated_points = gather_points(points2, inds[:,:,:interp_num_tmp])  # shape=(B, N1, 3, C2)
    interpolated_points = torch.sum(weight * interpolated_points, dim=2)
    return interpolated_points

In [5]:
# shape_base = 'D:\MyCode\PCdeep_TL\ShapeGenerator\shape_set\waverider\\testing_N200_D60_322'
shape_base = '..\ShapeGenerator\shape_set\waverider\\testing_N200_D60_322'
param_path = os.path.join(shape_base,'paramList.csv')
param_list = np.loadtxt(param_path, delimiter=',')
aoa_list = param_list[:,-1]
# # Tri-reading
# shape_file = shape_base + 'shape_001.stl'
# if shape_file[-3:] == 'stl' or 'ply':
#     Points_init, Connectivity, rate = stl_read(shape_file)

In [6]:
data_root = '..\shapesffd4'
pc_root = data_root + '\waverider\\testing_N200_D60_322'
# transforms = PCDPretreatment(num=1024, down_sample='random', normal=model_cfg['normal'])
transforms = None
dataset = FFDshape_eval(root=pc_root,split='test',transforms=transforms)
dataset.eval()
eval_dataloader = DataLoader(dataset=dataset,
                                batch_size=1,
                                num_workers=0,
                                pin_memory=False,#True
                                drop_last=False,
                                shuffle=False)
criterion = reg_loss().to(device)


Training 0 shapes. Testing 200 shapes 


In [9]:
import copy
from sklearn import metrics
torch.manual_seed(0)
# seed 0
# CP Direct, MAE: 0.0323; RMSE: 0.0698; R2: 0.8609 
# CP Interp, MAE: 0.0390; RMSE: 0.0810; R2: 0.7945 
# cds, MAE: 7.7161; RMSE: 11.3791; R2: 0.9939; relative l2: 0.0524
# seed 1234
# CP Direct, MAE: 0.0318; RMSE: 0.0676; R2: 0.8629 
# CP Interp, MAE: 0.0387; RMSE: 0.0803; R2: 0.7903 
# cds, MAE: 7.9880; RMSE: 12.1252; R2: 0.9931; relative l2: 0.0558

disp_loss = False

mse_list = np.zeros(len(eval_dataloader))
mae_list = np.zeros(len(eval_dataloader))
rmse_list = np.zeros(len(eval_dataloader))
r2_score_list = np.zeros(len(eval_dataloader))
mae_all_list = np.zeros(len(eval_dataloader))
rmse_all_list = np.zeros(len(eval_dataloader))
r2_score_all_list = np.zeros(len(eval_dataloader))
cds_pred = np.zeros(len(eval_dataloader))
cds_gt = np.zeros(len(eval_dataloader))


num_points = 1024
for count, data in enumerate(eval_dataloader):
    # size pcd_all: [B,C,N] gts: [B,N]
    pcd_all, gts_all, s_tri = data
    # pcd_rotated = rotate_pc(pcd_all.squeeze(0).T, -aoa_list[count])
    # pcd_all = pcd_rotated.clone().detach().T.unsqueeze(0).float()
    pcd_uniform = copy.deepcopy(pcd_all)
    pcd_uniform, scale_rate, shift = uniform_pc(pcd_uniform)
    choice_idx = torch.randperm(pcd_uniform.shape[2])[:num_points] # random sampling
#     choice_idx = farthest_point_sample(pcd_all.permute(0,2,1)[:,:,:3], num_points).squeeze(0) # fps
    pcd = pcd_uniform[:,:,choice_idx]
    pcd = pcd.to(device, non_blocking=True)
    gts = gts_all[:,choice_idx] 
    # model pred
    with torch.no_grad():
        pred = model(pcd)
        pred = torch.squeeze(pred,1)# Size([1, 1024])
        if disp_loss:
            print(criterion(pred, gts.to(device)))
    # re scale
    pred, gts, pcd = pred.to('cpu'), np.array(gts.to('cpu')), pcd.to('cpu')
    pcd[0,:3,:] = pcd[0,:3,:] * scale_rate
    pcd[0,:3,:] += shift.T
    # to np
    pcd_np = np.array(pcd.squeeze(0).T[:,:3])
    pred_np = np.array(pred.squeeze(0))
    pcd_all_np = np.array(pcd_all.squeeze(0).T[:,:3])

    # interp
    pred_all = three_interpolate_normals(pcd_all[:,:3,:].permute(0,2,1), \
                            pcd[:,:3,:].permute(0,2,1).to(torch.float32),\
                            pred.unsqueeze(2),\
                            pcd_all[:,3:6,:].permute(0,2,1),\
                            pcd[:,3:6,:].permute(0,2,1).to(torch.float32))
    # to 1 dim
    pred_1d = np.array(pred[0])
    gts_1d = np.array(gts[0])
    pred_1d_all = np.array(pred_all[0,:,0])
    gts_1d_all = np.array(gts_all[0,:])
    
    # cp metrics per shape
    mse = metrics.mean_squared_error(pred_1d, gts_1d)
    mae = metrics.mean_absolute_error(pred_1d, gts_1d)
    rmse = metrics.mean_squared_error(pred_1d, gts_1d)**0.5
    r2_score = metrics.r2_score(pred_1d, gts_1d)
    print('Direct-%d, MAE: %.4f; RMSE: %.4f; R2: %.4f; MSE: %.6f' %(count, mae, rmse, r2_score, mse))

    mae_all = metrics.mean_absolute_error(pred_1d_all, gts_1d_all)
    rmse_all = metrics.mean_squared_error(pred_1d_all, gts_1d_all)**0.5
    r2_score_all = metrics.r2_score(pred_1d_all, gts_1d_all)
    print('Interp-%d, MAE: %.4f; RMSE: %.4f; R2: %.4f ' %(count, mae_all, rmse_all, r2_score_all))

    # cp Metrics
    mse_list[count] = mse
    mae_list[count] = mae
    rmse_list[count] = rmse
    r2_score_list[count] = r2_score
    mae_all_list[count] = mae_all
    rmse_all_list[count] = rmse_all
    r2_score_all_list[count] = r2_score_all
    # cal forces
    normals_all_np = np.array(pcd_all.squeeze(0).T[:,3:])
    s_tri_np = np.array(s_tri)[0]
    force_pred = np.sum(normals_all_np*\
            np.repeat(np.expand_dims(pred_1d_all,axis=1),3,axis=1)*\
            np.repeat(np.expand_dims(s_tri_np,axis=1),3,axis=1), axis=0)
    force_gt = np.sum(normals_all_np*\
            np.repeat(np.expand_dims(gts_1d_all,axis=1),3,axis=1)*\
            np.repeat(np.expand_dims(s_tri_np,axis=1),3,axis=1), axis=0)
    cds_pred[count] = force_pred[0]
    cds_gt[count] = force_gt[0]

# metrics summary
# cp of direct set
mse_set = np.mean(mse_list)
mae_set = np.mean(mae_list)
rmse_set = np.mean(rmse_list)
r2_score_set = np.mean(r2_score_list)
print('CP Direct, MAE: %.4f; RMSE: %.4f; R2: %.4f; MSE: %.6f' %(mae_set, rmse_set, r2_score_set, mse_set))
# cp of interp set
mae_all_set = np.mean(mae_all_list)
rmse_all_set = np.mean(rmse_all_list)
r2_score_all_set = np.mean(r2_score_all_list)
print('CP Interp, MAE: %.4f; RMSE: %.4f; R2: %.4f ' %(mae_all_set, rmse_all_set, r2_score_all_set))
# cds
cds_mae = metrics.mean_absolute_error(cds_pred, cds_gt)
cds_rmse = metrics.mean_squared_error(cds_pred, cds_gt)**0.5
cds_r2_score = metrics.r2_score(cds_pred, cds_gt)
l2_loss = np.linalg.norm(cds_pred-cds_gt)/np.linalg.norm(cds_gt)
print('cds, MAE: %.4f; RMSE: %.4f; R2: %.4f; relative l2: %.4f' %(cds_mae, cds_rmse, cds_r2_score, l2_loss))

In [None]:
# import plotly.express as px
# fig = px.scatter(x=pred_1d,
#                  y=gts_1d
#                 )
# fig.show()

In [None]:
# import plotly.graph_objects as go

# tmp = pred_1d# pred_1d- gts_1d
# fig = go.Figure(data=[go.Scatter3d(
#     x=pcd_np[:,0],
#     y=pcd_np[:,1],
#     z=pcd_np[:,2],
#     mode='markers',
#     marker=dict(
#         size=3,
#         color=tmp,   # pred[:,0].to('cpu')
#         colorscale='Viridis',   # choose a colorscale
#         opacity=1
#     )
# )])

# # tight layout
# fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
# fig.show()

In [None]:
# import plotly.graph_objects as go

# fig = go.Figure(data=[go.Scatter3d(
#     x=pcd_all_np[:,0],
#     y=pcd_all_np[:,1],
#     z=pcd_all_np[:,2],
#     mode='markers',
#     marker=dict(
#         size=2,
#         color=pred_1d_all,                # set color to an array/list of desired values
#         colorscale='Viridis',   # choose a colorscale
#         opacity=1
#     )
# )])

# # tight layout
# fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
# fig.show()

In [None]:
# from sklearn import metrics 
# mae = metrics.mean_absolute_error(pred_1d, gts_1d)
# rmse = metrics.mean_squared_error(pred_1d, gts_1d)**0.5
# r2_score = metrics.r2_score(pred_1d, gts_1d)
# print('Direct, MAE: %.4f; RMSE: %.4f; R2: %.4f ' %(mae, rmse, r2_score))

# mae_all = metrics.mean_absolute_error(pred_1d_all, gts_1d_all)
# rmse_all = metrics.mean_squared_error(pred_1d_all, gts_1d_all)**0.5
# r2_score_all = metrics.r2_score(pred_1d_all, gts_1d_all)
# print('Interp, MAE: %.4f; RMSE: %.4f; R2: %.4f ' %(mae_all, rmse_all, r2_score_all))



In [None]:
# # save NN to need_frozen_list.csv
# path = 'nn_list.txt'
# data = np.loadtxt(path, dtype=str, delimiter=' ')
# nn_list = np.expand_dims(data,axis=1)
# need_frozen = np.ones((len(data),1),dtype=str)
# need_frozen_list = np.concatenate((nn_list,need_frozen),axis=1)
# np.savetxt('need_frozen_list.csv',need_frozen_list, fmt='%s',delimiter=',')

In [None]:
# need_frozen_list_path = 'need_frozen_list.csv'
# need_frozen_np = np.loadtxt(need_frozen_list_path, dtype=str, delimiter=',')
# need_frozen_list = {}

# for i, a in enumerate(need_frozen_np[:,0]):
#     need_frozen_list[a] = need_frozen_np[i,1]
# for param in model.named_parameters():
#     if need_frozen_list[param[0]] == '1':
#         # frozen
#         param[1].requires_grad = False
#     else:
#         param[1].requires_grad = True
#     print(param[0],param[1].requires_grad)