In [None]:
import random
import numpy as np
import pandas as pd
import torch
from data import MyData
from torch.utils.data import DataLoader
from model import BaseLSTM,MyModel,MyLSTM
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mycolors = plt.cm.RdBu(list(range(0,256,256//10)))

def setup_seed(seed):
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	np.random.seed(seed)
	random.seed(seed)

In [1]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shapereader

ModuleNotFoundError: No module named 'cartopy'

In [None]:
def test_19():
	testdata = MyData(data_path='./data/TestData.json', l=4, frac=1, TaifengID=19)
	batch_size = 4096
	RMSE = {}

	GT_Predict = pd.DataFrame(np.full([len(testdata), 5], np.nan),
	                          columns=['True_lat', 'True_lon', 'Predict_lat', 'Predict_lon', 'SE'])

	model = BaseLSTM().to(device)
	model.load_state_dict(torch.load(f'checkpoints/BaseLSTM/best.pth'))
	test_dataloader = DataLoader(testdata, batch_size=batch_size, shuffle=False)

	for i, batch in enumerate(test_dataloader):
		X, y, _, _ = batch
		model.eval()
		predict_y = testdata.inverse_norm(model(X.to(device)).cpu().detach()).numpy()
		true_y = testdata.inverse_norm(y).numpy()
	GT_Predict.iloc[:, :2] = true_y[:, :2] * 0.1

	GT_Predict.iloc[:, 2:4] = predict_y[:, :2] * 0.1
	GT_Predict.iloc[:, 4] = np.sum(np.square(GT_Predict.iloc[:, :2].values - GT_Predict.iloc[:, 2:4].values), axis=1)
    return GT_Predict


In [None]:
# points_lab和points_pre是二维np.array数组，两列分别是纬度和经度
df_lab = pd.DataFrame(points_lab, columns=['lat', 'lon'])
df_pre = pd.DataFrame(points_pre, columns=['lat', 'lon'])

In [None]:
proj = ccrs.PlateCarree()
fig = plt.figure(figsize=(7, 5), dpi=200)  # 创建画布
ax = plt.axes(projection=ccrs.PlateCarree())# 创建子图

ax.coastlines(resolution='50m', lw=0.5)
ax.add_feature(cfeature.LAND.with_scale('50m')) # 添加陆地
ax.add_feature(cfeature.COASTLINE.with_scale('50m'),lw=0.25)# 添加海岸线
ax.add_feature(cfeature.RIVERS.with_scale('50m'),lw=0.4)# 添加河流
ax.add_feature(cfeature.LAKES.with_scale('50m'))# 添加湖泊
ax.add_feature(cfeature.BORDERS.with_scale('50m'), linestyle='-',lw=0.5)# 不推荐，我国丢失了藏南、台湾等领土
ax.add_feature(cfeature.OCEAN.with_scale('50m'))#添加海洋

# 调节字体大小
gl = ax.gridlines(draw_labels=True, linewidth=0.2, color='k', alpha=0.5, linestyle='--')
gl.xlabel_style={'size':6.5}
gl.ylabel_style={'size':6.5}

# 经纬度范围
extent=[95,170,0,40]
ax.set_extent(extent,crs=proj)

line1, = ax.plot(df_lab["lon"],df_lab["lat"],marker='o',markersize=3, linewidth=0.5, c="r",transform=ccrs.PlateCarree())
line2, = ax.plot(df_pre["lon"],df_pre["lat"],marker='o',markersize=3, linewidth=0.5, c="b",transform=ccrs.PlateCarree())

ax = plt.gca() #返回坐标轴
ax.legend(handles=[line1, line2], labels=['true','pred'],loc='upper right', fontsize=6)