### Setup

In [None]:
!pip install transformers -q
!pip install git+https://github.com/agemagician/Ankh.git

In [2]:
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from datasets import load_dataset
import ankh
from transformers.models import convbert
from keras import layers

seed = 7
tf.random.set_seed(seed)
np.random.seed(seed)

### Load Ankh Base as a TensorFlow model.

In [None]:
model, tokenizer = ankh.load_base_model(model_format='tf')

### Load Fluorescence dataset from Proteinea's Hugging Face Datasets.

In [None]:
dataset = load_dataset('proteinea/fluorescence')

train = dataset['train']
validation = dataset['validation']
test = dataset['test']

### Prepare the tokenization function that will be used in `tf.data.Dataset` instance.

In [36]:
def tokenize(x, y):
  # `.decode("utf8") is used because TensorFlow converts the strings to bytes,
  # so we need to convert them back to string using `.decode("utf8")` function.
  x = tokenizer.encode(list(x.numpy().decode('utf8')), is_split_into_words=True, add_special_tokens=True)
  return x, y

In [37]:
# Create our datasets using `.from_tensor_slices()` method.
train_dataset = tf.data.Dataset.from_tensor_slices((train['primary'], train['log_fluorescence']))
test_dataset = tf.data.Dataset.from_tensor_slices((test['primary'], test['log_fluorescence']))
valid_dataset = tf.data.Dataset.from_tensor_slices((validation['primary'], validation['log_fluorescence']))

# Map each residue in every protein sequence
# to its corresponding id using `tokenize()` function.
train_dataset = train_dataset.map(lambda x, y: tf.py_function(tokenize, inp=[x, y], Tout=(tf.int32, tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
# Pad our sequences so that each sequence can have
# the same length as the longest sequence in its current batch using `.padded_batch()`.
train_dataset = train_dataset.padded_batch(16, padded_shapes=((None,), []))
# Prefetch from our dataset.
train_dataset = train_dataset.prefetch(1024)

test_dataset = test_dataset.map(lambda x, y: tf.py_function(tokenize, inp=[x, y], Tout=(tf.int32, tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.padded_batch(16, padded_shapes=((None,), []))
test_dataset = test_dataset.prefetch(1024)

valid_dataset = valid_dataset.map(lambda x, y: tf.py_function(tokenize, inp=[x, y], Tout=(tf.int32, tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
valid_dataset = valid_dataset.padded_batch(16, padded_shapes=((None,), []))
valid_dataset = valid_dataset.prefetch(1024)

### Create the downstream model.

In [33]:
# Create ConvBert configuration,
# the same configuration that was used in the paper.
convbert_config = convbert.ConvBertConfig(hidden_size=768,
                                          num_hidden_layers=1,
                                          num_attention_heads=8,
                                          intermediate_size=768//2,
                                          hidden_dropout_prob=0.1,
                                          conv_kernel_size=7)

# Freeze Ankh Base weights.
model.trainable = False

inputs = layers.Input((None,))

# Pass the inputs layer to the model.
x = model(inputs, training=False)

# Pass the output layer (`last_hidden_state`) to the ConvBert Layer.
x = convbert.TFConvBertLayer(convbert_config)(x.last_hidden_state, None, None, None)[0]
# Apply Global Max Pooling over the timesteps.
x = layers.GlobalMaxPooling1D()(x)
output = layers.Dense(1, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.1, maxval=0.1, seed=seed))(x)

# Create our downstream model
downstream_model = keras.Model(inputs, output)

### Compile our downstream model.

In [34]:
downstream_model.compile(loss='mse', optimizer='adam', jit_compile=True)

### Train our downstream model.

In [None]:
downstream_model.fit(train_dataset, validation_data=valid_dataset, epochs=5)

### Evaluate the model.

In [None]:
downstream_model.evaluate(test_dataset)

### Prediction

In [None]:
model.predict(test_dataset)