Skip to content

Commit

Permalink
Merge pull request #86 from WMD-group/comp_updates
Browse files Browse the repository at this point in the history
Comp updates
  • Loading branch information
AntObi committed Aug 7, 2023
2 parents 0b6b2f9 + 48ae3d3 commit 60cf8cd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

<small>[Compare with latest](https://github.com/WMD-group/ElementEmbeddings/compare/v0.3.0...HEAD)</small>

### Added

- Add arg for formula in featuriser function ([13dc313](https://github.com/WMD-group/ElementEmbeddings/commit/13dc313b40753aa267c878b86a33ba76944a5228) by Anthony Onwuli).

### Removed

- Removed warning ([271f61e](https://github.com/WMD-group/ElementEmbeddings/commit/271f61e8653b706a6dd716bf6c0ced9396965750) by Anthony Onwuli).
Expand Down
8 changes: 5 additions & 3 deletions src/elementembeddings/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def _composition_distance(

def composition_featuriser(
data: Union[pd.DataFrame, pd.Series, CompositionalEmbedding, list],
formula_column: str = "formula",
embedding: Union[Embedding, str] = "magpie",
stats: Union[str, list] = "mean",
inplace: bool = False,
Expand Down Expand Up @@ -385,13 +386,14 @@ def composition_featuriser(
if isinstance(data, pd.DataFrame):
if not inplace:
data = data.copy()
if "formula" not in data.columns:
if formula_column not in data.columns:
raise ValueError(
"The data must contain a column named 'formula' to featurise."
f"The data must contain a column named {formula_column} to featurise."
)
print("Featurising compositions...")
comps = [
CompositionalEmbedding(x, embedding) for x in tqdm(data["formula"].tolist())
CompositionalEmbedding(x, embedding)
for x in tqdm(data[formula_column].tolist())
]
print("Computing feature vectors...")
fvs = [x.feature_vector(stats) for x in tqdm(comps)]
Expand Down

0 comments on commit 60cf8cd

Please sign in to comment.