Skip to content

Commit

Permalink
Merge d06a051 into acbd91e
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorb1 committed Oct 22, 2023
2 parents acbd91e + d06a051 commit 1f33fd4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
19 changes: 14 additions & 5 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def convert(self, input_filepath: str, output_filepath: str, **kwargs: Dict):

class Strategy(ABC):
"""
Arguments
---------
user_config : dict, default=None
Expand All @@ -139,10 +138,20 @@ def _add_dtypes(self, config: Dict):
dtypes = {}
for column in details["indices"] + ["VALUE"]:
if column == "VALUE":
dtypes["VALUE"] = details["dtype"]
dtypes["VALUE"] = (
details["dtype"] if details["dtype"] != "int" else "int64"
)
else:
dtypes[column] = config[column]["dtype"]
dtypes[column] = (
config[column]["dtype"]
if config[column]["dtype"] != "int"
else "int64"
)
details["index_dtypes"] = dtypes
elif details["type"] == "set":
details["dtype"] = (
details["dtype"] if details["dtype"] != "int" else "int64"
)
return config

@property
Expand Down Expand Up @@ -491,8 +500,8 @@ def _check_index_dtypes(
except ValueError: # ValueError: invalid literal for int() with base 10:
df = df.dropna(axis=0, how="all").reset_index()
for index, dtype in config["index_dtypes"].items():
if dtype == "int":
df[index] = df[index].astype(float).astype(int)
if dtype == "int64":
df[index] = df[index].astype(float).astype("int64")
else:
df[index] = df[index].astype(dtype)
df = df.set_index(config["indices"])
Expand Down
2 changes: 1 addition & 1 deletion src/otoole/results/result_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def discount_factor(
if regions and years:
discount_rate["YEAR"] = [years]
discount_factor = discount_rate.explode("YEAR").reset_index(level="REGION")
discount_factor["YEAR"] = discount_factor["YEAR"].astype(int)
discount_factor["YEAR"] = discount_factor["YEAR"].astype("int64")
discount_factor["NUM"] = discount_factor["YEAR"] - discount_factor["YEAR"].min()
discount_factor["RATE"] = discount_factor["VALUE"] + 1
discount_factor["VALUE"] = (
Expand Down
14 changes: 12 additions & 2 deletions src/otoole/write_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def _write_parameter(
df = self._form_parameter(df, default)
handle.write("param default {} : {} :=\n".format(default, parameter_name))
df.to_csv(
path_or_buf=handle, sep=" ", header=False, index=True, float_format="%g"
path_or_buf=handle,
sep=" ",
header=False,
index=True,
float_format="%g",
lineterminator="\n",
)
handle.write(";\n")

Expand All @@ -171,7 +176,12 @@ def _write_set(self, df: pd.DataFrame, set_name, handle: TextIO):
"""
handle.write("set {} :=\n".format(set_name))
df.to_csv(
path_or_buf=handle, sep=" ", header=False, index=False, float_format="%g"
path_or_buf=handle,
sep=" ",
header=False,
index=False,
float_format="%g",
lineterminator="\n",
)
handle.write(";\n")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_read_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def test_index_dtypes_available(self, user_config):
assert actual == {
"REGION": "str",
"FUEL": "str",
"YEAR": "int",
"YEAR": "int64",
"VALUE": "float",
}

Expand Down Expand Up @@ -834,7 +834,7 @@ def test_read_config(self, user_config):
"FUEL": "str",
"REGION": "str",
"VALUE": "float",
"YEAR": "int",
"YEAR": "int64",
},
}
assert actual["AccumulatedAnnualDemand"] == expected
Expand Down

0 comments on commit 1f33fd4

Please sign in to comment.