In [1]:
%%time
# %pip install tensorflow[and-cuda]==2.15.0.post1 transformers==4.36.2 einops==0.7.0 datasets==2.16.1
%pip install -qU transformers==4.36.2 einops==0.7.0 datasets==2.16.1

Note: you may need to restart the kernel to use updated packages.
CPU times: user 216 ms, sys: 40.3 ms, total: 256 ms
Wall time: 14.1 s


In [2]:
import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers, Model

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Union

from transformers import AutoTokenizer

import datasets
import math
import numpy as np
import pprint

2024-08-15 20:41:46.625166: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-15 20:41:46.625234: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-15 20:41:46.626733: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
max_seq_length = 128

In [4]:
@dataclass
class ModelArgs:
    model_input_dims: int = 64
    model_states: int = 64
    projection_expand_factor: int = 2
    conv_kernel_size: int = 4
    delta_t_min: float = 0.001
    delta_t_max: float = 0.1
    delta_t_scale: float = 0.1
    delta_t_init_floor: float = 1e-4
    conv_use_bias: bool = True
    dense_use_bias: bool = False
    layer_id: int = -1
    seq_length: int = max_seq_length
    num_layers: int = 5
    dropout_rate: float = 0.2
    use_lm_head: float = True
    num_classes: int = None
    vocab_size: int = None
    final_activation = None
    loss: Union[str, keras.losses.Loss] = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer: Union[str, keras.optimizers.Optimizer] = keras.optimizers.AdamW()
    metrics = ['accuracy']

    def __post_init__(self):
        self.model_internal_dim: int = int(self.projection_expand_factor * self.model_input_dims)

        self.delta_t_rank = math.ceil(self.model_input_dims/16)
        if self.layer_id == -1:
            self.layer_id = np.round(np.random.randint(0, 1000), 4)

        if self.vocab_size == None:
            raise ValueError("vocab size cannot be none")

        if self.use_lm_head:
            self.num_classes=self.vocab_size
        else:
            if self.num_classes == None:
                raise ValueError(f'num classes cannot be {self.num_classes}')

            if self.num_classes == 1:
                self.final_activation = 'sigmoid'
            else:
                self.final_activation = 'softmax'

        if self.loss == None:
            raise ValueError(f"loss cannot be {self.loss}")

In [5]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
vocab_size = tokenizer.vocab_size

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def selective_scan(u, delta, A, B, C, D):
    # first step of A_bar = exp(ΔA), i.e., ΔA
    dA = tf.einsum('bld,dn->bldn', delta, A) 
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)
    
    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]
    
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1
    
    # Cumulative sum along all the input tokens, parallel prefix sum, 
    # calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)  

    # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.exp(dA_cumsum)  
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

    x = dB_u * dA_cumsum
    # 1e-12 to avoid division by 0
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12) 

    y = tf.einsum('bldn,bln->bld', x, C)
    
    return y + u * D 

