Skip to content

Commit

Permalink
Merge branch 'mmtf' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed Apr 4, 2023
2 parents 38fd549 + e8c6afb commit cb8fdaa
Show file tree
Hide file tree
Showing 20 changed files with 189 additions and 182 deletions.
2 changes: 1 addition & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas<2.0.0
biopandas>=0.4.1
biopandas==0.5.0.dev0
biopython
bioservices>=1.10.0
deepdiff
Expand Down
1 change: 1 addition & 0 deletions .requirements/torch_cpu.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
einops
1 change: 1 addition & 0 deletions .requirements/torch_gpu.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
einops
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ from graphein.protein.utils import download_alphafold_structure

config = ProteinGraphConfig()
fp = download_alphafold_structure("Q5VSL9", aligned_score=False)
g = construct_graph(config=config, pdb_path=fp)
g = construct_graph(config=config, path=fp)
```

### Creating a Protein Mesh
Expand Down
16 changes: 8 additions & 8 deletions graphein/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
@click.option(
"-p",
"--pdb_path",
"--path",
help="Path to input pdbs",
type=click.Path(
exists=True, file_okay=True, dir_okay=True, path_type=pathlib.Path
Expand All @@ -35,20 +35,20 @@
exists=True, file_okay=False, dir_okay=True, path_type=pathlib.Path
),
)
def main(config_path, pdb_path, output_path):
def main(config_path, path, output_path):
"""Build the graphs and save them in output dir."""
config = parse_config(path=config_path) if config_path else None
if pdb_path.is_file():
pdb_paths = [pdb_path]
elif pdb_path.is_dir():
pdb_paths = list(pdb_path.glob("*.pdb"))
if path.is_file():
paths = [path]
elif path.is_dir():
paths = list(path.glob("*.pdb"))
else:
raise NotImplementedError(
"Given PDB path needs to point to either a pdb file or a directory with pdb files."
)

for path in pdb_paths:
g = construct_graph(config=config, pdb_path=str(path))
for path in paths:
g = construct_graph(config=config, path=str(path))

with open(str(output_path / f"{path.stem}.gpickle"), "wb") as f:
pickle.dump(g, f)
Expand Down
52 changes: 26 additions & 26 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
root: str,
name: str,
pdb_paths: Optional[List[str]] = None,
paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, torch.Tensor]] = None,
Expand Down Expand Up @@ -73,9 +73,9 @@ def __init__(
:type root: str
:param name: Name of the dataset. Will be saved to ``data_$name.pt``.
:type name: str
:param pdb_paths: List of full path of pdb files to load. Defaults to
``None``.
:type pdb_paths: Optional[List[str]], optional
:param paths: List of full path of PDB or MMTF files to load. Defaults
to ``None``.
:type paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to None.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -139,8 +139,8 @@ def __init__(
else None
)

self.pdb_paths = pdb_paths
if self.pdb_paths is None:
self.paths = paths
if self.paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
Expand All @@ -150,12 +150,12 @@ def __init__(
# Use local saved pdb_files instead of download or move them to
# self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
if isinstance(self.paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
os.path.splitext(os.path.split(path)[-1])[0]
for path in self.paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
self.path, _ = os.path.split(self.paths[0])

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
Expand Down Expand Up @@ -201,8 +201,8 @@ def processed_file_names(self) -> List[str]:

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
if self.paths is not None:
return self.path # replace raw dir with user local path
else:
return os.path.join(self.root, "raw")

Expand Down Expand Up @@ -276,7 +276,7 @@ def process(self):
# Create graph objects
print("Constructing Graphs...")
graphs = construct_graphs_mp(
pdb_path_it=structure_files,
path_it=structure_files,
config=self.config,
chain_selections=chain_selections,
return_dict=True,
Expand Down Expand Up @@ -329,7 +329,7 @@ class ProteinGraphDataset(Dataset):
def __init__(
self,
root: str,
pdb_paths: Optional[List[str]] = None,
paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
Expand Down Expand Up @@ -358,9 +358,9 @@ def __init__(
:param root: Root directory where the dataset should be saved.
:type root: str
:param pdb_paths: List of full path of pdb files to load. Defaults to
``None``.
:type pdb_paths: Optional[List[str]], optional
:param paths: List of full path of PDB or MMTF files to load. Defaults
to ``None``.
:type paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to ``None``.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -422,8 +422,8 @@ def __init__(
if uniprot_ids is not None
else None
)
self.pdb_paths = pdb_paths
if self.pdb_paths is None:
self.paths = paths
if self.paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
Expand All @@ -433,12 +433,12 @@ def __init__(
# Use local saved pdb_files instead of download or move them to
# self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
if isinstance(self.paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
os.path.splitext(os.path.split(path)[-1])[0]
for path in self.paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
self.path, _ = os.path.split(self.paths[0])

# Labels & Chains

Expand Down Expand Up @@ -496,8 +496,8 @@ def processed_file_names(self) -> List[str]:

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
if self.paths is not None:
return self.path # replace raw dir with user local path
else:
return os.path.join(self.root, "raw")

Expand Down Expand Up @@ -607,7 +607,7 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:
file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in pdbs]

graphs = construct_graphs_mp(
pdb_path_it=file_names,
path_it=file_names,
config=self.config,
chain_selections=chain_selections,
return_dict=False,
Expand Down
8 changes: 4 additions & 4 deletions graphein/protein/features/nodes/dssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def add_dssp_df(

config = G.graph["config"]
pdb_code = G.graph["pdb_code"]
pdb_path = G.graph["pdb_path"]
path = G.graph["path"]
pdb_name = G.graph["name"]

# Extract DSSP executable
Expand All @@ -97,9 +97,9 @@ def add_dssp_df(
), "DSSP must be on PATH and marked as an executable"

pdb_file = None
if pdb_path:
if os.path.isfile(pdb_path):
pdb_file = pdb_path
if path:
if os.path.isfile(path):
pdb_file = path
else:
if config.pdb_dir:
if os.path.isfile(config.pdb_dir / (pdb_code + ".pdb")):
Expand Down

0 comments on commit cb8fdaa

Please sign in to comment.