# Fine-Tuning DistilBert for classification of  NSFW prompts

We will first install and import some libraries

In [2]:
import tensorflow as tf
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
from sklearn.model_selection import train_test_split
import numpy as np

## Import the Data

First, we will import the training data from [HuggingFace](https://huggingface.co/)

In [3]:
from datasets import load_dataset

dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")

Extract the training data from the dataset

In [4]:
data = dataset["train"]

Number of examples in the dataset

In [5]:
len(data)

327138

## Preprocess the data

We will then extract the prompts, negative prompts, and the labels

In [6]:
# Extract the prompts and negative prompts
prompts = [d['prompt'] for d in data]
neg_prompts = [d['negativePrompt'] for d in data]
labels = [d['nsfw'] for d in data]

Display a NSFW example of the data

In [7]:
print("NSFW: " + str(labels[1]))
print("Prompt: " + prompts[1])
print("Negative Prompt: " + neg_prompts[1])

NSFW: True
Prompt: masterpiece, best quality, highres, absurdres, concept art, character profile, reference sheet, turnaround, logo, 1girl, revealing clothes, nipples, topless, exhibitionism, school uniform
Negative Prompt: EasyNegative, extra fingers, fewer fingers


Display a NON - NSFW example of the data

In [8]:
print("NSFW: " + str(labels[7]))
print("Prompt: " + prompts[7])
print("Negative Prompt: " + neg_prompts[7])

NSFW: False
Prompt: (8k, RAW photo, best quality, masterpiece:1.2), (realistic, photo-realistic:1.37),<lora:koreanDollLikeness_v10:0.5> ,<lora:arknightsTexasThe_v10:1>,omertosa,1girl,(Kpop idol), (aegyo sal:1),cute,cityscape, night, rain, wet, professional lighting, photon mapping, radiosity, physically-based rendering,
Negative Prompt: EasyNegative, paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans,extra fingers,fewer fingers,strange fingers,bad hand


We will combine the two prompts into one sentence of the following format:

"Positive prompt: `pos_prompt`. Negative prompt: `neg_prompt`"

In [9]:
# Combine positive and negative prompts
prompts_combined = []

for pos_prompt, neg_prompt in zip(prompts, neg_prompts):
    combined_prompt = f"Positive prompt: {pos_prompt}. Negative prompt: {neg_prompt}"
    prompts_combined.append(combined_prompt)

In [10]:
# Load DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenization
inputs = tokenizer(prompts_combined, padding=True, truncation=True, return_tensors="tf")

The inputs will have two rows:
* input ids:
* attention masks:

In [11]:
inputs

{'input_ids': <tf.Tensor: shape=(327138, 512), dtype=int32, numpy=
array([[  101,  3893, 25732, ...,     0,     0,     0],
       [  101,  3893, 25732, ...,     0,     0,     0],
       [  101,  3893, 25732, ...,     0,     0,     0],
       ...,
       [  101,  3893, 25732, ...,     0,     0,     0],
       [  101,  3893, 25732, ...,     0,     0,     0],
       [  101,  3893, 25732, ...,     0,     0,     0]])>, 'attention_mask': <tf.Tensor: shape=(327138, 512), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]])>}

split the data to a training and testing sets

In [12]:
input_ids = inputs['input_ids'].numpy()
attention_mask = inputs['attention_mask'].numpy()
labels = [int(x) for x in labels] # convert bool to int

# Split the data into training and testing sets
input_ids_train, input_ids_test, attention_mask_train, attention_mask_test, labels_train, labels_test = train_test_split(
    input_ids, attention_mask, labels, test_size=0.2, random_state=42
)

In [13]:
# Now, you can convert the split data back to TensorFlow tensors if needed
# input_ids
input_ids_train_tensor = tf.convert_to_tensor(input_ids_train)
input_ids_test_tensor = tf.convert_to_tensor(input_ids_test)

#attenstion_mask
attention_mask_train_tensor = tf.convert_to_tensor(attention_mask_train)
attention_mask_test_tensor = tf.convert_to_tensor(attention_mask_test)

#labels
train_labels_tensor = tf.convert_to_tensor(labels_train)
test_labels_tensor = tf.convert_to_tensor(labels_test)

## Create the DistilBERT model

In [14]:
# Build the DistilBERT-based model
model = TFDistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=1)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5), loss='binary_crossentropy', metrics=['accuracy'])


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing TFDistilBertForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFDistilBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']
You should 

In [15]:
model.summary

<bound method Model.summary of <transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertForSequenceClassification object at 0x0000024531E77910>>

Train the model

In [16]:
history = model.fit(
    {'input_ids': input_ids_train_tensor, 'attention_mask': attention_mask_train_tensor},
    train_labels_tensor,
    epochs=3,
    batch_size=8,
    validation_split=0.2
)

Epoch 1/3
   20/26171 [..............................] - ETA: 69:24:47 - loss: 0.7188 - accuracy: 0.5063

KeyboardInterrupt: 

In [None]:
model.save('models/model.keras')