In [173]:
import copy

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader,TensorDataset
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torch import tensor
from pandas.core.frame import DataFrame
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error,mean_absolute_error
from sklearn.utils import shuffle
import scipy as sp
from sklearn.cluster import DBSCAN

In [174]:
def normalization(data,label):   # data[n,10,4] label[n,2]
    mm_x = MinMaxScaler()
    mm_y = MinMaxScaler()
    seq_len = data.shape[1]
    feature_num = data.shape[2]
    data = data.reshape(-1,feature_num)
    data = mm_x.fit_transform(data)
    data = data.reshape(-1,seq_len,feature_num)
    label = mm_y.fit_transform(label)
    return data,label,mm_y

In [175]:
def dbscan_predict(dbscan_model, X_new, metric=sp.spatial.distance.euclidean):
    # Result is noise by default
    y_new = np.ones(shape=len(X_new), dtype=int) * -1
    # Iterate all input samples for a label
    for j, x_new in enumerate(X_new):
        # Find a core sample closer than EPS
        for i, x_core in enumerate(dbscan_model.components_):
            if metric(x_new, x_core) < dbscan_model.eps:
                # Assign label of x_core to x_new
                y_new[j] = dbscan_model.labels_[dbscan_model.core_sample_indices_[i]]
                break

    return y_new

In [176]:
class Net(nn.Module):
    def __init__(self, feature_size, size_hidden,out_size,num_layers=1,dropout=0):
        super().__init__()
        self.feature_size = feature_size
        self.size_hidden = size_hidden
        self.num_layers = num_layers
        self.num_direction = 1  # 单向lstm
        self.rnn = nn.LSTM(input_size=feature_size, hidden_size=size_hidden, num_layers=num_layers,dropout=dropout, batch_first=True)
        self.out = nn.Linear(size_hidden, out_size)
        self.dropout=nn.Dropout(p=dropout)
    def forward(self, input):
        batch_size,seq_len=input.size()[0],input.size()[1]
        # print(f"batch_size:{batch_size}",f"seq_len:{seq_len}")
        h_0 = torch.rand(self.num_direction*self.num_layers,input.size(0),self.size_hidden)
        h_0=h_0.cuda()
        c_0 = torch.rand(self.num_direction*self.num_layers,input.size(0),self.size_hidden)
        c_0=c_0.cuda()
        # print(f"h_0 shape:{h_0.shape}")
        output, _ = self.rnn(input,(h_0,c_0))
        # output=self.dropout(output)    # 不加dropout效果会好不少
        pred = self.out(output)   # [32,10,2]
        pred = pred[:,-1,:]   # -> [32,2]?
        # print(f"pred_shape:{pred.shape}")
        return pred

In [177]:
# useful_tag = ['X', 'Y', 'SOG', 'COG', 'Heading']
cluster_tag = ['MMSI','LAT', 'LON','SOG','COG']
useful_tag = ['LAT', 'LON','SOG']
# predict_tag = ['X', 'Y']
# predict_tag = ['LAT', 'LON']
predict_tag = ['LAT', 'LON','SOG']
len_topredict = 5
features = ['MMSI', 'BaseDateTime', 'LAT', 'LON', 'SOG', 'COG', 'Heading', 'Status']
data_file_root_path = './data/path_data/'
ratio = 0.8
feature_size=len(useful_tag)
batch_size=32 # 最好是2的次幂
hidden_size = 256
output_size=len(predict_tag)
drop_out=0.3
lr = 0.01

In [178]:
class MyDataSet(Dataset):
    def __init__(self, df: DataFrame):
        self.x_data,self.y_data = df[cluster_tag].astype('float32').values,df[cluster_tag].astype('float32').values
        self.length = len(self.x_data)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.length

In [179]:
def norm_for_cluster(X):
    f = MinMaxScaler()
    data = f.fit_transform(X)
    return data

In [180]:
def get_cluster(Multidimensional_Points):
    print(Multidimensional_Points[:5])
    cluster = DBSCAN(eps=0.02,min_samples=7)
    cluster.fit(np.array(Multidimensional_Points))
    labels = cluster.labels_
    return labels


In [181]:
X = []
y = []
for_cluster_X = []
num = 0
for i in range(1000):
    data_file_path = data_file_root_path + 'id' + str(i) + '/'
    """
        还要做一步归一化
    """
    file_lst = os.listdir(data_file_path)
    for filename in file_lst:
        path = data_file_path + filename
        df = pd.read_csv(path)
        if len(y) < 20000:
            num += 1
            dataset = MyDataSet(df)
            for j in range(len_topredict, len(dataset)):
                t1, t2,t3 = dataset.x_data[j - len_topredict:j,1:4], dataset.y_data[j,1:4],dataset.x_data[j-len_topredict]   # 改预测的feature的话要改这里
                X.append(t1)
                y.append(t2)
                for_cluster_X.append(t3)
            # train_loader = DataLoader(dataset=dataset, batch_size=3)  # 这里batch_size就是一次从一个轨迹中取几个点
