In [None]:
!pip install datasets
!pip install einops
!pip install tensorflow-text
!pip install sentencepiece

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed datasets-

In [None]:
# imports
import tensorflow as tf
import numpy as np

import datasets
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from einops import rearrange, repeat
import tensorflow_text as tft

from google.colab import files
import json

import matplotlib.pyplot as plt
import sentencepiece as sp
import re

In [None]:
bible = files.upload()

Saving bible.txt to bible.txt


In [None]:
def data_cleaning():
    """
    """
    # read the text file
    bible_file = open("bible.txt")
    bible_str = bible_file.read()
    # preprocessing
    # convert to lowercase
    bible_str = bible_str.lower()
    # remove all special characters and numbers except \n
    bible_str = re.sub(r'[^A-Za-z\n]+', ' ', bible_str)
    # write the preprocessed text into the file
    bible_file = open("bible.txt", "w")
    bible_file.write(bible_str)
    # close the file
    bible_file.close()

def train_tokenizer():
    """
    """
    vocab_size=2000
    # Train the SentencePiece tokenizer on our text file
    sp.SentencePieceTrainer.train(input="bible.txt", model_prefix='tokenizer_model', model_type="unigram", vocab_size=vocab_size)

    # load the trained model file in the correct format
    trained_tokenizer_model = tf.io.gfile.GFile('tokenizer_model.model', "rb").read()

    # load the model as a tokenizer that we can use for our models
    tokenizer = tft.SentencepieceTokenizer(
        model=trained_tokenizer_model, out_type=tf.int32, nbest_size=-1, alpha=1, reverse=False,
        add_bos=False, add_eos=False, return_nbest=False, name=None
    )

    return tokenizer

def prepare_data(text, tokenizer,vocab_size, inputlength_m=32): # input_length_m between 32 and 256
    """
    """
    # tokenize the text
    tokens = tokenizer.tokenize(text)

    # create windows
    windowed_tokens = tft.sliding_window(data=tokens, width=inputlength_m+1)
    # the first m window tokens are inputs
    inputs = windowed_tokens[:, :inputlength_m]
    targets = windowed_tokens[:,inputlength_m]
    targets = tf.one_hot(targets, vocab_size)
    # create TensorFlow dataset from the input-target pairs
    dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
    train_dataset, val_dataset = tf.keras.utils.split_dataset(dataset, left_size=0.7, right_size=0.3, shuffle=True)
    # shuffle, batch and prefetch
    train_dataset = train_dataset.shuffle(1000).batch(32).prefetch(4)
    val_dataset = val_dataset.shuffle(1000).batch(32).prefetch(4)


    return train_dataset, val_dataset

In [None]:
data_cleaning()
tokenizer = train_tokenizer()

vocab_size = 2000
sequence_length = 128

