-
Notifications
You must be signed in to change notification settings - Fork 74
/
02_find_best_model.py
executable file
·61 lines (50 loc) · 1.87 KB
/
02_find_best_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Finding the best fitting variogram model
----------------------------------------
"""
import numpy as np
import gstools as gs
from matplotlib import pyplot as plt
###############################################################################
# Generate a synthetic field with an exponential model.
x = np.random.RandomState(19970221).rand(1000) * 100.0
y = np.random.RandomState(20011012).rand(1000) * 100.0
model = gs.Exponential(dim=2, var=2, len_scale=8)
srf = gs.SRF(model, mean=0, seed=19970221)
field = srf((x, y))
###############################################################################
# Estimate the variogram of the field with 40 bins and plot the result.
bins = np.arange(40)
bin_center, gamma = gs.vario_estimate_unstructured((x, y), field, bins)
plt.scatter(bin_center, gamma, label="data")
ax = plt.gca()
###############################################################################
# Define a set of models to test.
models = {
"gaussian": gs.Gaussian,
"exponential": gs.Exponential,
"matern": gs.Matern,
"stable": gs.Stable,
"rational": gs.Rational,
"linear": gs.Linear,
"circular": gs.Circular,
"spherical": gs.Spherical,
}
scores = {}
###############################################################################
# Iterate over all models, fit their variogram and calculate the r2 score.
for model in models:
fit_model = models[model](dim=2)
para, pcov, r2 = fit_model.fit_variogram(bin_center, gamma, return_r2=True)
fit_model.plot(x_max=40, ax=ax)
scores[model] = r2
###############################################################################
# Create a ranking based on the score and determine the best models
ranking = [
(k, v)
for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)
]
print("RANKING")
for i, (model, score) in enumerate(ranking, 1):
print(i, model, score)
plt.show()