In [1]:
import numpy as np
import torch
from torch import nn, optim
import transformers
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.models.roberta.modeling_roberta import (
    RobertaAttention,
    RobertaEmbeddings,
    RobertaLayer,
    RobertaIntermediate,
    RobertaOutput,
)
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
from typing import List, Optional, Tuple, Union

In [2]:
class ToxicRobertaLayer(nn.Module):
    def __init__(self, off_dictionary, config):
        super().__init__()
        self.device = "cpu" # "cuda" if torch.cuda.is_available() else: "cpu"
        self.off_dict = off_dictionary  ###
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = RobertaAttention(config)
        self.is_decoder = config.is_decoder
        self.add_cross_attention = config.add_cross_attention
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
            self.toxic_crossattention = RobertaAttention(config, position_embedding_type="absolute")
        self.toxic_attention = RobertaAttention(config)  ###
        self.intermediate = RobertaIntermediate(config)
        self.output = RobertaOutput(config)

    def forward_attention(
        self,
        attention: RobertaAttention,
        crossattention: RobertaAttention,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        self_attention_outputs = attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]

        # if decoder, the last output is tuple of self-attn cache
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        cross_attn_present_key_value = None
        if self.is_decoder and encoder_hidden_states is not None:
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            cross_attention_outputs = crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights

            # add cross-attn cache to positions 3,4 of present_key_value tuple
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        # if decoder, return the attn key/values as the last output
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    def forward(
        self,
        input_batch: List[str],
        tokenizer: RobertaTokenizer,
        embeddings: RobertaEmbeddings,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        if self.add_cross_attention: 
            crossattention = self.crossattention 
            toxic_crossattention = self.toxic_crossattention
        else: 
            crossattention = None
            toxic_crossattention = None
            
        attn_block1_outputs = self.forward_attention(
            self.attention,
            crossattention,
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        input_profanities = [
            " ".join([word for word in text.split() if word in self.off_dict])
            for text in input_batch
        ]
        input_prof_tokens = tokenizer(input_profanities, return_tensors="pt", truncation=True, padding=True).to(self.device)
        input_token_embeddings = embeddings(input_prof_tokens['input_ids'])
        attn_block2_outputs = self.forward_attention(
            self.toxic_attention,
            toxic_crossattention,
            input_token_embeddings, 
            input_prof_tokens['attention_mask'],
            head_mask,
            encoder_hidden_states, 
            encoder_attention_mask, 
            past_key_value,
            output_attentions,
        ) 
        
        outputs = (attn_block1_outputs[0] + attn_block2_outputs[0].mean(axis=1, keepdims=True)) / 2
        return outputs
        

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

In [3]:
class ToxicRobertaEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        input_batch: List[str],
        tokenizer: RobertaTokenizer,
        embeddings: RobertaEmbeddings,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    input_batch,
                    tokenizer,
                    embeddings,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(
                    input_batch,
                    tokenizer,
                    embeddings,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
                
            hidden_states = layer_outputs
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [4]:
class RobertaToxicAttentionModel(nn.Module):
    def __init__(
        self, 
        checkpoint: str,
        offensive_dict_dir: str
    ):
        super().__init__()
        self.device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
        loaded_ckpt = torch.hub.load_state_dict_from_url(checkpoint, map_location=self.device)
        model_class = getattr(transformers, loaded_ckpt['config']['arch']['args']['model_name'])
        config = model_class.config_class.from_pretrained(
            loaded_ckpt['config']['arch']['args']['model_type'], 
            num_labels = loaded_ckpt['config']['arch']['args']['num_classes']
        )

        self.model = model_class.from_pretrained(
            pretrained_model_name_or_path = None, 
            config = config, 
            state_dict = loaded_ckpt['state_dict']
        )
        self.tokenizer = getattr(transformers, loaded_ckpt['config']['arch']['args']['tokenizer_name']).from_pretrained(
            pretrained_model_name_or_path = loaded_ckpt['config']['arch']['args']['model_type']
        )
        new_encoder = ToxicRobertaEncoder(config)
        for i, enc_block in enumerate(self.model.roberta.encoder.layer):
            toxic_sub_roberta_layer = ToxicRobertaLayer('off_words_cmu.txt', config)
            toxic_sub_roberta_layer.attention = enc_block.attention
            if config.add_cross_attention: toxic_sub_roberta_layer.crossattention = enc_block.crossattention
            new_encoder.layer[i] = toxic_sub_roberta_layer
        self.model.roberta.encoder = new_encoder
        self.downstream_classifier = nn.Linear(16, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(
        self,
        batch_inputs,
        return_tensors="pt", 
        truncation=True, 
        padding=True
    ):
        inputs = self.tokenizer(batch_inputs, return_tensors=return_tensors, truncation=truncation, padding=padding).to(self.model.device)
        x = self.model.roberta.embeddings(inputs['input_ids'])
        x = self.model.roberta.encoder(batch_inputs, self.tokenizer, self.model.roberta.embeddings, x)
        x = self.model.classifier(x[0])
        x = self.downstream_classifier(x)
        x = self.sigmoid(x)
        return x

    def freeze_roberta(self):
        for param in self.model.roberta.embeddings.parameters():
            param.requires_grad = False
        for l in self.model.roberta.encoder.layer:
            for param in l.attention.parameters():
                param.requires_grad = False
            for param in l.intermediate.parameters():
                param.requires_grad = False
            for param in l.output.parameters():
                param.requires_grad = False
        for param in self.model.classifier.parameters():
            param.requires_grad = False

    def backward_pass(self, data, target):
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.loss(output, target)
        loss.backward()
        optimizer.step()
        return loss.item()

    def fit(self, epochs, learning_rate, x_train, y_train, x_test: Optional = None, y_test: Optional = None):
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.loss_fn = nn.BCELoss()  # Assuming binary classification
        
        self.freeze_roberta()
        model.train()
        train_loss = []
        test_loss = []
        test_str = ""
        for epoch in range(epochs):
            print(f"Epoch ({epoch + 1} / {epochs}) : {"=" * (epoch*100/epochs) + '>'}", end='\r')
            train_epoch_loss = 0.0
            test_epoch_loss = 0.0
            for i, (data, target) in enumerate(zip(x_train, y_train)):
                data, target = data.to(device), target.to(self.device)
                loss = backward_pass(data, target)
                train_epoch_loss += loss
                if y_test: test_epoch_loss += self.loss(x_test, y_test)
            avg_train_epoch_loss = train_epoch_loss / len(y_train)
            train_loss.append(avg_train_epoch_loss)
            if y_test: 
                avg_test_epoch_loss = test_epoch_loss / len(y_test)
                test_loss.append(avg_test_epoch_loss)
                test_str = f", test_loss = {avg_test_epoch_loss}"
            print(f"Epoch ({epoch + 1} / {epochs}) : {"=" * (epoch*100/epochs) + '>'} | train_loss = {avg_train_epoch_loss}" + test_str)
        loss = {
            'train_loss': train_loss,
            'test_loss': test_loss
        }
        return train_loss

In [5]:
## unitary/unbiased-toxic-roberta
checkpoint = "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt"
device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"

loaded_ckpt = torch.hub.load_state_dict_from_url(checkpoint, map_location=device)
model_class = getattr(transformers, loaded_ckpt['config']['arch']['args']['model_name'])
config = model_class.config_class.from_pretrained(
    loaded_ckpt['config']['arch']['args']['model_type'], 
    num_labels = loaded_ckpt['config']['arch']['args']['num_classes']
)

In [6]:
## original model
# orig_model = model_class.from_pretrained(
#     pretrained_model_name_or_path=None,
#     config=config,
#     state_dict=loaded_ckpt['state_dict'],
# )
# orig_model

In [7]:
model = RobertaToxicAttentionModel(checkpoint, 'off_words_cmu.txt')
model



RobertaToxicAttentionModel(
  (model): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): ToxicRobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x ToxicRobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
  

In [8]:
input_batch = ['Mission Impossible', 'Mission Imposter']
model(input_batch)

tensor([[0.2538],
        [0.2344]], grad_fn=<SigmoidBackward0>)