Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/JannisHoch/copro into dev
Browse files Browse the repository at this point in the history
because it is
  • Loading branch information
JannisHoch committed May 20, 2021
2 parents 9ab7e4a + afdaaf4 commit 3a671f4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
6 changes: 3 additions & 3 deletions copro/conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def split_conflict_geom_data(X):

return X_ID, X_geom, X_data

def get_pred_conflict_geometry(X_test_ID, X_test_geom, y_test, y_pred, y_prob_0, y_prob_1):
def get_pred_conflict_geometry(X_test_ID, X_test_geom, y_test, y_pred, y_prob, y_prob_0, y_prob_1):
"""Stacks together the arrays with unique identifier, geometry, test data, and predicted data into a dataframe.
Contains therefore only the data points used in the test-sample, not in the training-sample.
Additionally computes whether a correct prediction was made.
Expand All @@ -308,9 +308,9 @@ def get_pred_conflict_geometry(X_test_ID, X_test_geom, y_test, y_pred, y_prob_0,
dataframe: dataframe with each input list as column plus computed 'correct_pred'.
"""

arr = np.column_stack((X_test_ID, X_test_geom, y_test, y_pred, y_prob_0, y_prob_1))
arr = np.column_stack((X_test_ID, X_test_geom, y_test, y_pred, y_prob, y_prob_0, y_prob_1))

df = pd.DataFrame(arr, columns=['ID', 'geometry', 'y_test', 'y_pred', 'y_prob_0', 'y_prob_1'])
df = pd.DataFrame(arr, columns=['ID', 'geometry', 'y_test', 'y_pred', 'y_prob', 'y_prob_0', 'y_pro b_1'])

df['correct_pred'] = np.where(df['y_test'] == df['y_pred'], 1, 0)

Expand Down
8 changes: 0 additions & 8 deletions tests/test_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,3 @@ def test_get_poly_geometry():

assert len(gdf) == len(list_geometry)

def test_get_poly_ID():

gdf = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))

list_ID = conflict.get_poly_ID(gdf)

assert len(gdf) == len(list_ID)

0 comments on commit 3a671f4

Please sign in to comment.