# shorten dataset
bible_file = open("bible.txt")
bible = bible_file.read()
bible = bible[:len(bible)//8]


train_data, val_data = prepare_data(bible, tokenizer,vocab_size,inputlength_m = sequence_length)


In [None]:
class MambaResBlock(tf.keras.Model):

  def __init__(self, input_dim, projection_dim):
    super().__init__()

    # normalisation
    self.layernorm = tf.keras.layers.LayerNormalization()
    # Dense
    self.dense1 = tf.keras.layers.Dense(units=projection_dim) # activation ?
    # Dense
    self.dense2 = tf.keras.layers.Dense(units=projection_dim) # activation ?
    # Convolution
    self.conv1d = tf.keras.layers.Conv1D(filters=projection_dim , kernel_size=4, strides=1 , padding="causal", groups = 256, data_format = "channels_last") # data_format?, groups?
    # SSM block
    self.ssm = SelectiveSSM(32, 256)
    # Dense
    self.dense3 = tf.keras.layers.Dense(units=input_dim) # activation ?
    # dropout
    self.dropout = tf.keras.layers.Dropout(rate=0.2)

  def call(self, input):

    x = self.layernorm(input)

    x1 = self.dense1(x)
    x1 = self.conv1d(x1)
    x1 =  tf.nn.silu(x1)
    x1 = self.ssm(x1)

    x2 = self.dense2(x)
    x2 =  tf.nn.silu(x2)

    x = x1 * x2
    x = self.dense3(x)

    # skip connection
    x = x + input

    x = self.dropout(x)

    return x

class SelectiveSSM(tf.keras.Model):
  def __init__(self, states, internal_dim):
    super().__init__()

    self.states = states
    self.internal_dim = internal_dim

    # hippo initialisation für A ? dafür müsste A aber quadratisch sein
    # -> quadratische matrix oder nicht ?
    #self.A =  # states x internal dim
    #self.D =  # np ones internal dim
    A = repeat(tf.range(1, states+1, dtype=tf.float32),'n -> d n', d=internal_dim)

    self.A_log = tf.Variable(tf.math.log(A),trainable=True, dtype=tf.float32)

    self.D = tf.Variable(tf.ones(internal_dim),trainable=True, dtype=tf.float32)

    self.denseB = tf.keras.layers.Dense(units=self.states)
    self.denseC = tf.keras.layers.Dense(units=self.states)
    self.densedelta = tf.keras.layers.Dense(units=self.internal_dim)

  def selective_scan(self,u, delta, A, B, C, D):
    # first step of A_bar = exp(ΔA), i.e., ΔA
    dA = tf.einsum('bld,dn->bldn', delta, A) # quasi delta mal A
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B) # input mal B mal delta

    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1

    # Cumulative sum along all the input tokens, parallel prefix sum,
    # calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)

    # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.exp(dA_cumsum)
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

    x = dB_u * dA_cumsum
    # 1e-12 to avoid division by 0
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12)

    y = tf.einsum('bldn,bln->bld', x, C)

    return y + u * D

  def call(self, input):

    A = -tf.exp(tf.cast(self.A_log, tf.float32)) # shape -> (d_in, n)
    #D = tf.cast(self.D, tf.float32)

    C = self.denseC(input)
    B = self.denseB(input)
    delta = tf.nn.softplus(self.densedelta(input))

    return self.selective_scan(input, delta, A, B, C, self.D)

class MambaModel(tf.keras.Model):
  def __init__(self, num_layers, vocab_size):
    super().__init__()

    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_dim = vocab_size, output_dim = 128) #, input_length = 128) (bs, 128, 128)
    self.layer_list = []
    for i in range(num_layers):
        self.layer_list.append(MambaResBlock(128, 256))
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.flatten = tf.keras.layers.Flatten()
    self.dense = tf.keras.layers.Dense(units=1024, activation=tf.nn.gelu)
    self.out = tf.keras.layers.Dense(units=vocab_size, activation=tf.nn.softmax)

  def call(self, input):

    x = self.embedding(input)

    for i in range(self.num_layers):
      x = self.layer_list[i](x)

    x = self.flatten(x)
    x = self.dense(x)
    x = self.out(x)

    return x

In [None]:
model = MambaModel(1, vocab_size)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.CategoricalCrossentropy()

# compile the model (here, adding a loss function and an optimizer)
model.compile(optimizer = optimizer, loss=loss, metrics=["accuracy"])

In [None]:
history = model.fit(train_data,validation_data=val_data, epochs=10)
model.save("mamba_07_04_2024.h5")

Epoch 1/10




[1m2507/2507[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19699s[0m 8s/step - accuracy: 0.1516 - loss: 5.2655 - val_accuracy: 0.2473 - val_loss: 4.1696
Epoch 2/10
[1m 582/2507[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m3:34:23[0m 7s/step - accuracy: 0.3019 - loss: 3.5524

In [None]:
def visualise_results(history):

  plt.plot(history.history["loss"])
  plt.plot(history.history["accuracy"])
  plt.title("Train- Accuracy and Loss")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy, Loss")
  plt.show()

  plt.plot(history.history["val_loss"])
  plt.plot(history.history["val_accuracy"])
  plt.title("Validation- Accuracy and Loss")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy, Loss")
  plt.show()

visualise_results(history)