In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import pickle
from datasets import Dataset

def load_data_from_pickle(filepath):
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['texts'],data['labels']

def prepare_dataset(texts, labels):
    ds = Dataset.from_dict({
        'texts': texts,
        'tg': labels,
    })
    return ds.train_test_split(test_size=0.2, shuffle=True)

In [20]:
dataset_id = 'dataset/addition_dataset_ct.pkl'
texts, labels = load_data_from_pickle(dataset_id)
dataset = prepare_dataset(texts, labels)

In [21]:
from transformers import AutoTokenizer

model_path = 'afmck/testing-llama-tiny'

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id= tokenizer.eos_token_id

In [22]:
import numpy as np
def preprocess_function(example):
    labels = example['tg']
    texts = example['texts']
    example = tokenizer(texts, padding = True, truncation=True)
    labels = np.array(labels,dtype = np.int64)
    example['labels'] = np.argmax(labels.reshape(-1,10),axis=1)
    #print(example)
    return example

tokenized_dataset = dataset.map(preprocess_function)

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [23]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [34]:
import evaluate
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    

def sigmoid(x):
    return 1/(1 + np.exp(-x))


def multi_label_metrics(predictions, labels, threshold=0.5):
    _predictions = predictions.reshape(predictions.shape[0],22,11)
    y_pred = np.argmax(_predictions, axis=-1)
    y_true = np.argmax(labels.reshape(predictions.shape[0],22,11), axis=-1)
    accuracy = np.mean(y_pred==y_true)
    abs_acc = np.mean([np.array_equal(y_pred[i],y_true[i]) for i in range(predictions.shape[0])])
    metrics = {'accuracy': accuracy,
              'abs_accuracy': abs_acc}
    return metrics



def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    result = multi_label_metrics(
        predictions=predictions, 
        labels=labels)
    return result
     


In [35]:
tokenizer.pad_token_id = tokenizer.eos_token_id

In [36]:
from transformers import MambaConfig
config = MambaConfig(
        vocab_size=tokenizer.vocab_size,  # Based on the number of unique tokens
        hidden_size = 256,
        state_size = 6,
        num_hidden_layers = 22,
        expend = 2,
        conv_kernel = 4,
        use_cache = True,
        num_labels = 242,
        pad_token_id = tokenizer.eos_token_id,
        problem_type = "multi_label_classification"
        
    )

In [37]:
from Mamba4SC import MambaForSequenceClassification

In [38]:
model = MambaForSequenceClassification(config)

In [39]:
print(f"Total parameters in the model: {model.num_parameters()}")

Total parameters in the model: 17552640


In [40]:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(

   output_dir="my_awesome_model",
   learning_rate=5e-5,
   per_device_train_batch_size=64,
   per_device_eval_batch_size=64,
   num_train_epochs=100,
   weight_decay=0.01,
   evaluation_strategy="epoch",
   save_strategy="epoch",
   load_best_model_at_end=True,
)

trainer = Trainer(

   model=model,
   args=training_args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset["test"],
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [41]:
trainer.train()

RuntimeError: shape '[-1, 10]' is invalid for input of size 15488

In [32]:
for i in range(6,21):
    print('Class: ', i)
    dataset_id = f'dataset/addition_dataset_ct{i}.pkl'
    texts, labels = load_data_from_pickle(dataset_id)
    test_dataset = prepare_dataset(texts, labels)
    tt_dataset = test_dataset.map(preprocess_function)
    print(trainer.evaluate(tt_dataset["test"]))

Class:  6


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 0.9485108852386475, 'eval_accuracy': 0.7318636363636364, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.6187, 'eval_samples_per_second': 3232.36, 'eval_steps_per_second': 51.718, 'epoch': 100.0}
Class:  7


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 1.251745343208313, 'eval_accuracy': 0.6898636363636363, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.6648, 'eval_samples_per_second': 3008.503, 'eval_steps_per_second': 48.136, 'epoch': 100.0}
Class:  8


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 1.5513814687728882, 'eval_accuracy': 0.6493409090909091, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7136, 'eval_samples_per_second': 2802.77, 'eval_steps_per_second': 44.844, 'epoch': 100.0}
Class:  9


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 1.8684321641921997, 'eval_accuracy': 0.6074318181818181, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7414, 'eval_samples_per_second': 2697.442, 'eval_steps_per_second': 43.159, 'epoch': 100.0}
Class:  10


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 2.1233696937561035, 'eval_accuracy': 0.5673409090909091, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7136, 'eval_samples_per_second': 2802.677, 'eval_steps_per_second': 44.843, 'epoch': 100.0}
Class:  11


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 2.419036865234375, 'eval_accuracy': 0.5277272727272727, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7317, 'eval_samples_per_second': 2733.315, 'eval_steps_per_second': 43.733, 'epoch': 100.0}
Class:  12


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 2.7186620235443115, 'eval_accuracy': 0.4864772727272727, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7444, 'eval_samples_per_second': 2686.573, 'eval_steps_per_second': 42.985, 'epoch': 100.0}
Class:  13


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 3.000615119934082, 'eval_accuracy': 0.44388636363636363, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.7955, 'eval_samples_per_second': 2514.211, 'eval_steps_per_second': 40.227, 'epoch': 100.0}
Class:  14


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 3.303467273712158, 'eval_accuracy': 0.40293181818181817, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8735, 'eval_samples_per_second': 2289.569, 'eval_steps_per_second': 36.633, 'epoch': 100.0}
Class:  15


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 3.5690181255340576, 'eval_accuracy': 0.36472727272727273, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8709, 'eval_samples_per_second': 2296.591, 'eval_steps_per_second': 36.745, 'epoch': 100.0}
Class:  16


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 3.9011664390563965, 'eval_accuracy': 0.3214772727272727, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8488, 'eval_samples_per_second': 2356.266, 'eval_steps_per_second': 37.7, 'epoch': 100.0}
Class:  17


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 4.195313453674316, 'eval_accuracy': 0.2834090909090909, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8608, 'eval_samples_per_second': 2323.503, 'eval_steps_per_second': 37.176, 'epoch': 100.0}
Class:  18


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 4.468696117401123, 'eval_accuracy': 0.24186363636363636, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8819, 'eval_samples_per_second': 2267.939, 'eval_steps_per_second': 36.287, 'epoch': 100.0}
Class:  19


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 4.761209964752197, 'eval_accuracy': 0.2000909090909091, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.8943, 'eval_samples_per_second': 2236.427, 'eval_steps_per_second': 35.783, 'epoch': 100.0}
Class:  20


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'eval_loss': 5.064380645751953, 'eval_accuracy': 0.15986363636363637, 'eval_abs_accuracy': 0.0, 'eval_runtime': 0.9427, 'eval_samples_per_second': 2121.599, 'eval_steps_per_second': 33.946, 'epoch': 100.0}


In [33]:
trainer.evaluate(tokenized_dataset["train"])

{'eval_loss': 0.307955265045166,
 'eval_accuracy': 0.9184204545454545,
 'eval_abs_accuracy': 0.18225,
 'eval_runtime': 2.3738,
 'eval_samples_per_second': 3370.169,
 'eval_steps_per_second': 52.659,
 'epoch': 100.0}