# 109条轨迹
print(num)
X,y,for_cluster_X = np.array(X),np.array(y),np.array(for_cluster_X)
print(X.shape)
print(y.shape)
X,y,mm_y = normalization(X,y)
labels = get_cluster(norm_for_cluster(for_cluster_X)) #对整个数据进行聚类

199
(20231, 5, 4)
(20231, 4)
[[0.30810362 0.45527586 0.57463074 0.05474096 0.78444445]
 [0.30810362 0.45527682 0.57459843 0.05767351 0.7372222 ]
 [0.30810362 0.45527527 0.57456064 0.05474096 0.73444444]
 [0.30810362 0.45527318 0.57452846 0.03812317 0.74333334]
 [0.30810362 0.45527133 0.574515   0.00879765 0.69111115]]


In [182]:
def split_data(X,y,z,split_ratio):
    train_size=int(len(y)*split_ratio)
    # X_data = torch.Tensor(np.array(X))
    # y_data = torch.Tensor(np.array(y))
    X_train = torch.Tensor(np.array(X[0:train_size]))
    y_train = torch.Tensor(np.array(y[0:train_size]))
    z_train = np.array(z[0:train_size])
    X_test = torch.Tensor(np.array(X[train_size:]))
    y_test = torch.Tensor(np.array(y[train_size:]))
    z_test = np.array(z[train_size:])
    return X_train,y_train,z_train,X_test,y_test,z_test


In [183]:
def get_dataset(X_train,y_train,X_test,y_test):
    train_dataset = TensorDataset(X_train,y_train)
    test_dataset = TensorDataset(X_test,y_test)
    return train_dataset,test_dataset

In [184]:
def train(net, train_iter, loss, epochs, lr):
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    iter = 0
    for epoch in range(epochs):
        for i,(X, y) in enumerate(train_iter):
            X,y = shuffle(X,y,random_state=131)
            X = X.cuda()
            y = y.cuda()
            # print(X.shape,y.shape)  #  X [32,10,5]  y [32,2]
            trainer.zero_grad()
            l = loss(net(X), y)
            iter += 1
            if iter % 100 == 0:
                print(f'iter: {iter}', f'loss：{l.sum()}')
            l.sum().backward()
            trainer.step()

In [185]:
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nets = list()
label_set = set(labels)
cluster_num =  len(label_set) if -1 not in label_set else len(label_set)-1

In [186]:
def test(net,X,y):
    X = X.cuda()
    y = y.cuda()
    net.eval()
    predict = net(X)
    predict = predict.data.cpu().numpy()
    # y_data_plot = np.reshape(y_data_plot,(-1,1))
    predict = mm_y.inverse_transform(predict)
    y_data_plot = mm_y.inverse_transform(y.data.cpu().numpy())
    print(y_data_plot[:20])
    print(predict[:20])
    # plt.plot(y_data_plot)
    # plt.plot(predict)
    # plt.legend(('real','predict'), fontsize='15')
    # plt.show()
    print(f'LAT:  mean_absolute_error: {mean_absolute_error(y_data_plot[:,0],predict[:,0])}, mean_squared_error: {mean_squared_error(y_data_plot[:,0],predict[:,0])}')
    print(f'LON:  mean_absolute_error: {mean_absolute_error(y_data_plot[:,1],predict[:,1])}, mean_squared_error: {mean_squared_error(y_data_plot[:,1],predict[:,1])}')
    print(f'SOG:  mean_absolute_error: {mean_absolute_error(y_data_plot[:,2],predict[:,2])}, mean_squared_error: {mean_squared_error(y_data_plot[:,2],predict[:,2])}')
    # print(f'COG:  mean_absolute_error: {mean_absolute_error(y_data_plot[:,3],predict[:,3])}, mean_squared_error: {mean_squared_error(y_data_plot[:,3],predict[:,3])}')
    # print(f'Heading:  mean_absolute_error: {mean_absolute_error(y_data_plot[:,4],predict[:,4])}, mean_squared_error: {mean_squared_error(y_data_plot[:,4],predict[:,4])}')

In [187]:
class SmallNet:
    def __init__(self,train_iter,test_data_X,test_data_y):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net=Net(feature_size=feature_size,size_hidden=hidden_size,out_size=output_size,dropout=drop_out,num_layers=1).to(self.device)
        self.loss=nn.MSELoss(reduction='none').to(self.device)
        self.train_iter=train_iter
        self.test_data_X=torch.stack(test_data_X)
        self.test_data_y=torch.stack(test_data_y)
    def train(self):
        print("number of training batch: ",len(self.train_iter))
        print("Training----")
        train(self.net,self.train_iter,self.loss,epochs,lr)
        print("Train Done----")
    def test(self):
        print("Testing----")
        test(self.net,self.test_data_X,self.test_data_y)
        print("Test Done----")

In [188]:
# loss=nn.L1Loss(reduction='sum')
X,y,labels = shuffle(X,y,labels,random_state=711)

