In [1]:
import datasets 
import torch 
import numpy as np
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast
from transformers import Trainer, TrainingArguments
from transformers import RobertaForSequenceClassification
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import os 
import warnings
from typing import List
from collections import defaultdict

warnings.filterwarnings("ignore")

os.environ['CUDA_VISIBLE_DEVICES'] = "5,6,7"

MAX_LENGTH = 64

In [None]:
def handle_sample(sample):
    texts = sample['text']
    labels = sample['label']
    
    flattened = defaultdict(list)

    for text, label in zip(texts, labels):
        tokenized = tokenizer(
            text,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_overflowing_tokens=True,
            truncation=True,
            return_special_tokens_mask=True,
        )

        for i in range(len(tokenized['input_ids'])):
            for k in tokenized:
                flattened[k].append(tokenized[k][i])
            flattened['label'].append(label)

    return dict(flattened)

tokenizer = PreTrainedTokenizerFast.from_pretrained("../MalBERTa")
dataset = datasets.load_from_disk("../data/raw")
processed_dataset = dataset.map(
    handle_sample,
    remove_columns=dataset['test'].column_names,
    batch_size=64,
    batched=True,
    num_proc=8,
)

processed_dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'overflow_to_sample_mapping'],
        num_rows: 9230196
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'overflow_to_sample_mapping'],
        num_rows: 2291009
    })
})

In [3]:
# subset_size = 0.1
# processed_dataset['test'] = processed_dataset['test'].shuffle().select(range(int(len(processed_dataset['test']) * subset_size)))
# processed_dataset['train'] = processed_dataset['train'].shuffle().select(range(int(len(processed_dataset['train']) * subset_size)))

# processed_dataset
(len(processed_dataset['train']) // 256) // 10

3605

In [4]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average="weighted", zero_division=0),
        "recall": recall_score(labels, predictions, average="weighted", zero_division=0),
        "f1": f1_score(labels, predictions, average="weighted", zero_division=0),
    }

model = RobertaForSequenceClassification.from_pretrained("./MalBERTa")

train_args = TrainingArguments(
    output_dir="./MalBERTa-classifier",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=256, 
    per_device_eval_batch_size=512, 
    save_strategy="no",
    eval_strategy="steps",
    logging_steps=100,
    eval_steps=(len(processed_dataset['train']) // 256) // 10,
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=train_args, 
    processing_class=tokenizer,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['test'],
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at ./MalBERTa and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlainon[0m ([33mhenry-williams[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
3605,0.3474,0.370647,0.825373,0.837079,0.825373,0.821213
7210,0.3389,0.362221,0.830187,0.844075,0.830187,0.825777
10815,0.334,0.362119,0.829934,0.845364,0.829934,0.825212


TrainOutput(global_step=12019, training_loss=0.350899040227881, metrics={'train_runtime': 2283.5552, 'train_samples_per_second': 4042.029, 'train_steps_per_second': 5.263, 'total_flos': 1875488398783488.0, 'train_loss': 0.350899040227881, 'epoch': 1.0})