将高保真测试数据分为测试集和验证集，分别对低保真度训练数据与高保真度训练数据建立神经网络进行特征提取，将低保真度数据特征与高保真度数据特征分别使用高斯过程进行预测。然后输入高保真度测试数据X，分别通过训练好的两个神经网络和高斯过程，再将两个正态分布相加，取出均值再用一个三层神经网络与高保真度真实值进行训练。最后输入高保真度验证数据，分别通过特征提取网络，高斯过程，均值神经网络，进行结果的预测。

In [1]:
import torch
import math
import numpy as np
import gpytorch
import torch.nn as nn
from gpytorch.kernels import RBFKernel, ScaleKernel
import matplotlib.pyplot as plt

In [2]:
XTestRaw = np.loadtxt('./dataset/x_test_high.txt').reshape(-1,32)
# XTest = XTest.mean(axis=1).reshape(-1)
YTestRaw = np.loadtxt('./dataset/y_test_high.txt').reshape(-1)

XLow = np.loadtxt('./dataset/x_train_low.txt').reshape(-1,32)
# XLow = XLow.mean(axis=1).reshape(-1)
YLow = (np.loadtxt('./dataset/y_train_low.txt')*1e4).reshape(-1)

XHigh = np.loadtxt('./dataset/x_train_high.txt').reshape(-1,32)
# XHigh = XHigh.mean(axis=1).reshape(-1)
YHigh = np.loadtxt('./dataset/y_train_high.txt').reshape(-1)

XTest=XTestRaw[0:5,:]
YTest=YTestRaw[0:5]

XVal=XTestRaw[5:,:]
YVal=YTestRaw[5:]


XLow = torch.from_numpy(XLow).float()
YLow = torch.from_numpy(YLow).float()
XHigh = torch.from_numpy(XHigh).float()
YHigh = torch.from_numpy(YHigh).float()
XTest = torch.from_numpy(XTest).float()
YTest = torch.from_numpy(YTest).float()
XVal = torch.from_numpy(XVal).float()
YVal = torch.from_numpy(YVal).float()

print(YTest)
print(YVal)


tensor([31.1779, 29.1824, 26.4095, 26.9644, 30.0900])
tensor([26.8370, 25.2800, 31.4606, 30.1555])


In [3]:
class LowFeatureExtractor(torch.nn.Sequential):
    def __init__(self):
        super(LowFeatureExtractor, self).__init__()
        self.add_module('linear1', torch.nn.Linear(32, 500))
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('linear2', torch.nn.Linear(500, 50))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('linear3', torch.nn.Linear(50, 2))
        
class HighFeatureExtractor(torch.nn.Sequential):
    def __init__(self):
        super(HighFeatureExtractor, self).__init__()
        self.add_module('linear1', torch.nn.Linear(32, 500))
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('linear2', torch.nn.Linear(500, 50))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('linear3', torch.nn.Linear(50, 2))

featureExtractorLow = LowFeatureExtractor()
featureExtractorHigh = HighFeatureExtractor()

