In [1]:
import gzip
from itertools import batched
import multiprocessing
from pathlib import Path
from typing import Iterable

import polars as pl

In [2]:
# %%snakeviz
# import openff.pdbscan.polars.COORD_LINE_SCHEMA as schema
# import openff.pdbscan.polars.load_coords as proc_chunk

schema = {
    "id": pl.datatypes.String,
    "err": pl.datatypes.String,
    "lineNo": pl.datatypes.Int32,
    "serial": pl.datatypes.Int32,
    "name": pl.datatypes.String,
    "altLoc": pl.datatypes.String,
    "resName": pl.datatypes.String,
    "chainID": pl.datatypes.String,
    "resSeq": pl.datatypes.Int32,
    "iCode": pl.datatypes.String,
    "x": pl.datatypes.Float32,
    "y": pl.datatypes.Float32,
    "z": pl.datatypes.Float32,
    "occupancy": pl.datatypes.Float32,
    "tempFactor": pl.datatypes.Float32,
    "element": pl.datatypes.String,
    "charge": pl.datatypes.String,
    "terminated": pl.datatypes.Boolean,
    "modelNo": pl.datatypes.Int32,
    "conects": pl.datatypes.List(pl.datatypes.Int32),
}


def prepare_row(data: pl.DataFrame, path: Path):
    """
    Prepare a new row in a dataframe for a PDB file

    The dataframe is modified in place. The new row can be assigned to
    or read from as `data[colname][-1]`.
    """
    for column in data.values():
        column.append(None)
    data["id"][-1] = path.stem[3:-4]


def proc_chunk(batch: Iterable[Path]):
    data = {k: [] for k in schema}
    for path in batch:
        try:
            with gzip.open(path, "rt") as f:
                lines = f.readlines()
        except Exception as e:
            prepare_row(data, path)
            data["err"][-1] = repr(e)
            continue
        model_n = 0
        conects = {}
        for line in lines:
            if line.startswith("CONECT "):
                a = int(line[6:11])
                bs = []
                for start, stop in [(11, 16), (16, 21), (21, 26), (26, 31)]:
                    try:
                        b = int(line[start:stop])
                    except (ValueError, IndexError):
                        continue
                    bs.append(b)
                    conects.setdefault(b, set()).add(a)
                conects.setdefault(a, set()).update(bs)
        for line_n, line in enumerate(lines):
            if line.startswith("MODEL "):
                model_n = int(line[10:14])
            if line.startswith("ENDMDL "):
                model_n += 1
            if line.startswith("HETATM") or line.startswith("ATOM  "):
                prepare_row(data, path)
                data["lineNo"][-1] = line_n
                data["serial"][-1] = int(line[6:11])
                data["name"][-1] = line[12:16].strip()
                data["altLoc"][-1] = line[16].strip() or None
                data["resName"][-1] = line[17:20].strip()
                data["chainID"][-1] = line[21].strip()
                data["resSeq"][-1] = int(line[22:26])
                data["iCode"][-1] = line[26].strip() or None
                data["x"][-1] = line[30:38].strip()
                data["y"][-1] = line[38:46].strip()
                data["z"][-1] = line[46:54].strip()
                data["occupancy"][-1] = line[54:60].strip()
                data["tempFactor"][-1] = line[60:66].strip()
                data["element"][-1] = line[76:78].strip()
                data["charge"][-1] = line[78:80].strip() or None
                data["terminated"][-1] = False
                data["modelNo"][-1] = model_n
                data["conects"][-1] = list(conects.get(data["serial"][-1], []))
            if line.startswith("TER   "):
                terminated_resname = line[17:20].strip() or data["resName"][-1]
                terminated_chainid = line[21].strip() or data["chainID"][-1]
                terminated_resseq = int(line[22:26]) or data["resSeq"][-1]
                for i in range(-1, -999, -1):
                    if (
                        data["resName"][i] == terminated_resname
                        and data["chainID"][i] == terminated_chainid
                        and data["resSeq"][i] == terminated_resseq
                    ):
                        data["terminated"][i] = True
                    else:
                        break
                else:
                    assert False, "last residue too big"
    return data

In [3]:
# %%snakeviz

chunk_size = 20
threads = 24
datadir = Path("/home/joshmitchell/Downloads/pdb")
df = pl.DataFrame(schema=schema)

for i, batch in enumerate(batched(datadir.glob("*/*.ent.gz"), n=chunk_size * threads)):
    print("batch", i)
    with multiprocessing.Pool(processes=threads) as pool:
        batch_results = list(
            pool.imap_unordered(
                proc_chunk,
                batched(batch, chunk_size),
            )
        )
    print("concatenating results of batch", i)
    df = pl.concat(
        [df, *(pl.DataFrame(d, schema=schema) for d in batch_results)]
    )
    print("rechunking after batch", i)
    df.rechunk()
    print("done with batch", i, f"({df.estimate_size('mb')} mb)")
    if df.estimated_size('gb') >= 5:
        print("writing recent batches to file")
        df.write_parquet(datadir / "coords_parquet" / f"coords_{i}.parquet")
        df = pl.DataFrame(schema=schema)
print("writing final batches to file")
df.write_parquet(datadir / "coords_parquet" / f"coords_{i}.parquet")

batch 0
concatenating results of batch 0
rechunking after batch 0
done with batch 0
batch 1
concatenating results of batch 1
rechunking after batch 1
done with batch 1
batch 2
concatenating results of batch 2
rechunking after batch 2
done with batch 2
batch 3
concatenating results of batch 3
rechunking after batch 3
done with batch 3
batch 4
concatenating results of batch 4
rechunking after batch 4
done with batch 4
batch 5
concatenating results of batch 5
rechunking after batch 5
done with batch 5
batch 6
concatenating results of batch 6
rechunking after batch 6
done with batch 6
batch 7
concatenating results of batch 7
rechunking after batch 7
done with batch 7
batch 8
concatenating results of batch 8
rechunking after batch 8
done with batch 8
batch 9
concatenating results of batch 9
rechunking after batch 9
done with batch 9
batch 10
concatenating results of batch 10
rechunking after batch 10
done with batch 10
batch 11
concatenating results of batch 11
rechunking after batch 11
don

In [4]:
print("writing final batches to file")
df.write_parquet(datadir / "coords_parquet" / f"coords_{i}.parquet")

writing final batches to file


In [12]:
import io
uncompressed_pdb_size = 0
for path in batch:
    with gzip.open(path, 'rb') as file_obj:
        uncompressed_pdb_size += file_obj.seek(0, io.SEEK_END)
uncompressed_pdb_size

419255352

In [8]:
df.estimated_size('b')

211046689

In [7]:
sum(path.stat().st_size for path in batch)

97648413

In [10]:
(datadir / "coords_parquet" / f"coords_{i}.parquet").stat().st_size

56238480