In [7]:
class MambaBlock(layers.Layer):
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        args = modelargs
        self.layer_id = modelargs.layer_id

        self.in_projection = layers.Dense(
            args.model_internal_dim * 2, 
            input_shape=(args.model_input_dims,), use_bias=False)

        self.conv1d = layers.Conv1D(
            filters=args.model_internal_dim,
            use_bias=args.conv_use_bias,
            kernel_size=args.conv_kernel_size,
            groups=args.model_internal_dim,
            data_format='channels_first',
            padding='causal'
        )

        # this layer takes in current token 'x' 
        # and outputs the input-specific Δ, B, C (according to S6)
        self.x_projection = layers.Dense(args.delta_t_rank + args.model_states * 2, use_bias=False)

        # this layer projects Δ from delta_t_rank to the mamba internal 
        # dimension
        self.delta_t_projection = layers.Dense(args.model_internal_dim, 
                                               input_shape=(args.delta_t_rank,), use_bias=True)

        self.A = repeat(
                tf.range(1, args.model_states+1, dtype=tf.float32), 
                'n -> d n', d=args.model_internal_dim)

        self.A_log = tf.Variable(
                tf.math.log(self.A), 
                trainable=True, dtype=tf.float32, 
                name=f"SSM_A_log_{args.layer_id}")

        self.D = tf.Variable(
                np.ones(args.model_internal_dim), 
                trainable=True, dtype=tf.float32, 
                name=f"SSM_D_{args.layer_id}")

        self.out_projection = layers.Dense(
                args.model_input_dims, 
                input_shape=(args.model_internal_dim,), 
                use_bias=args.dense_use_bias)

    def call(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba pape.
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """

        (batch_size, seq_len, dimension) = x.shape

        x_and_res = self.in_projection(x) # shape = (batch, seq_len, 2 * model_internal_dimension)
        (x, res) = tf.split(x_and_res, 
                            [self.args.model_internal_dim, 
                             self.args.model_internal_dim], axis=-1)
        
        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :seq_len]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = tf.nn.swish(x)
        y = self.ssm(x)
        y = y * tf.nn.swish(res)
        return self.out_projection(y)
    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper
            - run_SSM(A, B, C, u) in The Annotated S4
            Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -tf.exp(tf.cast(self.A_log, tf.float32)) # shape -> (d_in, n)
        D = tf.cast(self.D, tf.float32)

        x_dbl = self.x_projection(x) # shape -> (batch, seq_len, delta_t_rank + 2*n)

        (delta, B, C) = tf.split(
                x_dbl, 
                num_or_size_splits=[self.args.delta_t_rank, n, n], 
                axis=-1) # delta.shape -> (batch, seq_len) & B, C shape -> (batch, seq_len, n)

        delta = tf.nn.softplus(self.delta_t_projection(delta)) # shape -> (batch, seq_len, model_input_dim)

        return selective_scan(x, delta, A, B, C, D)

In [8]:
class ResidualBlock(layers.Layer):
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        self.mixer = MambaBlock(modelargs)
        self.norm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x):
        """
        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        return self.mixer(self.norm(x)) + x

In [9]:
def init_model(args: ModelArgs):
    input_layer = layers.Input(shape=(args.seq_length,), name='input_ids')
    x = layers.Embedding(args.vocab_size, args.model_input_dims, input_length=args.seq_length)(input_layer)

    for i in range(args.num_layers):
        x = ResidualBlock(args, name=f"Residual_{i}")(x)
        x = layers.Dropout(args.dropout_rate)(x)

    x = layers.LayerNormalization(epsilon=1e-5)(x)

    if not args.use_lm_head: # use flatten only if we are using the model as an LM
        x = layers.Flatten()(x)
    x = layers.Dense(1024, activation=tf.nn.gelu)(x)
    output_layer = layers.Dense(args.num_classes, activation=args.final_activation)(x)

    model = Model(inputs=input_layer, outputs=output_layer, name='MambaTimeModel')
    model.compile(
        loss=args.loss,
        optimizer=args.optimizer,
        metrics=args.metrics
    )

    return model

In [10]:
args = ModelArgs(
    model_input_dims=max_seq_length,
    model_states=32,
    num_layers=12,
    dropout_rate=0.2,
    vocab_size=vocab_size
)
model = init_model(args)
model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [11]:
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset("neuralwork/fashion-style-instruct", split="train[:1000]")

In [12]:
dataset

Dataset({
    features: ['input', 'completion', 'context'],
    num_rows: 1000
})

In [13]:
EOS_TOKEN = tokenizer.eos_token 
EOS_TOKEN

'<|endoftext|>'

In [14]:
inputFormatString = """You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations.
### Context:
{}

### Input:
{}"""

outputFormatString = """### Completion:
{}
"""

In [15]:
# Define the function to create the new 'text' column
def formatDatasetAlpaca(sample):
    context = sample['context']
    inputText = sample['input']
    completion = sample['completion']
    
    sampleInput = inputFormatString.format(context, inputText) + EOS_TOKEN
    sampleOutput = outputFormatString.format(completion) + EOS_TOKEN
    
    sample['inputText'] = sampleInput
    sample['outputText'] = sampleOutput
    
    return sample

# Apply the function to the dataset
dataset = dataset.map(formatDatasetAlpaca, remove_columns=dataset.column_names)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [16]:
dataset = dataset.train_test_split(test_size=0.3)
dataset

DatasetDict({
    train: Dataset({
        features: ['inputText', 'outputText'],
        num_rows: 700
    })
    test: Dataset({
        features: ['inputText', 'outputText'],
        num_rows: 300
    })
})

In [17]:
pprint.pprint(dataset['train'][0])

{'inputText': 'Below is an instruction that describes a task, paired with an '
              'input that provides further context. Write a response that '
              'appropriately completes the request.\n'
              '\n'
              '### Context:\n'
              'I usually go for jeans and tees with an edgy, androgynous look '
              "I'm going to a play / concert.\n"
              '\n'
              '### Input:\n'
              "I'm a tomboyish woman with big biceps and shoulders from "
              'weightlifting.<|endoftext|>',
 'outputText': '### Completion:\n'
               'Outfit Combination 1:\n'
               '- Top: A fitted leather jacket with shoulder pads to enhance '
               'your shoulders and give structure to the look.\n'
               '- Bottom: Black ripped skinny jeans for a edgy and rugged '
               'touch.\n'
               '- Shoe: Black combat boots for a comfortable yet stylish '
               'option.\n'
               '- A

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['inputText', 'outputText'],
        num_rows: 700
    })
    test: Dataset({
        features: ['inputText', 'outputText'],
        num_rows: 300
    })
})

