Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: codespell
args:
- '-w'
- '--skip="*.txt,pylintrc,.*,src/maxtext/assets/*"'
- '--skip="*.txt,pylintrc,.*,src/maxtext/assets/*,src/maxtext/input_pipeline/protos/*"'
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid,ags,dout'
- '.'
additional_dependencies:
Expand All @@ -30,6 +30,7 @@ repos:
args:
- '--disable=R0401,R0917,W0201,W0613'
- "--ignore-patterns='.pytype,.*pyi$'"
- '--ignore-paths=src/maxtext/input_pipeline/protos'
- 'benchmarks'
- 'src'
- 'tests'
Expand All @@ -47,6 +48,7 @@ repos:
rev: 24.10.1
hooks:
- id: pyink
exclude: src/maxtext/input_pipeline/protos/
args:
- '--pyink-indentation=2'
- '--line-length=122'
Expand Down
40 changes: 23 additions & 17 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@

if TYPE_CHECKING:
import datasets
import tensorflow as tf

import grain.python as grain
import numpy as np
import tensorflow as tf
from maxtext.input_pipeline.protos import example_pb2
from maxtext.input_pipeline import tokenizer
from maxtext.multimodal import processor as mm_processor
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_logging

Features = dict[str, tf.Tensor]
AUTOTUNE = tf.data.experimental.AUTOTUNE
Comment thread
aireenmei marked this conversation as resolved.
Features = dict[str, Any]
INPUT_TOKENS_KEY = "input_ids"

########## Functions used by TFDS pipeline
Expand All @@ -58,6 +58,8 @@ def shift_data_by_truncation(x):


def add_segmentation_and_position(x, data_columns, padding_token=0):
import tensorflow as tf # pylint: disable=import-outside-toplevel

for data_column in data_columns:
x[f"{data_column}_segmentation"] = tf.cast(x[data_column] != padding_token, tf.int32)
x[f"{data_column}_position"] = tf.broadcast_to(
Expand All @@ -68,6 +70,7 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):

def TokenizeOp(tokenizer_model, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
"""Op for tokenization"""
import tensorflow as tf # pylint: disable=import-outside-toplevel

def _process_string(string_tensor):
# Extract string value and decode it if necessary
Expand Down Expand Up @@ -421,20 +424,23 @@ class ParseFeatures(grain.MapTransform):

def __init__(self, data_columns, tokenize):
self.data_columns = data_columns
if tokenize:
self.dtype = tf.string
Comment thread
aireenmei marked this conversation as resolved.
else:
self.dtype = tf.int64
self.tokenize = tokenize

def map(self, element):
def _parse(example):
parsed = tf.io.parse_example(
example,
{col: tf.io.FixedLenSequenceFeature([], dtype=self.dtype, allow_missing=True) for col in self.data_columns},
)
return parsed

return _parse(element)
"""Parse a serialized tf.train.Example proto and extract features."""
example = example_pb2.Example()
example.ParseFromString(element)
features = example.features.feature

parsed = {}
for col in self.data_columns:
if col in features:
f = features[col]
if self.tokenize:
parsed[col] = np.array(f.bytes_list.value, dtype=object)
else:
parsed[col] = np.array(f.int64_list.value, dtype=np.int32)
return parsed


@dataclasses.dataclass
Expand All @@ -447,9 +453,9 @@ def __init__(self, column_names, tokenize):

def map(self, element):
if self.tokenize:
return {col: element[col].numpy()[0].decode() for col in self.column_names}
return {col: element[col][0].decode() for col in self.column_names}
else:
return {col: element[col].numpy() for col in self.column_names}
return {col: element[col] for col in self.column_names}


@dataclasses.dataclass
Expand Down
13 changes: 13 additions & 0 deletions src/maxtext/input_pipeline/protos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading