-
Notifications
You must be signed in to change notification settings - Fork 43
/
inference.py
executable file
·114 lines (94 loc) · 3 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
sys.path.append(".") # dataset.read
import itertools
import json
import os
import shutil
import zipfile
from dataset.read_data import prepare_data
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
def transform2train(sample, start_month=1, only_roi=False):
if only_roi:
sample = sample[:, 10:13, 38:49, :] # filter roi
# first transform to 2d long type, same as SODA_train_roi, (12, 3, 11, 4)
cols = list(
itertools.product(
range(start_month + 11, start_month - 1, -1),
[-5, 0, 5],
[
190.0,
195.0,
200.0,
205.0,
210.0,
215.0,
220.0,
225.0,
230.0,
235.0,
240.0,
],
)
)
else:
cols = list(
itertools.product(
range(start_month, start_month + 12),
range(-55, 65, 5),
range(0, 360, 5),
)
)
data = pd.DataFrame(cols, columns=["month", "lat", "lon"])
data["year"] = 1
data["sst"] = sample[..., 0].reshape(-1)
data["t300"] = sample[..., 1].reshape(-1)
data["ua"] = sample[..., 2].reshape(-1)
data["va"] = sample[..., 3].reshape(-1)
return data
def predict_single(data_dir, file, model):
data = np.load(os.path.join(data_dir, file))
start_month = int(file.split("_")[2])
if start_month <= 0 or start_month >= 13:
print("month Error")
data = transform2train(data, start_month=start_month)
sst, t300, ua, va, month = prepare_data(data)
month = month - 1 # 与训练特征对齐
x = tuple(
[i[np.newaxis, ...].astype(np.float32) for i in [sst, t300, ua, va, month]]
)
y = model(x)
out = y.numpy().reshape(-1) + 0.01
return out
def predict(
data_dir="../tcdata/enso_round1_test_20210201", model_dir="../user_data/fine"
): # 提交时: '../tcdata/enso_round1_test_20210201'
if os.path.exists("../result"):
shutil.rmtree("../result", ignore_errors=True)
os.makedirs("../result")
model = tf.saved_model.load(model_dir)
for file in os.listdir(data_dir):
res = predict_single(data_dir, file, model)
np.save("../result/{}".format(file), res)
return
def compress(res_dir="../result", output_dir="result.zip"):
z = zipfile.ZipFile(output_dir, "w")
for d in os.listdir(res_dir):
z.write(res_dir + os.sep + d)
z.close()
def local_test():
model = tf.saved_model.load("../user_data/nn")
y = predict_single(
data_dir="../tcdata/enso_round1_test_20210201",
file="test_0144_01_12.npy",
model=model,
)
y = y + 0.01
print(y)
if __name__ == "__main__":
model_dir = "../user_data/nn"
# predict(model_dir=model_dir, data_dir='../tcdata/enso_final_test_data_B')
# compress()
local_test()