In [7]:
import logging
import tensorflow as tf


GPU_FROM = 0
GPU_TO = 1    

visible_devices = tf.config.get_visible_devices('GPU')
logging.info(f"Num GPUs visible:{len(visible_devices)}")
tf.config.set_visible_devices(visible_devices[GPU_FROM:GPU_TO],'GPU')

visible_devices = tf.config.get_visible_devices('GPU')
logging.info(f"Num GPUs to be used: {len(visible_devices)}")


from segmentation.model import SpaceSegmentationTransformer
from segmentation.model import LossWithVoids

data = tf.random.stateless_binomial(shape=(10000, 100), counts=1, probs=0.8, seed=[1997,1997]) + 1
train_frac = int(data.shape[0]*3/4)

train_ds, val_ds = tf.data.Dataset.from_tensor_slices(data[:train_frac]), tf.data.Dataset.from_tensor_slices(data[train_frac:])

def mapper(y):
    x = tf.strings.as_string(y+4)
    x = tf.strings.reduce_join(x, axis=-1)
    return (x, x), tf.cast(y, "float16")

train_ds = train_ds.map(mapper).shuffle(100).batch(8)
val_ds = val_ds.map(mapper).batch(8)

train_ds.element_spec, val_ds.element_spec

(((TensorSpec(shape=(None,), dtype=tf.string, name=None),
   TensorSpec(shape=(None,), dtype=tf.string, name=None)),
  TensorSpec(shape=(None, 100), dtype=tf.float16, name=None)),
 ((TensorSpec(shape=(None,), dtype=tf.string, name=None),
   TensorSpec(shape=(None,), dtype=tf.string, name=None)),
  TensorSpec(shape=(None, 100), dtype=tf.float16, name=None)))

In [8]:
tokenizer = tf.keras.layers.TextVectorization(
    output_sequence_length=100,
    standardize="lower_and_strip_punctuation",
    split="character",
    output_mode="int",
)

tokenizer.adapt(train_ds.map(lambda x,y: x[0]))
tokenizer.get_vocabulary()

['', '[UNK]', '6', '5']

In [9]:
tokenizer("5556665")

<tf.Tensor: shape=(100,), dtype=int64, numpy=
array([3, 3, 3, 2, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>

In [10]:
model = SpaceSegmentationTransformer(
    num_layers=2,
    d_model=512,
    num_attention_heads=3,
    seq_len=100,
    dff=1028,
    input_tokenizer=tokenizer,
    dropout_rate=0.1
)

In [11]:
from segmentation.metrics import SparseAccuracyWithIgnore
from segmentation.metrics import SparsePrecision
from segmentation.metrics import SparseRecall
from segmentation.metrics import SparseF1

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # Why can I not ignore class 0?
    metrics=[
        SparseAccuracyWithIgnore(ignore_token=0.),
        SparsePrecision(class_id=2, name="space_precision"),
        SparseRecall(class_id=2, name="space_recall"),
        SparseF1(class_id=2, name="space_f1"),
        SparsePrecision(class_id=1, name="char_precision"),
        SparseRecall(class_id=1, name="char_recall"),
        SparseF1(class_id=1, name="char_f1"),
    ]
)

In [12]:
model.fit(train_ds, validation_data=val_ds, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f109d67adf0>

In [13]:
model.evaluate(val_ds)



[2.759265669283195e-07, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

In [14]:
preds = model([("1111", "1212111", "222221111"),(None,None, None)])
tf.argmax(preds, axis=-1)

<tf.Tensor: shape=(3, 100), dtype=int64, numpy=
array([[1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1