In [19]:
train_texts = dataset['train']['inputText']
train_completions = dataset['train']['outputText']

test_texts = dataset['test']['inputText']
test_completions = dataset['test']['outputText']

In [20]:
np.shape(train_texts)

(700,)

In [21]:
train_encodings = tokenizer(train_texts, truncation=True, padding='max_length', max_length=max_seq_length)
train_labels = tokenizer(train_completions, truncation=True, padding='max_length', max_length=max_seq_length)

test_encodings = tokenizer(test_texts, truncation=True, padding='max_length', max_length=max_seq_length)
test_labels = tokenizer(test_completions, truncation=True, padding='max_length', max_length=max_seq_length)

In [22]:
np.shape(train_encodings['input_ids'])

(700, 128)

In [24]:
train_labels.keys()

dict_keys(['input_ids', 'attention_mask'])

In [25]:
# Convert to TensorFlow Datasets
train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(train_encodings),  # input_ids and attention_mask
    train_labels['input_ids']  # labels (target sequences)
))

test_dataset = tf.data.Dataset.from_tensor_slices((
    dict(test_encodings),  # input_ids and attention_mask
    test_labels['input_ids']  # labels (target sequences)
))

In [26]:
# Batch and shuffle the datasets
BATCH_SIZE = 8
train_dataset = train_dataset.batch(BATCH_SIZE).shuffle(1000)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [27]:
# Print out the train dataset tensors; if you're insane
# list(train_dataset.as_numpy_iterator())

In [29]:
%%time
history = model.fit(train_dataset, validation_data=test_dataset, epochs=1)

I0000 00:00:1723754592.113708     725 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1723754592.174646     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.175817     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.176576     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.177229     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.177947     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.178664     725 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1723754592.179370     725 graph_launch.cc:671] Fallback to op-by-op mode because m

[1m80/88[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m4s[0m 532ms/step - accuracy: 0.0081 - loss: nan

W0000 00:00:1723754679.593438     724 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m88/88[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 1s/step - accuracy: 0.0075 - loss: nan - val_accuracy: 0.0000e+00 - val_loss: nan
CPU times: user 3min 10s, sys: 2.5 s, total: 3min 13s
Wall time: 2min 56s