X_train,y_train,labels_train,X_test,y_test,labels_test = split_data(X,y,labels,ratio)
train_data_for_each_net = []
train_predict_for_each_net = []
test_data_for_each_net = []
test_predict_for_each_net = []
for i in range(cluster_num):
    train_data_for_each_net.append(list())
    train_predict_for_each_net.append(list())
    test_data_for_each_net.append(list())
    test_predict_for_each_net.append(list())

num1=0
num2=0
for i in range(len(X_train)):
    id = labels_train[i]
    if id != -1:
        train_data_for_each_net[id].append(X_train[i])
        train_predict_for_each_net[id].append(y_train[i])
    else:
        num1+=1
for i in range(len(X_test)):
    id = labels_test[i]
    if id != -1:
        test_data_for_each_net[id].append(X_test[i])
        test_predict_for_each_net[id].append(y_test[i])
    else:
        num2+=1
print(f"unuse train_data:{num1} , unuse test_data:{num2}",)
for i in range(cluster_num):
    if len(train_data_for_each_net[i])<100 or len(test_data_for_each_net[i])==0:
        continue
    train_dataset,test_dataset = get_dataset(torch.stack(train_data_for_each_net[i]),torch.stack(train_predict_for_each_net[i]),torch.stack(test_data_for_each_net[i]), torch.stack(test_predict_for_each_net[i]))
    train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,drop_last=True)
    test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,drop_last=True)  #batch_size=32   [32,10,5]
    net = SmallNet(train_iter=train_loader,test_data_X=test_data_for_each_net[i],test_data_y=test_predict_for_each_net[i])
    net.train()
    net.test()

unuse train_data:757 , unuse test_data:190
number of training batch:  143
Training----




iter: 100 loss：0.0838337242603302
iter: 200 loss：0.1395207643508911
iter: 300 loss：0.24567347764968872
iter: 400 loss：0.18065668642520905
iter: 500 loss：0.2847444713115692
iter: 600 loss：1.3895087242126465
iter: 700 loss：0.1973751336336136
iter: 800 loss：0.18649566173553467
iter: 900 loss：0.5717215538024902
iter: 1000 loss：1.0600852966308594
iter: 1100 loss：1.212276816368103
iter: 1200 loss：0.10978655517101288
iter: 1300 loss：0.05253786966204643
iter: 1400 loss：0.18154768645763397
iter: 1500 loss：0.06481243669986725
iter: 1600 loss：0.1805438995361328
iter: 1700 loss：0.48834413290023804
iter: 1800 loss：1.6129734516143799
iter: 1900 loss：1.0272668600082397
iter: 2000 loss：0.1261216700077057
iter: 2100 loss：0.687880277633667
iter: 2200 loss：1.1843634843826294
iter: 2300 loss：0.11492057144641876
iter: 2400 loss：0.746954619884491
iter: 2500 loss：0.6452222466468811
iter: 2600 loss：0.5478875637054443
iter: 2700 loss：0.28424301743507385
iter: 2800 loss：0.22517603635787964
Train Done----
Testin



