Install 🤗 Transformers and 🤗 Datasets

In [10]:
# ! pip install datasets transformers[torch]
# ! pip install --upgrade transformers[torch]
# ! pip install causal-conv1d
# ! pip install mamba-ssm
# ! pip install --upgrade mamba-ssm

In [2]:
from datasets import load_dataset, load_metric
# from peft import LoraConfig
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, MambaPreTrainedModel, MambaModel, PretrainedConfig, AutoConfig, MambaForCausalLM
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import os
import numpy as np
from typing import Any, Dict, Optional, Tuple, Union
import pickle
from sklearn.metrics import accuracy_score

In [3]:
!pwd

/home/zytadam/mamba


In [4]:
task = "sst2"
model_id = "state-spaces/mamba-130m-hf"
batch_size = 64
max_seq_length = 512

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [5]:
actual_task = "mnli" if task == "mnli-mm" else task

# if os.path.exists("dataset.pkl"):
#         print(f"Loading dataset from dataset.pkl")
#         with open("dataset.pkl", 'rb') as f:
#             dataset = pickle.load(f)
# else:
#     print(f"dataset.pkl not found. Loading dataset from original source and saving it.")
#     dataset = load_dataset("glue", actual_task)
#     with open("dataset.pkl", 'wb') as f:
#         pickle.dump(dataset, f)

dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)
metric_sklearn = lambda y_pred,y_true: accuracy_score(y_true,y_pred>0.5)
print(dataset,metric,sep='\n')

  metric = load_metric('glue', actual_task)
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})
Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Example

In [6]:
num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
# tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[task]
if sentence2_key is None:
  print(f"Sentence: {dataset['train'][0][sentence1_key]}")
else:
  print(f"Sentence 1: {dataset['train'][0][sentence1_key]}")
  print(f"Sentence 2: {dataset['train'][0][sentence2_key]}")
def preprocess_function(examples):
  if sentence2_key is None:
    return tokenizer(examples[sentence1_key], max_length=max_seq_length, truncation=True)
  return tokenizer(examples[sentence1_key], examples[sentence2_key], max_length=max_seq_length, truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True)

Sentence: hide new secretions from the parental units 


In [8]:
from transformers.models.mamba.modeling_mamba import *

def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    cache_params: Optional[MambaCache] = None,
    labels: Optional[torch.LongTensor] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    use_cache: Optional[bool] = None,
    **kwargs,  # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
        `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
        are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    mamba_outputs = self.backbone(
        input_ids,
        cache_params=cache_params,
        inputs_embeds=inputs_embeds,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        use_cache=use_cache,
    )
    hidden_states = mamba_outputs[0]

    logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()

    mask = kwargs.get("attention_mask", None)
    indices = (torch.sum(mask, dim=1, keepdim=True)-1).unsqueeze(-1)
    indices = indices.expand(logits.shape[0], 1, logits.shape[2])
    logits = torch.gather(logits, 1, indices).squeeze(1).contiguous()

    loss = None
    if labels is not None:
        # move labels to correct device to enable model parallelism
        labels = labels.to(logits.device)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)

    if not return_dict:
        output = (logits,) + mamba_outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return MambaCausalLMOutput(
        loss=loss,
        logits=logits,
        cache_params=mamba_outputs.cache_params,
        hidden_states=mamba_outputs.hidden_states,
    )

MambaForCausalLM.forward = forward

def model_init():
    model = MambaForCausalLM.from_pretrained(model_id)
    model.config.keys_to_ignore_at_inference = ["cache_params", "hidden_states"]
    model.lm_head = torch.nn.Linear(model.config.d_model, num_labels)
    return model

In [9]:
metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if task != "stsb":
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions[:, 0]
    # print(predictions)
    # print(labels)
    
    # print(metric.compute(predictions=predictions, references=labels))
    # print({'accuracy':metric_sklearn(predictions, labels)})
    
    # return metric.compute(predictions=predictions, references=labels)
    return {'accuracy':metric_sklearn(predictions, labels)}

validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"

class MambaTrainer(Trainer):
  def compute_loss(self, model, inputs, return_outputs=False):
    outputs = model(**inputs)
    lm_logits = outputs.logits
    lm_loss = outputs.loss

    if return_outputs:
      return lm_loss, outputs
    else:
      return lm_loss

  def save_model(self, output_dir, _internal_call):
    if not os.path.exists(output_dir):
      os.makedirs(output_dir)

    torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
    self.tokenizer.save_pretrained(output_dir)

In [14]:
# Training once by arguments
torch.manual_seed(123)

args = TrainingArguments(
    f"Mamba-finetuned-{task}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1.1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    push_to_hub=False,
    warmup_ratio=0.1,
    # max_steps=1
)

trainer = MambaTrainer(
    model=None,
    args=args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    model_init=model_init,
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss,Accuracy
1,0.1699,0.205166,0.932339
2,0.0951,0.214225,0.923165
3,0.0355,0.244805,0.931193


TrainOutput(global_step=3159, training_loss=0.11808190009155467, metrics={'train_runtime': 761.7225, 'train_samples_per_second': 265.25, 'train_steps_per_second': 4.147, 'total_flos': 4852811088511092.0, 'train_loss': 0.11808190009155467, 'epoch': 3.0})

In [15]:
trainer.evaluate()["eval_" + metric_name]

0.9323394495412844