In [1]:
import os
import re
import time

import jax
import torch.utils.data
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import ml_collections

from molnet.data import input_pipeline
from configs.tests import attention_test
from configs import root_dirs

from typing import Any, Dict, Tuple, Sequence, List, Union

In [2]:
config = attention_test.get_config()
config.root_dir = root_dirs.get_root_dir()

In [3]:
def _preprocess_images(
    batch: Dict[str, tf.Tensor],
    noise_std: float = 0.0,
    seed: int = 0
) -> Dict[str, tf.Tensor]:
    """Preprocesses images."""
    
    x = batch["images"]
    y = batch["atom_map"]

    # Cast the images to float32.
    x = tf.cast(x, tf.float32)
    y = tf.cast(y, tf.float32)

    # Normalize the images to zero mean and unit variance.
    # images are [X, Y, Z] - normalize each z slice separately
    xmean = tf.reduce_mean(x, axis=(0, 1), keepdims=True)
    xstd = tf.math.reduce_std(x, axis=(0, 1), keepdims=True)

    x = (x - xmean) / xstd

    # Interpolate to 16 z slices
    #x = tf.image.resize(x, (x.shape[0], x.shape[1], 16), method='bilinear')

    # Add noise to the images.
    if noise_std > 0.0:
        x = x + tf.random.normal(tf.shape(x), stddev=noise_std, seed=seed)

    # Add channel dimension.
    x = x[..., tf.newaxis]

    # Swap the species channel to last
    y = tf.transpose(y, perm=[1, 2, 3, 0])

    batch["images"] = x # [X, Y, Z, 1]
    batch["atom_map"] = y # [X, Y, Z, num_species]
    
    return batch



In [12]:
def get_dataloader():
    filenames = sorted(os.listdir(config.root_dir))
    filenames = [
    os.path.join(config.root_dir, f)
    for f in filenames
    if f.startswith("maps_")
    ]

    element_spec = tf.data.Dataset.load(filenames[0]).element_spec

    ds = tf.data.Dataset.from_tensor_slices(filenames)
    ds = ds.interleave(
        lambda path: tf.data.Dataset.load(path, element_spec=element_spec),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )

    # Shuffle the dataset.
    if config.shuffle_datasets:
        ds = ds.shuffle(1000, seed=config.rng_seed)

    # Repeat the dataset.
    ds = ds.repeat()

    # batches consist of a dict {'images': image, 'xyz': xyz, 'atom_map': atom_map}
    # pad xyz with zeros, its shape is [num_atoms, 5] - pad to [max_atoms, 5]
    ds = ds.map(
        lambda x: {
            "images": x["images"],
            "xyz": tf.pad(x["xyz"], [[0, config.max_atoms - tf.shape(x["xyz"])[0]], [0, 0]]),
            "atom_map": x["atom_map"],
        },
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )

    # Preprocess images.
    ds = ds.map(
        lambda x: _preprocess_images(x, config.noise_std, seed=config.rng_seed),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )

    # Batch the dataset.
    ds = ds.batch(config.batch_size)
    ds = ds.prefetch(tf.data.AUTOTUNE).as_numpy_iterator()
    return ds

In [19]:
loader = get_dataloader()
tf_times = []
for i in range(100):
    t0 = time.time()
    batch = next(loader)
    t1 = time.time()
    tf_times.append(t1 - t0)

print(f"TF mean time: {np.mean(tf_times)*1000:.2f} ms")

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text

TF mean time: 11.61 ms


In [20]:
loader = get_dataloader()

torch_times = []
for i in range(100):
    t0 = time.time()
    batch = next(loader)
    images = torch.tensor(batch["images"])
    t1 = time.time()
    torch_times.append(t1 - t0)

print(f"Torch mean time: {np.mean(torch_times)*1000:.2f} ms")

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text

Torch mean time: 12.84 ms
