Skip to content

Commit

Permalink
further improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ecodiv committed Mar 19, 2022
1 parent d29f2d2 commit c100368
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions src/raster/r.vif/r.vif.py
Expand Up @@ -205,7 +205,7 @@ def read_data(raster, n, flag_s, seed):
def compute_vif(mapx, mapy):
"""Compute rsqr of linear regression between layers mapx and mapy."""
x_i = np.hstack((mapx, np.ones((mapx.shape[0], 1))))
_, resid = np.linalg.lstsq(x_i, mapy, rcond=None)[:2]
unused, resid = np.linalg.lstsq(x_i, mapy, rcond=None)[:2]
if resid.size == 0:
resid = 0
r2 = float(1 - resid / (mapy.size * mapy.var()))
Expand Down Expand Up @@ -240,16 +240,16 @@ def main(options, flags):

# Variables
input_maps = options["maps"].split(",")
retain_map = options["retain"].split(",")
if retain_map != [""]:
check_layer(retain_map)
for k, _ in enumerate(retain_map):
if retain_map[k] not in input_maps:
input_maps.extend([retain_map[k]])
retain_maps = options["retain"].split(",")
if retain_maps != [""]:
check_layer(retain_maps)
for retain_map in retain_maps:
if retain_map not in input_maps:
input_maps.extend([retain_map])
input_map_names = [i.split("@")[0] for i in input_maps]
retain_map_names = [i.split("@")[0] for i in retain_map]
retain_map_names = [i.split("@")[0] for i in retain_maps]
max_vif = options["maxvif"]
if max_vif != "":
if max_vif:
max_vif = float(max_vif)
output_file = options["file"]
number_points = options["n"]
Expand Down Expand Up @@ -330,8 +330,7 @@ def main(options, flags):

# print the header of the output table to the console
if not flag_v:
print("\n")
print("VIF round " + str(m))
print("\nVIF round " + str(m))
print("--------------------------------------")
print(
"{0[0]:{1}s} {0[1]:>8s} {0[2]:>8s}".format(
Expand Down Expand Up @@ -388,8 +387,7 @@ def main(options, flags):

# Write final selected variables to std output
if not flag_v:
print("/n")
print("selected variables are: ")
print("\nselected variables are: ")
print("--------------------------------------")
print(", ".join(input_map_names))
else:
Expand All @@ -400,15 +398,15 @@ def main(options, flags):
text_file = open(output_file, "w")
if max_vif == "":
text_file.write("variable,vif,sqrtvif\n")
for i, _ in enumerate(out_vif):
for i in range(len(out_vif)):
text_file.write(
"{0:s},{1:.6f},{2:.6f}\n".format(
out_variable[i], out_vif[i], out_sqrt[i]
)
)
else:
text_file.write("round,removed,variable,vif,sqrtvif\n")
for i, _ in enumerate(out_vif):
for i in range(len(out_vif)):
text_file.write(
"{0:d},{1:s},{2:s},{3:.6f},{4:.6f}\n".format(
out_round[i],
Expand All @@ -420,8 +418,7 @@ def main(options, flags):
)
finally:
text_file.close()
gs.message("\n")
gs.message("Statistics are written to " + output_file + "\n")
gs.message("\nStatistics are written to {}\n".format(output_file))


if __name__ == "__main__":
Expand Down

0 comments on commit c100368

Please sign in to comment.