iter: 100 loss：2.003721237182617
Train Done----
Testing----
[[ 4.07110214e+01 -7.40295181e+01  7.69999933e+00  1.85100006e+02]
 [ 4.06890907e+01 -7.40102005e+01  7.29999971e+00  2.01000000e+02]
 [ 4.06866417e+01 -7.40116577e+01  4.50000000e+00  3.02000027e+01]
 [ 4.07016792e+01 -7.40189285e+01  6.09999990e+00  1.61100006e+02]
 [ 4.06813507e+01 -7.40395508e+01  3.00000012e-01  2.08500015e+02]
 [ 4.06994934e+01 -7.40013809e+01  6.50000000e+00  2.09600021e+02]
 [ 4.06815109e+01 -7.40396194e+01  4.00000006e-01  2.09899994e+02]
 [ 4.06659889e+01 -7.40304413e+01  4.90000010e+00  1.05399994e+02]
 [ 4.07251892e+01 -7.40254288e+01  7.09999990e+00  2.07300018e+02]
 [ 4.07357712e+01 -7.40154572e+01  5.90000010e+00  4.00000000e+00]
 [ 4.07550316e+01 -7.40170822e+01  8.89999962e+00  1.96699982e+02]
 [ 4.06813316e+01 -7.40396881e+01  5.00000000e-01  2.08500015e+02]
 [ 4.06942482e+01 -7.40347824e+01  9.30000019e+00  1.95800003e+02]
 [ 4.06815414e+01 -7.40404510e+01  5.00000000e-01  2.28800003e+02]
 [



iter: 100 loss：1.625928282737732
Train Done----
Testing----
[[ 40.76776   -74.00616     5.7        30.3      ]
 [ 40.70515   -73.99351    10.         71.9      ]
 [ 40.68719   -74.03613     0.9       200.3      ]
 [ 40.68402   -74.0416      1.2       204.5      ]
 [ 40.72926   -74.01735     2.6         6.2      ]
 [ 40.73073   -74.01468     2.4        64.5      ]
 [ 40.68798   -74.01017     5.9        33.3      ]
 [ 40.67056   -74.02239     5.8        11.599999 ]
 [ 40.75393   -74.01437     6.1       195.8      ]
 [ 40.716     -74.0194      3.          9.4      ]
 [ 40.75086   -74.01352     5.9        23.7      ]
 [ 40.7186    -74.01896     2.9        10.1      ]
 [ 40.74267   -74.01645     5.8         6.2999997]
 [ 40.74496   -74.01819     6.5       199.30002  ]
 [ 40.714603  -74.02574     7.3999996 231.3      ]
 [ 40.69858   -74.00769     6.        260.1      ]
 [ 40.70032   -74.01812     7.2       340.8      ]
 [ 40.71739   -74.01919     2.9         7.3000007]
 [ 40.71238   -74.0272



iter: 100 loss：1.530205249786377
iter: 200 loss：1.862841248512268
iter: 300 loss：2.043771266937256
iter: 400 loss：0.7231346368789673
iter: 500 loss：1.1067371368408203
iter: 600 loss：1.4071898460388184
iter: 700 loss：1.6303383111953735
iter: 800 loss：1.3365509510040283
iter: 900 loss：1.7228069305419922
iter: 1000 loss：1.700258493423462
iter: 1100 loss：0.6264661550521851
Train Done----
Testing----
[[ 3.99057312e+01 -7.40181503e+01  2.70000005e+00  3.30100006e+02]
 [ 3.99005203e+01 -7.40172806e+01  1.10000002e+00  1.51600006e+02]
 [ 3.99295807e+01 -7.40234299e+01  3.00000012e-01  1.15899994e+02]
 [ 3.99178085e+01 -7.40315323e+01  1.00000000e+00  3.07999992e+01]
 [ 3.99287300e+01 -7.40260620e+01  5.00000000e-01  2.37699997e+02]
 [ 3.99275284e+01 -7.40209274e+01  3.00000012e-01  2.06399994e+02]
 [ 3.99252586e+01 -7.40229416e+01  3.00000012e-01  2.14599991e+02]
 [ 3.99145889e+01 -7.40188065e+01  1.20000005e+00  1.41600006e+02]
 [ 3.99348907e+01 -7.51415482e+01  1.00000001e-01  1.91800003e+02



number of training batch:  5
Training----
iter: 100 loss：0.00823121052235365
Train Done----
Testing----
[[ 27.032043 -90.38795    8.848148 115.34013 ]
 [ 27.205002 -90.851715   8.5      112.06667 ]
 [ 27.6903   -92.16282    9.       113.3     ]
 [ 27.5917   -91.9118     9.2      111.6     ]
 [ 27.54806  -91.78894    8.6      112.5     ]
 [ 27.6146   -91.97415    8.8      114.3     ]
 [ 26.942745 -90.14004    8.9      113.984955]
 [ 27.084732 -90.517525   8.76875  113.90909 ]
 [ 27.68009  -92.13676    8.9      113.2     ]
 [ 27.53928  -91.76462    8.4      110.8     ]
 [ 27.475452 -91.58399    8.6      110.19999 ]
 [ 27.69674  -92.18159    9.       110.7     ]
 [ 27.43654  -91.45984    8.3      105.7     ]
 [ 27.195713 -90.82681    8.5      112.933334]
 [ 27.71968  -92.24833    9.       112.3     ]
 [ 27.221392 -90.89356    8.55     112.03893 ]
 [ 27.71839  -92.24477    9.1      110.1     ]
 [ 27.104992 -90.573      8.5      112.4     ]
 [ 27.16479  -90.73642    8.5      109.399994]
 [ 



iter: 100 loss：1.3172872066497803
Train Done----
Testing----
[[ 2.9743120e+01 -9.5195679e+01  1.0000000e-01  2.5150000e+02]
 [ 2.9720600e+01 -9.5247299e+01  6.6999998e+00  2.9879999e+02]
 [ 2.9743000e+01 -9.5194519e+01  6.6999998e+00  2.4370000e+02]
 [ 2.9737669e+01 -9.5203423e+01  7.4999995e+00  2.2160001e+02]
 [ 2.9723619e+01 -9.5250977e+01  6.8000002e+00  3.1629999e+02]
 [ 2.9724930e+01 -9.5216782e+01  6.5999999e+00  2.5639999e+02]
 [ 2.9724768e+01 -9.5231033e+01  6.6999998e+00  2.4700002e+02]
 [ 2.9745050e+01 -9.5190117e+01  6.6999998e+00  2.4970001e+02]
 [ 2.9749029e+01 -9.5289162e+01  5.5000000e+00  6.0000000e+01]
 [ 2.9613670e+01 -9.4994431e+01  6.0000000e+00  2.6410001e+02]
 [ 2.9724800e+01 -9.5218498e+01  5.0000000e+00  2.7200000e+02]
 [ 2.9763950e+01 -9.5097504e+01  4.6999998e+00  2.7989999e+02]
 [ 2.9755070e+01 -9.5177254e+01  1.2000000e+00  3.3639999e+02]
 [ 2.9745220e+01 -9.5188797e+01  3.3000000e+00  2.4970001e+02]
 [ 2.9746220e+01 -9.5180771e+01  4.8000002e+00  2.6279999



iter: 100 loss：0.03664398193359375
Train Done----
Testing----
[[ 30.02863  -88.20493    9.       230.1     ]
 [ 29.68512  -88.66509    9.4      230.8     ]
 [ 29.62227  -88.74947    9.8      229.7     ]
 [ 29.50714  -88.87017   10.       221.3     ]
 [ 29.74687  -88.58004    9.1      232.8     ]
 [ 29.42757  -88.9366     9.5      173.      ]
 [ 29.96128  -88.29375    9.1      227.8     ]
 [ 28.85905  -88.86724    9.4      199.9     ]
 [ 29.50953  -88.86772   10.       220.8     ]
 [ 28.88474  -88.85653    9.5      199.6     ]
 [ 29.63927  -88.72735    9.8      225.69998 ]
 [ 29.65785  -88.70385    9.5      227.5     ]
 [ 29.70445  -88.63832    9.5      231.9     ]
 [ 29.44781  -88.93015    9.9      213.3     ]
 [ 29.560148 -88.81613    9.7      222.3     ]
 [ 30.04835  -88.17887    9.2      229.09999 ]
 [ 29.85796  -88.43011    9.1      230.2     ]
 [ 29.49553  -88.88224    9.9      223.2     ]
 [ 29.624388 -88.74676    9.7      229.49998 ]
 [ 29.72188  -88.61444    9.3      228.8     



Train Done----
Testing----
[[ 29.25069 -88.86848  10.      157.6    ]
 [ 29.23327 -88.85938   9.9     152.9    ]
 [ 29.13702 -88.80557   9.8     157.     ]
 [ 29.14875 -88.81106   9.7     156.9    ]
 [ 29.08209 -88.78159   9.2     161.6    ]
 [ 29.3002  -88.89226   9.7     157.8    ]
 [ 29.19733 -88.83641  10.      149.9    ]
 [ 29.39619 -88.92982   9.9     167.6    ]
 [ 29.30316 -88.89369   9.9     157.3    ]
 [ 29.12542 -88.80008   9.7     157.7    ]
 [ 29.31508 -88.89931  10.1     156.9    ]
 [ 29.40555 -88.93215   9.8     168.     ]
 [ 29.25715 -88.87155  10.2     156.4    ]
 [ 29.21965 -88.85063   9.9     151.     ]
 [ 29.18939 -88.83135   9.8     152.1    ]
 [ 29.41822 -88.93483   9.8     169.4    ]
 [ 29.3618  -88.91867   9.9     163.     ]
 [ 29.26603 -88.87583   9.7     157.     ]
 [ 29.09095 -88.7852    9.2     158.8    ]
 [ 29.21405 -88.84713   9.7     151.     ]]
[[ 29.10843  -89.18654    9.610907 158.60008 ]
 [ 29.097551 -89.124626   9.375472 158.5663  ]
 [ 29.113419 -89.2



number of training batch:  10
Training----
iter: 100 loss：0.5486551523208618
iter: 200 loss：0.7220097184181213
Train Done----
Testing----
[[ 2.8695530e+01 -9.5947273e+01  3.0999999e+00  2.3950000e+02]
 [ 2.8690750e+01 -9.5957497e+01  1.1000000e+00  2.4420000e+02]
 [ 2.8754829e+01 -9.5686600e+01  4.5999999e+00  2.4170001e+02]
 [ 2.8715231e+01 -9.5894783e+01  4.0000001e-01  2.4389999e+02]
 [ 2.8671869e+01 -9.5998817e+01  1.0000000e-01  2.0289999e+02]
 [ 2.8664730e+01 -9.6013527e+01  3.3000000e+00  2.4180000e+02]
 [ 2.8742670e+01 -9.5805931e+01  4.5000000e+00  2.5610001e+02]
 [ 2.8763248e+01 -9.5670624e+01  4.1999998e+00  2.3510001e+02]
 [ 2.8668131e+01 -9.6006378e+01  3.7999997e+00  2.4300000e+02]
 [ 2.8704399e+01 -9.5926949e+01  1.3000000e+00  2.7710001e+02]
 [ 2.8673830e+01 -9.5994202e+01  2.5999999e+00  2.4270000e+02]
 [ 2.8676220e+01 -9.5988831e+01  2.4000001e+00  6.4900002e+01]
 [ 2.8671930e+01 -9.5998283e+01  6.9999999e-01  2.7689999e+02]
 [ 2.8691050e+01 -9.5956718e+01  1.4000000e



iter: 100 loss：0.010411111637949944
iter: 200 loss：0.02142507955431938
Train Done----
Testing----
[[ 28.79199 -93.44741   9.9     182.19998]
 [ 29.50778 -93.36579   9.5     184.80002]
 [ 29.74811 -93.34197  10.6     172.7    ]
 [ 29.5441  -93.36136   9.5     187.6    ]
 [ 28.73532 -93.44905   9.5     183.2    ]
 [ 29.43076 -93.37343   9.6     188.29999]
 [ 29.4082  -93.3768    8.8     182.3    ]
 [ 28.88197 -93.43623   9.5     183.70001]
 [ 29.6427  -93.35006   8.8     182.6    ]
 [ 29.48673 -93.36778   9.3     181.9    ]
 [ 29.05653 -93.41964  10.      185.99998]
 [ 29.13536 -93.41263   9.      191.9    ]
 [ 28.95786 -93.42936   9.5     185.5    ]
 [ 28.97651 -93.42752   9.2     191.1    ]
 [ 28.64872 -93.45983   9.      187.     ]
 [ 29.11653 -93.41504   9.7     187.50002]
 [ 29.37763 -93.38141  10.      188.8    ]
 [ 28.87896 -93.43659   9.8     188.9    ]
 [ 29.12904 -93.4135    9.8     180.8    ]
 [ 28.50583 -93.48234   9.4     191.9    ]]
[[ 28.861967  -93.479324    9.073092  184



Train Done----
Testing----
[[ 38.0125   -90.07157    5.4      277.1     ]
 [ 38.19312  -90.3355     5.3      325.      ]
 [ 38.08962  -90.21008    5.2      303.      ]
 [ 38.09575  -90.22211    5.1      305.      ]
 [ 38.05765  -90.12231    4.4      342.      ]
 [ 38.24485  -90.36578    5.1      355.      ]
 [ 38.06571  -90.13548    5.       273.      ]
 [ 38.004158 -90.04808    5.6      299.2     ]
 [ 38.01003  -90.06131    5.4      296.      ]
 [ 38.08082  -90.19596    5.6      309.      ]
 [ 38.26162  -90.36938    4.9      347.      ]
 [ 38.09301  -90.21791    5.1      310.      ]
 [ 38.2048   -90.34599    5.       329.8     ]
 [ 38.22067  -90.35693    5.1      333.      ]
 [ 38.04667  -90.12095    5.7      352.      ]
 [ 38.01071  -90.06333    5.4      292.      ]
 [ 38.12315  -90.25069    5.1      334.      ]
 [ 38.00505  -90.05002    5.6      301.      ]
 [ 38.17981  -90.31918    5.1      291.      ]
 [ 38.19675  -90.33886    5.1      320.3     ]]
[[ 37.4433    -90.681175    6.19



iter: 100 loss：1.043613314628601
Train Done----
Testing----
[[ 3.836152e+01 -9.035643e+01  4.700000e+00  1.700000e+01]
 [ 3.858058e+01 -9.021870e+01  3.000000e-01  3.550000e+02]
 [ 3.855533e+01 -9.024247e+01  4.900000e+00  3.600000e+01]
 [ 3.842932e+01 -9.029118e+01  4.800000e+00  2.900000e+01]
 [ 3.842804e+01 -9.029224e+01  4.800000e+00  3.400000e+01]
 [ 3.836627e+01 -9.035453e+01  5.100000e+00  8.000000e+00]
 [ 3.860028e+01 -9.019428e+01  1.000000e-01  3.180000e+02]
 [ 3.853335e+01 -9.025754e+01  4.900000e+00  2.500000e+01]
 [ 3.841481e+01 -9.031219e+01  5.600000e+00  5.600000e+01]
 [ 3.836558e+01 -9.035469e+01  5.000000e+00  1.400000e+01]
 [ 3.854687e+01 -9.024863e+01  4.800000e+00  2.050000e+01]
 [ 3.856772e+01 -9.023461e+01  1.000000e-01  1.300000e+01]
 [ 3.843812e+01 -9.028667e+01  4.600000e+00  1.800000e+01]
 [ 3.857423e+01 -9.022650e+01  2.200000e+00  4.700000e+01]
 [ 3.841097e+01 -9.031902e+01  4.100000e+00  3.400000e+01]
 [ 3.839043e+01 -9.034090e+01  4.500000e+00  4.700000e+



iter: 100 loss：1.010598063468933
iter: 200 loss：0.21896010637283325
iter: 300 loss：1.899187445640564
iter: 400 loss：0.9307044744491577
iter: 500 loss：0.16603393852710724
iter: 600 loss：0.07573643326759338
Train Done----
Testing----
[[ 3.9966709e+01 -8.0743759e+01  4.3000002e+00  3.3000000e+01]
 [ 4.0264858e+01 -7.9898628e+01  3.0999999e+00  2.2089999e+02]
 [ 3.9996319e+01 -8.0738640e+01  4.1999998e+00  3.5300000e+02]
 [ 3.9993790e+01 -7.9948402e+01  2.8000000e+00  2.7879999e+02]
 [ 3.9742008e+01 -8.0851517e+01  4.4000001e+00  3.0400000e+02]
 [ 3.9987919e+01 -8.0737633e+01  4.3000002e+00  3.5500000e+02]
 [ 4.0233871e+01 -8.0657860e+01  4.4000001e+00  1.8000000e+01]
 [ 4.0134930e+01 -8.0704697e+01  1.1000000e+00  8.0000000e+00]
 [ 4.0176640e+01 -7.9850151e+01  5.0000000e+00  1.5000000e+00]
 [ 4.0253151e+01 -8.0638672e+01  4.3000002e+00  5.3000000e+01]
 [ 4.0240299e+01 -7.9957253e+01  4.8000002e+00  2.5080002e+02]
 [ 4.0044540e+01 -7.9889290e+01  3.5999999e+00  1.4220000e+02]
 [ 3.9602859



number of training batch:  13
Training----
iter: 100 loss：0.15218301117420197
iter: 200 loss：0.13346043229103088
Train Done----
Testing----
[[ 42.60375   -82.52842     9.3       229.3      ]
 [ 43.14806   -82.40963    12.3       178.4      ]
 [ 42.54551   -82.60404     7.8       245.19998  ]
 [ 42.77927   -82.47045    10.3       168.6      ]
 [ 43.56327   -82.42739    12.7       175.2      ]
 [ 42.590633  -82.54867    10.2       220.1      ]
 [ 42.36541   -82.90006    12.4       238.4      ]
 [ 42.65089   -82.51123    10.2       178.3      ]
 [ 42.90668   -82.46353     9.9       198.9      ]
 [ 43.42278   -82.41135    12.8       180.60002  ]
 [ 42.53134   -82.64387     7.4999995 244.40001  ]
 [ 42.73071   -82.48359    10.        191.6      ]
 [ 43.22289   -82.4115     12.5       178.3      ]
 [ 42.7484    -82.47392     9.3       205.7      ]
 [ 42.887974  -82.47301    10.        187.       ]
 [ 42.12261   -83.12507    11.        186.1      ]
 [ 42.28986   -83.09624     9.        206.20



Train Done----
Testing----
[[ 41.695248 -82.17421   12.5      116.      ]
 [ 41.67997  -82.13211   12.6      115.899994]
 [ 41.60519  -81.92786   12.5      116.80001 ]
 [ 41.78225  -82.39361   12.3      118.4     ]
 [ 41.53803  -81.75935    9.5      124.2     ]
 [ 41.72156  -82.24176   12.5      117.9     ]
 [ 41.67311  -82.11312   12.6      115.6     ]
 [ 41.75     -82.31341   12.4      117.99999 ]
 [ 41.53202  -81.75118    8.2      136.4     ]
 [ 41.55105  -81.78883   11.       120.3     ]
 [ 41.79934  -82.43619   12.       118.4     ]
 [ 41.812416 -82.4682    11.9      120.2     ]
 [ 41.67466  -82.11741   12.6      115.700005]
 [ 41.83833  -82.5296    11.6      118.6     ]
 [ 41.59868  -81.9106    12.5      116.7     ]
 [ 41.84816  -82.55101   11.3      122.7     ]
 [ 41.7481   -82.30866   12.4      118.19999 ]
 [ 41.79004  -82.41298   12.1      118.19999 ]
 [ 41.53965  -81.76288    9.6      120.3     ]
 [ 41.66779  -82.09842   12.6      115.700005]]
[[ 41.861656  -81.76531    12.71



iter: 100 loss：0.10935482382774353
iter: 200 loss：0.13178759813308716
Train Done----
Testing----
[[ 28.58976  -89.28275   11.9      154.7     ]
 [ 28.73512  -89.37354   11.7      151.2     ]
 [ 27.97314  -88.94692   11.7      156.8     ]
 [ 29.92308  -89.94      12.4       90.1     ]
 [ 28.75403  -89.38562   11.6      151.2     ]
 [ 29.52364  -89.72772   14.599999 113.00001 ]
 [ 29.45203  -89.60549   14.2      118.9     ]
 [ 28.56129  -89.26731   11.8      154.2     ]
 [ 28.15962  -89.04172   11.9      154.4     ]
 [ 28.02899  -88.97406   11.9      156.6     ]
 [ 27.813879 -88.85687   11.4      151.      ]
 [ 29.49386  -89.69915   14.4      150.2     ]
 [ 29.51902  -89.71868   13.8      133.6     ]
 [ 29.88106  -89.90248   13.299999 174.6     ]
 [ 29.50146  -89.70415   14.1      149.7     ]
 [ 28.760202 -89.38958   11.6      150.9     ]
 [ 29.71844  -89.98578   14.1      154.4     ]
 [ 29.35767  -89.46261   12.7       35.1     ]
 [ 28.21255  -89.06989   11.9      154.4     ]
 [ 29.6165



iter: 100 loss：0.020805612206459045
Train Done----
Testing----
[[ 24.32393  -81.37655   14.799999  75.      ]
 [ 24.23315  -81.74052   15.8       74.      ]
 [ 24.30447  -81.45367   15.3       74.      ]
 [ 24.4306   -80.96743   15.399999  75.      ]
 [ 24.45997  -80.84758   15.399999  74.      ]
 [ 24.54978  -80.51525   17.3       67.      ]
 [ 24.3639   -81.22312   14.599999  72.      ]
 [ 24.427    -80.98248   15.399999  74.      ]
 [ 24.25998  -81.63373   14.999999  74.      ]
 [ 24.37237  -81.19433   14.799999  72.      ]
 [ 24.31513  -81.41143   15.1       74.      ]
 [ 24.6063   -80.35565   18.2       42.      ]
 [ 24.43813  -80.93657   15.399999  74.      ]
 [ 24.24273  -81.7022    15.5       74.      ]
 [ 24.33803  -81.31792   14.599999  75.      ]
 [ 24.39308  -81.11963   14.999999  74.      ]
 [ 24.37885  -81.17208   14.799999  72.      ]
 [ 24.47903  -80.77018   15.1       74.      ]
 [ 24.29003  -81.51212   15.199999  75.      ]
 [ 24.48563  -80.74345   15.5       74.     



iter: 100 loss：0.024181270971894264
iter: 200 loss：0.030167508870363235
Train Done----
Testing----
[[ 25.48183  -79.90905   18.3        6.      ]
 [ 27.888678 -79.66673   18.5        6.      ]
 [ 27.81128  -79.67533   18.7        5.      ]
 [ 26.05873  -79.82628   17.9        9.      ]
 [ 25.95912  -79.84472   17.9        9.      ]
 [ 26.10793  -79.81952   17.9        1.      ]
 [ 26.63588  -79.78262   18.9        8.      ]
 [ 26.494051 -79.80232   18.7        5.      ]
 [ 26.64775  -79.78065   19.         8.      ]
 [ 25.70713  -79.8843    18.1        5.      ]
 [ 25.643778 -79.89028   17.9        5.      ]
 [ 28.01728  -79.65175   18.4        5.      ]
 [ 26.208551 -79.81653   18.2        1.      ]
 [ 27.63875  -79.69105   19.         4.      ]
 [ 26.49948  -79.80177   18.6        5.      ]
 [ 26.37517  -79.81007   18.4        2.      ]
 [ 27.4115   -79.71318   19.4        5.      ]
 [ 27.82272  -79.6741    18.7        5.      ]
 [ 27.9383   -79.66095   18.5        6.      ]
 [ 26.63



Train Done----
Testing----
[[ 28.73538  -79.6593     2.8      346.      ]
 [ 28.78872  -79.67562    2.7      345.      ]
 [ 28.707668 -79.65263    2.6      349.      ]
 [ 28.67587  -79.64765    2.8      354.      ]
 [ 28.79198  -79.67647    3.       347.      ]
 [ 28.5753   -79.65768    3.2        1.      ]
 [ 28.73463  -79.6591     2.8      346.      ]
 [ 28.77733  -79.67152    3.       345.      ]
 [ 28.76277  -79.66698    2.9      344.      ]
 [ 28.74108  -79.66082    2.6      346.      ]
 [ 28.67487  -79.64755    2.7      355.      ]
 [ 28.75802  -79.6656     3.       347.      ]
 [ 28.74565  -79.66203    2.9      346.      ]
 [ 28.73018  -79.65788    2.7      345.      ]
 [ 28.77097  -79.66953    3.1      345.      ]
 [ 28.731968 -79.65838    2.8      346.      ]
 [ 28.70505  -79.65208    2.7      348.      ]
 [ 28.69665  -79.65053    2.9      351.      ]
 [ 28.68127  -79.64818    2.9      356.      ]
 [ 28.5732   -79.65772    3.2      358.      ]]
[[ 28.599     -80.09809     2.61



iter: 100 loss：0.08209829032421112
iter: 200 loss：0.053070373833179474
iter: 300 loss：0.03042154386639595
iter: 400 loss：0.017518293112516403
Train Done----
Testing----
[[ 43.83012 -87.62647   8.4     188.8    ]
 [ 41.83387 -87.4802    8.4     167.7    ]
 [ 42.7181  -87.6981    8.6     197.1    ]
 [ 41.76577 -87.45985   8.3     163.5    ]
 [ 43.04578 -87.83457   8.4     179.2    ]
 [ 42.3692  -87.73287   8.5     177.5    ]
 [ 42.58627 -87.74897   8.7     196.9    ]
 [ 43.32833 -87.83285   8.      177.6    ]
 [ 42.37703 -87.73362   8.6     174.     ]
 [ 41.88592 -87.49763   8.3     167.7    ]
 [ 43.09595 -87.83428   8.4     177.5    ]
 [ 41.88855 -87.49838   8.3     168.6    ]
 [ 43.92475 -87.60192   8.3     189.00002]
 [ 43.9099  -87.60568   8.4     190.6    ]
 [ 43.5062  -87.74257   8.4     189.7    ]
 [ 43.8187  -87.62935   8.3     190.10002]
 [ 43.96405 -87.59228   8.4     188.60002]
 [ 41.79707 -87.46903   8.4     167.9    ]
 [ 43.7668  -87.6428    8.3     190.8    ]
 [ 43.34978 -8