Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 26, 2024
1 parent 2c1c48f commit e961e2f
Show file tree
Hide file tree
Showing 9 changed files with 488 additions and 184 deletions.
6 changes: 4 additions & 2 deletions atomgen/models/configuration_atomformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from transformers.configuration_utils import PretrainedConfig
from typing import Any

from transformers.configuration_utils import PretrainedConfig


class AtomformerConfig(PretrainedConfig): # type: ignore
r"""
Configuration of a :class:`~transform:class:`~transformers.AtomformerModel`.
Expand Down Expand Up @@ -39,4 +41,4 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.cls_token_id = cls_token_id
self.cls_token_id = cls_token_id
3 changes: 2 additions & 1 deletion atomgen/models/modeling_atomformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Implementation of the Atomformer model."""

from typing import Any, Optional, Tuple
from typing import Optional, Tuple

import torch
import torch.nn.functional as f
from torch import nn
from transformers.modeling_utils import PreTrainedModel

from .configuration_atomformer import AtomformerConfig


Expand Down
28 changes: 21 additions & 7 deletions scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"from atomgen.models.atomformer import AtomformerConfig\n",
"from atomgen.models.atomformer import Structure2EnergyAndForces\n",
"from atomgen.models.atomformer import AtomformerModel"
"from atomgen.models.atomformer import (\n",
" AtomformerConfig,\n",
" AtomformerModel,\n",
" Structure2EnergyAndForces,\n",
")"
]
},
{
Expand All @@ -27,7 +29,9 @@
"metadata": {},
"outputs": [],
"source": [
"model = Structure2EnergyAndForces.from_pretrained(\"/mnt/data2/s2ef_all_10epochs_weights/checkpoint-292950\")"
"model = Structure2EnergyAndForces.from_pretrained(\n",
" \"/mnt/data2/s2ef_all_10epochs_weights/checkpoint-292950\"\n",
")"
]
},
{
Expand Down Expand Up @@ -131,7 +135,10 @@
"source": [
"from transformers import AutoModel\n",
"\n",
"model = AutoModel.from_pretrained(\"vector-institute/atomformer-base\", trust_remote_code=True)"
"\n",
"model = AutoModel.from_pretrained(\n",
" \"vector-institute/atomformer-base\", trust_remote_code=True\n",
")"
]
},
{
Expand All @@ -142,11 +149,16 @@
"source": [
"import torch\n",
"\n",
"\n",
"input_ids = torch.randint(0, 100, (1, 10))\n",
"coords = torch.randn(1, 10, 3)\n",
"attention_mask = torch.ones(1, 10)\n",
"\n",
"input_ids, coords, attn_mask = torch.randint(0, 100, (1, 10)), torch.randn(1, 10, 3), torch.ones(1, 10)\n",
"input_ids, coords, attn_mask = (\n",
" torch.randint(0, 100, (1, 10)),\n",
" torch.randn(1, 10, 3),\n",
" torch.ones(1, 10),\n",
")\n",
"\n",
"output = model(input_ids, coords=coords, attention_mask=attention_mask)"
]
Expand Down Expand Up @@ -208,7 +220,9 @@
"outputs": [],
"source": [
"config = AtomformerConfig.from_json_file(\"atomgen/models/configs/atomformer-base.json\")\n",
"model = AtomformerModel.from_pretrained(\"/mnt/data2/s2ef_all_10epochs_weights/checkpoint-292950\", config=config)"
"model = AtomformerModel.from_pretrained(\n",
" \"/mnt/data2/s2ef_all_10epochs_weights/checkpoint-292950\", config=config\n",
")"
]
},
{
Expand Down
114 changes: 81 additions & 33 deletions scripts/data/mptrj_s2ef_train_convert.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,47 @@
import argparse
import numpy as np
import json
from pymatgen.core import Structure

import numpy as np
from datasets import Dataset
from pymatgen.core import Structure
from tqdm import tqdm
from json import JSONDecodeError


def process_json_chunk(chunk):
json_loaded = json.loads(chunk)
struct = Structure.from_dict(json_loaded['structure'])
struct = Structure.from_dict(json_loaded["structure"])

input_ids = np.array(list(struct.atomic_numbers)).astype("int16")
return {
'input_ids': input_ids,
'coords': struct.cart_coords.astype("float32"),
'forces': np.array(json_loaded['force']).astype("float32"),
'formation_energy': np.array(json_loaded['ef_per_atom'] * len(input_ids)).astype("float32"),
'total_energy': np.array(json_loaded['corrected_total_energy']).astype("float32"),
'has_formation_energy': True
"input_ids": input_ids,
"coords": struct.cart_coords.astype("float32"),
"forces": np.array(json_loaded["force"]).astype("float32"),
"formation_energy": np.array(
json_loaded["ef_per_atom"] * len(input_ids)
).astype("float32"),
"total_energy": np.array(json_loaded["corrected_total_energy"]).astype(
"float32"
),
"has_formation_energy": True,
}


def main(args):
with open(args.input_file, 'r') as file:
with open(args.input_file, "r") as file:
output = ""
dataset = {"input_ids": [], "coords": [], "forces": [], "formation_energy": [], "total_energy": [], "has_formation_energy": []}
dataset = {
"input_ids": [],
"coords": [],
"forces": [],
"formation_energy": [],
"total_energy": [],
"has_formation_energy": [],
}
num_datasets = 0
num_samples = 0
pbar = tqdm(total=args.total_samples)
read = True

while True:
try:
if read:
Expand All @@ -37,24 +50,31 @@ def main(args):
start = output.find('{"structure"')
end = output.find('"mp_id"')
if start != -1 and end != -1:
end = output.find('}', end)
end = output.find("}", end)
if end != -1:
end += 1
num_samples += 1
pbar.update(1)

chunk_data = process_json_chunk(output[start:end])
for key, value in chunk_data.items():
dataset[key].append(value)

output = output[end:]

if num_samples == args.samples_per_dataset:
dataset = Dataset.from_dict(dataset)
dataset.save_to_disk(f'{args.output_dir}/{num_datasets}')
dataset.save_to_disk(f"{args.output_dir}/{num_datasets}")
num_datasets += 1
num_samples = 0
dataset = {"input_ids": [], "coords": [], "forces": [], "formation_energy": [], "total_energy": [], "has_formation_energy": []}
dataset = {
"input_ids": [],
"coords": [],
"forces": [],
"formation_energy": [],
"total_energy": [],
"has_formation_energy": [],
}
pbar.close()
pbar = tqdm(total=args.total_samples)
else:
Expand All @@ -66,26 +86,54 @@ def main(args):
except Exception as e:
print(f"An error occurred: {e}")
break

pbar.close()

# Save any remaining data
if num_samples > 0:
dataset = Dataset.from_dict(dataset)
dataset.save_to_disk(f'{args.output_dir}/{num_datasets}')
dataset.save_to_disk(f"{args.output_dir}/{num_datasets}")

# Concatenate all datasets
all_datasets = [Dataset.load_from_disk(f'{args.output_dir}/{i}') for i in range(num_datasets + 1)]
all_datasets = [
Dataset.load_from_disk(f"{args.output_dir}/{i}")
for i in range(num_datasets + 1)
]
final_dataset = Dataset.concatenate_datasets(all_datasets)
final_dataset.save_to_disk(args.final_output)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Preprocess MPtrj dataset into a HuggingFace dataset.')
parser.add_argument('--input_file', type=str, required=True, help='Path to the input JSON file')
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save processed datasets')
parser.add_argument('--final_output', type=str, required=True, help='Path to save the final concatenated dataset')
parser.add_argument('--total_samples', type=int, default=1580395, help='Total number of samples in the dataset')
parser.add_argument('--samples_per_dataset', type=int, default=1580394, help='Number of samples per dataset chunk')

parser = argparse.ArgumentParser(
description="Preprocess MPtrj dataset into a HuggingFace dataset."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input JSON file"
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory to save processed datasets",
)
parser.add_argument(
"--final_output",
type=str,
required=True,
help="Path to save the final concatenated dataset",
)
parser.add_argument(
"--total_samples",
type=int,
default=1580395,
help="Total number of samples in the dataset",
)
parser.add_argument(
"--samples_per_dataset",
type=int,
default=1580394,
help="Number of samples per dataset chunk",
)

args = parser.parse_args()
main(args)
main(args)
Loading

0 comments on commit e961e2f

Please sign in to comment.