In [None]:
%pip install -q --upgrade rouge-score
%pip install -q --upgrade keras-nlp
%pip install -q --upgrade keras  # Upgrade to Keras 3.

#from ann_visualizer.visualize import ann_viz

In [None]:
import keras_nlp
import pathlib
import random
import os
# Loss and metrics
import tensorflow as tf
from tensorflow import keras

os.chdir("/media/damian/PDB_DB/TRANSLATION/MODEL")
os.environ['CUDA_VISIBLE_DEVICES'] = "0"


import keras
from keras import ops

import tensorflow.data as tf_data
from tensorflow_text.tools.wordpiece_vocab import (
    bert_vocab_from_dataset as bert_vocab,
)


In [None]:
# Define constants
BATCH_SIZE = 256
EPOCHS = 300
MAX_SEQUENCE_LENGTH = 300
AA_VOCAB_SIZE = 100
PF_VOCAB_SIZE = 100
EMBED_DIM = 128
INTERMEDIATE_DIM = 2048
NUM_HEADS = 8

In [None]:
text_file = "GLOBAL_150_300_subset.csv"

In [None]:
def wordmaker_kmer(seq):
    k = 1 # length of the k-mer
    output = [seq[i:i+k] for i in range(len(seq) - k + 1)]
    return " ".join(output)

def wordmaker_non_overlapping(seq, word_length=1):
    output = [seq[i:i+word_length] for i in range(0, len(seq), word_length)]
    return " ".join(output)



In [None]:
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    AA, PF = line.split(",")
    AA = wordmaker_non_overlapping(AA.lower())
    PF = wordmaker_non_overlapping(PF)
    text_pairs.append((AA, PF))


In [None]:
for _ in range(5):
    print(random.choice(text_pairs))


In [None]:
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")


In [None]:
def train_word_piece(text_samples, vocab_size, reserved_tokens):
    word_piece_ds = tf_data.Dataset.from_tensor_slices(text_samples)
    vocab = keras_nlp.tokenizers.compute_word_piece_vocabulary(
        word_piece_ds.batch(1000).prefetch(2),
        vocabulary_size=vocab_size,
        reserved_tokens=reserved_tokens,
    )
    return vocab



In [None]:
reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]


AA_samples = [text_pair[0] for text_pair in train_pairs]
AA_vocab = train_word_piece(AA_samples, AA_VOCAB_SIZE, reserved_tokens)

PF_samples = [text_pair[1] for text_pair in train_pairs]
PF_vocab = train_word_piece(PF_samples, PF_VOCAB_SIZE, reserved_tokens)


In [None]:
print("AA Tokens: ", AA_vocab[10:20])
print("Proflex Tokens: ", PF_vocab[10:20])

In [None]:
AA_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=AA_vocab, lowercase=False
)
PF_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=PF_vocab)


In [None]:
AA_input_ex = text_pairs[0][0]
AA_tokens_ex = AA_tokenizer.tokenize(AA_input_ex)
print("AA seq: ", AA_input_ex)
print("Tokens: ", AA_tokens_ex)
print(
    "Recovered seq after detokenizing: ",
    AA_tokenizer.detokenize(AA_tokens_ex),
)

print()

PF_input_ex = text_pairs[0][1]
PF_tokens_ex = PF_tokenizer.tokenize(PF_input_ex)
print("PF seq: ", PF_input_ex)
print("Tokens: ", PF_tokens_ex)
print(
    "Recovered seq after detokenizing: ",
    PF_tokenizer.detokenize(PF_tokens_ex),
)


In [None]:
def preprocess_batch(AA, PF):
    batch_size = ops.shape(PF)[0]

    AA = AA_tokenizer(AA)
    PF = PF_tokenizer(PF)

    # Pad `eng` to `MAX_SEQUENCE_LENGTH`.
    AA_start_end_packer = keras_nlp.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH,
        pad_value=AA_tokenizer.token_to_id("[PAD]"),
    )
    AA = AA_start_end_packer(AA)

    # Add special tokens (`"[START]"` and `"[END]"`) to PF and pad it as well.
    PF_start_end_packer = keras_nlp.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH + 1,
        start_value=PF_tokenizer.token_to_id("[START]"),
        end_value=PF_tokenizer.token_to_id("[END]"),
        pad_value=PF_tokenizer.token_to_id("[PAD]"),
    )
    PF = PF_start_end_packer(PF)

    return (
        {
            "encoder_inputs": AA,
            "decoder_inputs": PF[:, :-1],
        },
        PF[:, 1:],
    )


def make_dataset(pairs):
    AA_texts, PF_texts = zip(*pairs)
    AA_texts = list(AA_texts)
    PF_texts = list(PF_texts)
    dataset = tf_data.Dataset.from_tensor_slices((AA_texts, PF_texts))
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(preprocess_batch, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.shuffle(1024).prefetch(64).cache()


train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)


In [None]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")