In [4]:
class LowGPRegressionModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super(LowGPRegressionModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
            self.feature_extractor = featureExtractorLow

        def forward(self, train_x):
            # We're first putting our data through a deep net (feature extractor)
            # We're also scaling the features so that they're nice values
            projected_x = self.feature_extractor(train_x)
            projected_x = projected_x - projected_x.min(0)[0]
            projected_x = 2 * (projected_x / projected_x.max(0)[0]) - 1

            mean_x = self.mean_module(projected_x)
            covar_x = self.covar_module(projected_x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [5]:
likelihoodLow=gpytorch.likelihoods.GaussianLikelihood()
modelLow=LowGPRegressionModel(XLow,YLow,likelihoodLow)

In [6]:
class HighGPRegressionModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super(HighGPRegressionModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
            self.feature_extractor = featureExtractorHigh

        def forward(self, train_x):
            # We're first putting our data through a deep net (feature extractor)
            # We're also scaling the features so that they're nice values
#             LowMean=self.modelLow.mean
#             x=torch.stack(train_x,LowMean)
            projected_x = self.feature_extractor(x_train)
            projected_x = projected_x - projected_x.min(0)[0]
            projected_x = 2 * (projected_x / projected_x.max(0)[0]) - 1

            mean_x = self.mean_module(projected_x)
            covar_x = self.covar_module(projected_x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [7]:
likelihoodHigh=gpytorch.likelihoods.GaussianLikelihood()
modelHigh=LowGPRegressionModel(XHigh,YHigh,likelihoodHigh)

In [8]:
modelHigh.train()
likelihoodHigh.train()
modelLow.train()
likelihoodLow.train()

optimizerLow=torch.optim.Adam([
    {'params':modelLow.mean_module.parameters()},
    {'params':modelLow.covar_module.parameters()},
    {'params':modelLow.feature_extractor.parameters()},
    {'params':modelLow.likelihood.parameters()},
],lr=0.01)
optimizerHigh=torch.optim.SGD([
    {'params':modelHigh.mean_module.parameters()},
    {'params':modelHigh.covar_module.parameters()},
    {'params':modelHigh.feature_extractor.parameters()},
    {'params':modelHigh.likelihood.parameters()},
],lr=0.01)

mllLow=gpytorch.mlls.ExactMarginalLogLikelihood(likelihoodLow,modelLow)
mllHigh=gpytorch.mlls.ExactMarginalLogLikelihood(likelihoodHigh,modelHigh)

for t in range(1000):
    optimizerLow.zero_grad()    
    outputLow=modelLow(XLow)    
    lossLow=-mllLow(outputLow,YLow)    
    print("Epoch",t,"LossLow:",lossLow.item())  
    lossLow.backward()
    optimizerLow.step()


for t in range(1000):
    optimizerHigh.zero_grad()
    outputHigh=modelHigh(XHigh)   
    lossHigh=-mllHigh(outputHigh,YHigh)   
    print("Epoch",t,"LossHigh:",lossHigh.item())   
    lossHigh.backward()
    optimizerHigh.step()


Epoch 0 LossLow: 10.62673568725586
Epoch 1 LossLow: 8.06839370727539
Epoch 2 LossLow: 5.9377336502075195
Epoch 3 LossLow: 5.19174861907959
Epoch 4 LossLow: 4.8296027183532715
Epoch 5 LossLow: 4.5279860496521
Epoch 6 LossLow: 4.322329998016357
Epoch 7 LossLow: 4.178844451904297
Epoch 8 LossLow: 4.035300254821777
Epoch 9 LossLow: 3.8756163120269775
Epoch 10 LossLow: 3.7324180603027344
Epoch 11 LossLow: 3.612377166748047
Epoch 12 LossLow: 3.5044162273406982
Epoch 13 LossLow: 3.4096686840057373
Epoch 14 LossLow: 3.3146824836730957
Epoch 15 LossLow: 3.2300281524658203
Epoch 16 LossLow: 3.1501071453094482
Epoch 17 LossLow: 3.071692943572998
Epoch 18 LossLow: 3.002429962158203
Epoch 19 LossLow: 2.936347723007202
Epoch 20 LossLow: 2.8855841159820557
Epoch 21 LossLow: 2.8273963928222656
Epoch 22 LossLow: 2.7688348293304443
Epoch 23 LossLow: 2.714380979537964
Epoch 24 LossLow: 2.657128095626831
Epoch 25 LossLow: 2.6038808822631836
Epoch 26 LossLow: 2.5551674365997314
Epoch 27 LossLow: 2.51753044

Epoch 221 LossLow: 0.9928992986679077
Epoch 222 LossLow: 0.9927203059196472
Epoch 223 LossLow: 0.9940090179443359
Epoch 224 LossLow: 0.9870951175689697
Epoch 225 LossLow: 0.98694908618927
Epoch 226 LossLow: 0.9746689796447754
Epoch 227 LossLow: 0.9715988636016846
Epoch 228 LossLow: 0.9703596234321594
Epoch 229 LossLow: 0.9682837724685669
Epoch 230 LossLow: 0.9686428308486938
Epoch 231 LossLow: 0.9631087779998779
Epoch 232 LossLow: 0.956511914730072
Epoch 233 LossLow: 0.9472244381904602
Epoch 234 LossLow: 0.9435133337974548
Epoch 235 LossLow: 0.9469581842422485
Epoch 236 LossLow: 0.9439964890480042
Epoch 237 LossLow: 0.9401429295539856
Epoch 238 LossLow: 0.9325847029685974
Epoch 239 LossLow: 0.9353055953979492
Epoch 240 LossLow: 0.9315766096115112
Epoch 241 LossLow: 0.9223405718803406
Epoch 242 LossLow: 0.9216132760047913
Epoch 243 LossLow: 0.9108937382698059
Epoch 244 LossLow: 0.9056105017662048
Epoch 245 LossLow: 0.901444673538208
Epoch 246 LossLow: 0.9015076160430908
Epoch 247 LossLo

Epoch 437 LossLow: 0.2579830586910248
Epoch 438 LossLow: 0.22438682615756989
Epoch 439 LossLow: 0.23763154447078705
Epoch 440 LossLow: 0.2510022521018982
Epoch 441 LossLow: 0.21919773519039154
Epoch 442 LossLow: 0.20621758699417114
Epoch 443 LossLow: 0.26714763045310974
Epoch 444 LossLow: 0.22565829753875732
Epoch 445 LossLow: 0.1987447440624237
Epoch 446 LossLow: 0.20438243448734283
Epoch 447 LossLow: 0.17622411251068115
Epoch 448 LossLow: 0.18349087238311768
Epoch 449 LossLow: 0.16726937890052795
Epoch 450 LossLow: 0.17404207587242126
Epoch 451 LossLow: 0.1628572642803192
Epoch 452 LossLow: 0.19150462746620178
Epoch 453 LossLow: 0.14973308145999908
Epoch 454 LossLow: 0.1383996456861496
Epoch 455 LossLow: 0.18652993440628052
Epoch 456 LossLow: 0.12922130525112152
Epoch 457 LossLow: 0.14418929815292358
Epoch 458 LossLow: 0.13166013360023499
Epoch 459 LossLow: 0.14921753108501434
Epoch 460 LossLow: 0.145150288939476
Epoch 461 LossLow: 0.13389967381954193
Epoch 462 LossLow: 0.10350444912

Epoch 645 LossLow: -0.2874588668346405
Epoch 646 LossLow: -0.23522557318210602
Epoch 647 LossLow: -0.29984012246131897
Epoch 648 LossLow: -0.2667330801486969
Epoch 649 LossLow: -0.25484776496887207
Epoch 650 LossLow: -0.2908478081226349
Epoch 651 LossLow: -0.29403480887413025
Epoch 652 LossLow: -0.2982156574726105
Epoch 653 LossLow: -0.3132180869579315
Epoch 654 LossLow: -0.2684619426727295
Epoch 655 LossLow: -0.23703217506408691
Epoch 656 LossLow: -0.21328558027744293
Epoch 657 LossLow: -0.21318165957927704
Epoch 658 LossLow: -0.2160043716430664
Epoch 659 LossLow: -0.2488773912191391
Epoch 660 LossLow: -0.24869762361049652
Epoch 661 LossLow: -0.2911681532859802
Epoch 662 LossLow: -0.2543533742427826
Epoch 663 LossLow: -0.3250391483306885
Epoch 664 LossLow: -0.31227830052375793
Epoch 665 LossLow: -0.25844910740852356
Epoch 666 LossLow: -0.24033285677433014
Epoch 667 LossLow: -0.3222319185733795
Epoch 668 LossLow: -0.2711107134819031
Epoch 669 LossLow: -0.2391849309206009
Epoch 670 Loss

Epoch 854 LossLow: -0.338882178068161
Epoch 855 LossLow: -0.3825267255306244
Epoch 856 LossLow: -0.3693176209926605
Epoch 857 LossLow: -0.3870680630207062
Epoch 858 LossLow: -0.4165741205215454
Epoch 859 LossLow: -0.44402408599853516
Epoch 860 LossLow: -0.47597840428352356
Epoch 861 LossLow: -0.42353078722953796
Epoch 862 LossLow: -0.49779805541038513
Epoch 863 LossLow: -0.5068013668060303
Epoch 864 LossLow: -0.5057908892631531
Epoch 865 LossLow: -0.45771849155426025
Epoch 866 LossLow: -0.4447300434112549
Epoch 867 LossLow: -0.5199031233787537
Epoch 868 LossLow: -0.49735525250434875
Epoch 869 LossLow: -0.40792086720466614
Epoch 870 LossLow: -0.44548797607421875
Epoch 871 LossLow: -0.47940927743911743
Epoch 872 LossLow: -0.5059725642204285
Epoch 873 LossLow: -0.44727155566215515
Epoch 874 LossLow: -0.46918153762817383
Epoch 875 LossLow: -0.4769875407218933
Epoch 876 LossLow: -0.4412921071052551
Epoch 877 LossLow: -0.3992465138435364
Epoch 878 LossLow: -0.43033406138420105
Epoch 879 Loss

Epoch 104 LossHigh: 7.045207977294922
Epoch 105 LossHigh: 7.022876739501953
Epoch 106 LossHigh: 7.000845909118652
Epoch 107 LossHigh: 6.979109287261963
Epoch 108 LossHigh: 6.9576592445373535
Epoch 109 LossHigh: 6.936487197875977
Epoch 110 LossHigh: 6.915585517883301
Epoch 111 LossHigh: 6.894947528839111
Epoch 112 LossHigh: 6.874564170837402
Epoch 113 LossHigh: 6.854434013366699
Epoch 114 LossHigh: 6.834547519683838
Epoch 115 LossHigh: 6.8148980140686035
Epoch 116 LossHigh: 6.7954864501953125
Epoch 117 LossHigh: 6.776298999786377
Epoch 118 LossHigh: 6.757329940795898
Epoch 119 LossHigh: 6.738574028015137
Epoch 120 LossHigh: 6.720027446746826
Epoch 121 LossHigh: 6.701681137084961
Epoch 122 LossHigh: 6.683531761169434
Epoch 123 LossHigh: 6.665578365325928
Epoch 124 LossHigh: 6.647812843322754
Epoch 125 LossHigh: 6.630228042602539
Epoch 126 LossHigh: 6.612825870513916
Epoch 127 LossHigh: 6.5955939292907715
Epoch 128 LossHigh: 6.578536033630371
Epoch 129 LossHigh: 6.5616607666015625
Epoch 1

Epoch 326 LossHigh: 4.757335662841797
Epoch 327 LossHigh: 4.753049373626709
Epoch 328 LossHigh: 4.748689651489258
Epoch 329 LossHigh: 4.7444586753845215
Epoch 330 LossHigh: 4.740142822265625
Epoch 331 LossHigh: 4.735919952392578
Epoch 332 LossHigh: 4.731668949127197
Epoch 333 LossHigh: 4.72745418548584
Epoch 334 LossHigh: 4.723268032073975
Epoch 335 LossHigh: 4.719062805175781
Epoch 336 LossHigh: 4.714939117431641
Epoch 337 LossHigh: 4.710745811462402
Epoch 338 LossHigh: 4.706675052642822
Epoch 339 LossHigh: 4.702524185180664
Epoch 340 LossHigh: 4.6984543800354
Epoch 341 LossHigh: 4.694370746612549
Epoch 342 LossHigh: 4.69030237197876
Epoch 343 LossHigh: 4.686284065246582
Epoch 344 LossHigh: 4.682218074798584
Epoch 345 LossHigh: 4.678263187408447
Epoch 346 LossHigh: 4.67422342300415
Epoch 347 LossHigh: 4.670283317565918
Epoch 348 LossHigh: 4.666301727294922
Epoch 349 LossHigh: 4.662357330322266
Epoch 350 LossHigh: 4.658442974090576
Epoch 351 LossHigh: 4.6544952392578125
Epoch 352 LossH

Epoch 549 LossHigh: 4.080374717712402
Epoch 550 LossHigh: 4.0781168937683105
Epoch 551 LossHigh: 4.075943946838379
Epoch 552 LossHigh: 4.073694705963135
Epoch 553 LossHigh: 4.071500301361084
Epoch 554 LossHigh: 4.0692949295043945
Epoch 555 LossHigh: 4.06707239151001
Epoch 556 LossHigh: 4.064911842346191
Epoch 557 LossHigh: 4.062686920166016
Epoch 558 LossHigh: 4.060521602630615
Epoch 559 LossHigh: 4.058333396911621
Epoch 560 LossHigh: 4.056136131286621
Epoch 561 LossHigh: 4.053995132446289
Epoch 562 LossHigh: 4.051791191101074
Epoch 563 LossHigh: 4.049641132354736
Epoch 564 LossHigh: 4.047474384307861
Epoch 565 LossHigh: 4.045290470123291
Epoch 566 LossHigh: 4.043173789978027
Epoch 567 LossHigh: 4.0409932136535645
Epoch 568 LossHigh: 4.03885555267334
Epoch 569 LossHigh: 4.0367207527160645
Epoch 570 LossHigh: 4.034554481506348
Epoch 571 LossHigh: 4.032454967498779
Epoch 572 LossHigh: 4.030308246612549
Epoch 573 LossHigh: 4.028168678283691
Epoch 574 LossHigh: 4.026078224182129
Epoch 575 

Epoch 771 LossHigh: 3.670801877975464
Epoch 772 LossHigh: 3.6692447662353516
Epoch 773 LossHigh: 3.6676878929138184
Epoch 774 LossHigh: 3.6661336421966553
Epoch 775 LossHigh: 3.664581775665283
Epoch 776 LossHigh: 3.6630313396453857
Epoch 777 LossHigh: 3.661482572555542
Epoch 778 LossHigh: 3.659935712814331
Epoch 779 LossHigh: 3.658390998840332
Epoch 780 LossHigh: 3.656848669052124
Epoch 781 LossHigh: 3.655306816101074
Epoch 782 LossHigh: 3.6537678241729736
Epoch 783 LossHigh: 3.6522300243377686
Epoch 784 LossHigh: 3.6506946086883545
Epoch 785 LossHigh: 3.649160385131836
Epoch 786 LossHigh: 3.6476292610168457
Epoch 787 LossHigh: 3.6460988521575928
Epoch 788 LossHigh: 3.6445705890655518
Epoch 789 LossHigh: 3.6430439949035645
Epoch 790 LossHigh: 3.641519069671631
Epoch 791 LossHigh: 3.63999605178833
Epoch 792 LossHigh: 3.638474702835083
Epoch 793 LossHigh: 3.6369547843933105
Epoch 794 LossHigh: 3.635437488555908
Epoch 795 LossHigh: 3.6339213848114014
Epoch 796 LossHigh: 3.6324069499969482

Epoch 993 LossHigh: 3.3619260787963867
Epoch 994 LossHigh: 3.3606669902801514
Epoch 995 LossHigh: 3.3594095706939697
Epoch 996 LossHigh: 3.358152151107788
Epoch 997 LossHigh: 3.356896162033081
Epoch 998 LossHigh: 3.355640172958374
Epoch 999 LossHigh: 3.3543860912323


In [14]:
meanToTestNetModel=torch.nn.Sequential(
    torch.nn.Linear(1,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,1),
)

In [15]:
modelLow.eval()
likelihoodLow.eval()
modelHigh.eval()
likelihoodHigh.eval()
''' Predict at test points '''
# sample f_1 at xtest
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    NetLow = likelihoodLow(modelLow(XTest))
    NetHigh=likelihoodHigh(modelHigh(XTest))
    
    NetIn=NetLow+NetHigh

In [17]:
learning_rate=1e-3
loss_fn=torch.nn.MSELoss(reduction='sum')
optimizer=torch.optim.SGD(meanToTestNetModel.parameters(),lr=learning_rate)
Input=NetIn.mean.reshape(-1,1)
YTest=YTest.reshape(-1,1)
print(Input)
print(YTest)
for t in range(0,1000):
    predLast=meanToTestNetModel(Input)
    loss=loss_fn(predLast,YTest)
    if t%100==0:
        print("第",t,"轮：loss:",loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

tensor([[68.3248],
        [68.3938],
        [63.3249],
        [68.5021],
        [62.9255]])
tensor([[31.1779],
        [29.1824],
        [26.4095],
        [26.9644],
        [30.0900]])
第 0 轮：loss: 19.018659591674805
第 100 轮：loss: 16.542478561401367
第 200 轮：loss: 16.542478561401367
第 300 轮：loss: 16.542478561401367
第 400 轮：loss: 16.542478561401367
第 500 轮：loss: 16.542478561401367
第 600 轮：loss: 16.542478561401367
第 700 轮：loss: 16.542478561401367
第 800 轮：loss: 16.542478561401367
第 900 轮：loss: 16.542478561401367


In [19]:
with torch.no_grad():
    # Initialize plot
    PredLow = likelihoodLow(modelLow(XVal))
    PredHigh=likelihoodHigh(modelHigh(XVal))
    PredMean=PredHigh+PredLow
    mean=PredMean.mean.reshape(-1,1)
    meanlast=meanToTestNetModel(mean)
    print(meanlast)
    print(YVal)
#     f, ax = plt.subplots(1, 1, figsize=(4, 3))
#     # Get upper and lower confidence bounds
#     lower, upper = pred.confidence_region()
    
#     XTest.numpy()reshape, lower.numpy(), upper.numpy()
#     # Plot predictive means as blue line
#     ax.plot(XTest.numpy(), pred.mean.numpy(), 'b')
#     # Shade between the lower and upper confidence bounds
#     ax.fill_between(XTest.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
#     ax.set_ylim([-3, 3])
#     ax.legend([ 'Mean', 'Confidence'])

tensor([[28.7648],
        [28.7648],
        [28.7648],
        [28.7648]])
tensor([26.8370, 25.2800, 31.4606, 30.1555])


tensor([31.1779, 29.1824, 26.4095, 26.9644, 30.0900, 26.8370, 25.2800, 31.4606,
        30.1555])
tensor([31.1779, 29.1824, 26.4095, 26.9644, 30.0900, 26.8370, 25.2800, 31.4606,
        30.1555])
