In [1]:
import logging
import tensorflow as tf


GPU_FROM = 0
GPU_TO = 1

visible_devices = tf.config.get_visible_devices('GPU')
logging.info(f"Num GPUs visible:{len(visible_devices)}")
tf.config.set_visible_devices(visible_devices[GPU_FROM:GPU_TO],'GPU')

visible_devices = tf.config.get_visible_devices('GPU')
logging.info(f"Num GPUs to be used: {len(visible_devices)}")

2023-02-20 17:50:59.696173: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-20 17:50:59.855692: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-20 17:50:59.892335: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-20 17:51:00.564863: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

In [2]:


from segmentation.model import SpaceSegmentationTransformer
from segmentation.model import LossWithVoids

data = tf.random.stateless_binomial(shape=(10000, 100), counts=1, probs=0.8, seed=[1997,1997]) + 1
train_frac = int(data.shape[0]*3/4)

train_ds, val_ds = tf.data.Dataset.from_tensor_slices(data[:train_frac]), tf.data.Dataset.from_tensor_slices(data[train_frac:])

def mapper(y):
    x = tf.strings.as_string(y+4) # 5 -> char, 6-> space; 80%spaces
    x = tf.strings.reduce_join(x, axis=-1)
    return (x, None), tf.cast(y, "float16")

train_ds = train_ds.map(mapper).shuffle(100).batch(8)
val_ds = val_ds.map(mapper).batch(8)

train_ds.element_spec, val_ds.element_spec

2023-02-20 17:51:01.362693: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-20 17:51:02.043448: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22296 MB memory:  -> device: 0, name: GeForce RTX 3090, pci bus id: 0000:3b:00.0, compute capability: 8.6


(((TensorSpec(shape=(None,), dtype=tf.string, name=None), NoneTensorSpec()),
  TensorSpec(shape=(None, 100), dtype=tf.float16, name=None)),
 ((TensorSpec(shape=(None,), dtype=tf.string, name=None), NoneTensorSpec()),
  TensorSpec(shape=(None, 100), dtype=tf.float16, name=None)))

In [3]:
tokenizer = tf.keras.layers.TextVectorization(
    output_sequence_length=100,
    standardize="lower_and_strip_punctuation",
    split="character",
    output_mode="int",
)

tokenizer.adapt(train_ds.map(lambda x,y: x[0]))
tokenizer.get_vocabulary()

['', '[UNK]', '6', '5']

In [4]:
for x in train_ds.take(1):
    print(x)

((<tf.Tensor: shape=(8,), dtype=string, numpy=
array([b'6666566665666666656666566566666666666666566566666656666566666665555656666566665566666656666666666666',
       b'6666655556566666665656666655666666665666666666666666666665666655666666666566666656665665665565656555',
       b'5656666666666666665666566666656656566666566666666666665556566665666666666665666666665656656666656656',
       b'6665666566656566665566556566666666666656656666656656566665666666656666666566565656556566665666666666',
       b'6666665566666666656666556666665666666665666566566656666666666666665665566666666666665666565666656666',
       b'5565666656665656566666566665566666566665666665566566666666666566666565566565565565666666665666666655',
       b'6665666666666666565566666665666666665656666566666565655556556656656665565666566656666656666656666566',
       b'6665666665666666566566566656666666666666666566566666666666566665666666656665666665656666666666666665'],
      dtype=object)>, None), <tf.Tensor: shape=(8, 100),

In [5]:
tokenizer("5556665")

<tf.Tensor: shape=(100,), dtype=int64, numpy=
array([3, 3, 3, 2, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])>

In [6]:
model = SpaceSegmentationTransformer(
    num_layers=2,
    d_model=512,
    num_attention_heads=3,
    seq_len=100,
    dff=1028,
    input_tokenizer=tokenizer,
    dropout_rate=0.1,
    num_classes=3,
)

In [7]:
from segmentation.metrics import SparseAccuracyWithIgnore
from segmentation.metrics import SparsePrecision
from segmentation.metrics import SparseRecall
from segmentation.metrics import SparseF1
from segmentation.model import LossWithVoids

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # Why can I not ignore class 0?
    metrics=[
        SparseAccuracyWithIgnore(ignore_token=0),
        SparsePrecision(class_id=2, name="space_precision"),
        SparseRecall(class_id=2, name="space_recall"),
        SparseF1(class_id=2, name="space_f1"),
        SparsePrecision(class_id=1, name="char_precision"),
        SparseRecall(class_id=1, name="char_recall"),
        SparseF1(class_id=1, name="char_f1"),
    ],
)

In [8]:
model.fit(train_ds, validation_data=val_ds, epochs=2)
model.summary()

Epoch 1/2
 10/938 [..............................] - ETA: 18s - loss: 1.6368 - sparse_categorical_accuracy: 0.7241 - space_precision: 0.8497 - space_recall: 0.8100 - space_f1: 0.8294 - char_precision: 0.3676 - char_recall: 0.3932 - char_f1: 0.3799  

2023-02-20 17:51:11.150388: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Epoch 2/2
Model: "space_segmentation_transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  8416264   
                                                                 
 text_vectorization (TextVec  multiple                 0         
 torization)                                                     
                                                                 
 dense_4 (Dense)             multiple                  1539      
                                                                 
Total params: 8,417,803
Trainable params: 8,417,803
Non-trainable params: 0
_________________________________________________________________


In [9]:
model.evaluate(val_ds)



[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

In [10]:
preds = model([("66666", "66665", "5556665", "56565665656"),(None,None, None, None)])
tf.argmax(preds, axis=-1)

<tf.Tensor: shape=(4, 100), dtype=int64, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [1, 1, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2