In [None]:
# Encoder
encoder_inputs = keras.Input(shape=(MAX_SEQUENCE_LENGTH,), name="encoder_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=AA_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(encoder_inputs)

encoder_outputs = keras_nlp.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)

encoder = keras.Model(encoder_inputs, encoder_outputs)

# Decoder
decoder_inputs = keras.Input(shape=(MAX_SEQUENCE_LENGTH,), name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(MAX_SEQUENCE_LENGTH, EMBED_DIM), name="decoder_state_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=PF_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(decoder_inputs)

x = keras_nlp.layers.TransformerDecoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.1)(x)
decoder_outputs = keras.layers.Dense(PF_VOCAB_SIZE, activation="softmax")(x)

decoder = keras.Model(
    [
        decoder_inputs,
        encoded_seq_inputs,
    ],
    decoder_outputs,
)

decoder_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model(
    [encoder_inputs, decoder_inputs],
    decoder_outputs,
    name="transformer",
)

# Early stopping callback
early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=8,
    restore_best_weights=True
)

# Custom loss function
def custom_loss(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_object = keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='none')
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

# Learning rate schedule
learning_rate = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.001,  # Lower initial learning rate
    decay_steps=10000,
    decay_rate=0.96,
    staircase=True
)

# Optimizer
optimiser = keras.optimizers.Adam(learning_rate)

transformer.summary()
# Compile the model
transformer.compile(optimizer=optimiser, loss=custom_loss, metrics=['accuracy'])

In [None]:
transformer.summary()
# Compile the model
transformer.compile(optimizer=optimiser, loss=custom_loss, metrics=['accuracy'])

# Train the model
history = transformer.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[early_stopping_callback]
)

In [None]:
from tensorflow.keras.models import load_model

custom_objects = {
    'custom_loss': custom_loss,
    # Add other custom objects here if needed
}

model_path = 'model2.keras'
transformer = load_model(model_path, custom_objects=custom_objects)
#transformer.compile(optimizer=optimiser, loss=custom_loss, metrics=['accuracy'])


#%pip install graphviz
#import graphviz
#plot_model(transformer, to_file='transformer_model.png', show_shapes=True, show_layer_names=True)

In [None]:
import tensorflow as tf
import random


def decode_sequences(input_sentences):
    batch_size = 1
    
    # Tokenize the encoder input 
    encoder_input_tokens = AA_tokenizer(input_sentences)
    if isinstance(encoder_input_tokens, tf.RaggedTensor):
        encoder_input_tokens = encoder_input_tokens.to_tensor()
    encoder_input_tokens = tf.convert_to_tensor(encoder_input_tokens, dtype=tf.int32)
    
    # Limit input sequence
    input_length = tf.shape(encoder_input_tokens)[1]
    if input_length > MAX_SEQUENCE_LENGTH:
        raise ValueError("Input sequence length exceeds MAX_SEQUENCE_LENGTH.")
    
    # Pad the input sequence 
    if input_length < MAX_SEQUENCE_LENGTH:
        pads = tf.fill([batch_size, MAX_SEQUENCE_LENGTH - input_length], 0)
        encoder_input_tokens = tf.concat([encoder_input_tokens, pads], axis=1)
    
    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length MAX_SEQUENCE_LENGTH with a start token and padding tokens.
    start = tf.fill([batch_size, 1], PF_tokenizer.token_to_id("[START]"))
    pad = tf.fill([batch_size, MAX_SEQUENCE_LENGTH - 1], PF_tokenizer.token_to_id("[PAD]"))
    prompt = tf.concat([start, pad], axis=1)

    # Greedy sampling
    generated_tokens = keras_nlp.samplers.GreedySampler()(
        next,
        prompt,
        stop_token_ids=[PF_tokenizer.token_to_id("[END]")],
        index=1,  # Start sampling after start token.
    )
    
    # Truncate or pad the generated tokens to ensure they match the input length
    if tf.shape(generated_tokens)[1] > input_length:
        generated_tokens = generated_tokens[:, :input_length]
    elif tf.shape(generated_tokens)[1] < input_length:
        pads = tf.fill([batch_size, input_length - tf.shape(generated_tokens)[1]], PF_tokenizer.token_to_id("[PAD]"))
        generated_tokens = tf.concat([generated_tokens, pads], axis=1)
    
    # Detokenize the generated tokens to obtain the output sentences.
    generated_sentences = PF_tokenizer.detokenize(generated_tokens)
    
    return generated_sentences

# Apply decoder
test_AA = [pair[0] for pair in test_pairs if len(pair[0]) <= MAX_SEQUENCE_LENGTH]

for i in range(1):
    seq = random.choice(test_AA)
    translated = decode_sequences([seq])
    translated = translated.numpy()[0].decode("utf-8")
    translated = (
        translated.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )
    print(f'{seq}')
    print(f'{translated}')
    print()


In [None]:
# save model

transformer.save('model2.keras')
