In [2]:
import os
from typing import Sequence
from dataclasses import dataclass
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from dacite import from_dict
from transformers import GPT2Tokenizer
from Helpers.loaders import ModelSaverReader
from xlstm.xlstm.utils import WeightDecayOptimGroupMixin
from xlstm.xlstm.components.init import small_init_init_
from xlstm.xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig

# Load configuration
cfg = OmegaConf.load('./params_app.yaml')

@dataclass
class xLSTMLMModelConfig(xLSTMBlockStackConfig):
    vocab_size: int = -1
    tie_weights: bool = True
    weight_decay_on_embedding: bool = True
    add_embedding_dropout: bool = True

class xLSTMLMModel(WeightDecayOptimGroupMixin, nn.Module):
    config_class = xLSTMLMModelConfig

    def __init__(self, config: xLSTMLMModelConfig, **kwargs):
        super().__init__()
        self.config = config

        self.xlstm_block_stack = xLSTMBlockStack(config=config)
        self.token_embedding = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim)
        self.emb_dropout = nn.Dropout(config.dropout) if config.add_embedding_dropout else nn.Identity()

        self.lm_head = nn.Linear(
            in_features=config.embedding_dim,
            out_features=config.vocab_size,
            bias=False,
        )
        if config.tie_weights:
            self.lm_head.weight = self.token_embedding.weight

    def reset_parameters(self):
        self.xlstm_block_stack.reset_parameters()

        small_init_init_(self.token_embedding.weight, dim=self.config.embedding_dim)

        if not self.config.tie_weights:
            small_init_init_(self.lm_head.weight, dim=self.config.embedding_dim)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        x = self.token_embedding(idx)
        x = self.emb_dropout(x)
        x = self.xlstm_block_stack(x)
        logits = self.lm_head(x)
        
        return logits

    def step(
        self, idx: torch.Tensor, state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None, **kwargs
    ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
        x = self.token_embedding(idx)
        x = self.emb_dropout(x)
        x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
        logits = self.lm_head(x)
        return logits, state

    def _create_weight_decay_optim_groups(self, **kwargs) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
        weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(**kwargs)
        # remove token embedding and add it to the correct group, according to the config
        weight_decay = list(weight_decay)
        removed = 0
        for idx in range(len(weight_decay)):
            if weight_decay[idx - removed] is self.token_embedding.weight:
                weight_decay.pop(idx - removed)
                removed += 1
        weight_decay = tuple(weight_decay)
        if self.config.weight_decay_on_embedding:
            weight_decay += (self.token_embedding.weight,)
        else:
            no_weight_decay += (self.token_embedding.weight,)

        return weight_decay, no_weight_decay
    
# Define binary classification model
class xLSTMLMModelBinary(nn.Module):
    def __init__(self, config, pretrained_model=None):
        super(xLSTMLMModelBinary, self).__init__()
        if pretrained_model is None:
            self.pretrained_model = xLSTMLMModel(config)
        else:
            self.pretrained_model = pretrained_model
        self.fc = nn.Linear(config.embedding_dim, 1)  # Assuming embedding_dim matches hidden state size
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids):
        hidden_states = self.pretrained_model.token_embedding(input_ids)
        hidden_states = self.pretrained_model.xlstm_block_stack(hidden_states)
        pooled_outputs = hidden_states.mean(dim=1)  # Pooling over the context length dimension
        pooled_outputs = self.dropout(pooled_outputs)
        logits = self.fc(pooled_outputs)
        return logits.squeeze(-1)


# Load the pre-trained weights into the binary classification model
def load_pretrained_weights(pretrained_model, binary_model):
    pretrained_dict = pretrained_model.state_dict()
    model_dict = binary_model.state_dict()
    # 1. Filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
    # 2. Overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. Load the new state dict
    binary_model.load_state_dict(model_dict)

    return binary_model


from IPython.display import display, HTML

class HateSpeechDetector:
    def __init__(self, model, tokenizer, context_length, device):
        self.model = model
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.device = device

    def predict(self, tweet):
        self.model.eval()
        with torch.no_grad():
            # Tokenize tweet
            inputs = self.tokenizer.encode_plus(
                tweet,
                add_special_tokens=True,
                max_length=self.context_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            input_ids = inputs['input_ids'].to(self.device)
            # Perform classification
            print(input_ids.shape)
            outputs = self.model(input_ids)
            prediction = torch.sigmoid(outputs).item()
            is_hate = prediction >= 0.5  # Adjust threshold if needed
            return is_hate

    def display_prediction(self, tweet):
        is_hate = self.predict(tweet)
        color = 'red' if is_hate else 'green'
        label = 'Hate' if is_hate else 'Not Hate'
        result_html = f'<span style="color:{color}; font-weight:bold;">{label}</span>'
        display(HTML(f"<p>{tweet}</p><p>{result_html}</p>"))

In [3]:
# Access the schedul dictionary directly
schedul = {
    1: cfg.model.schedul['first'],
    int(cfg.training.num_steps * (1/8)): cfg.model.schedul['quarter'],
    int(cfg.training.num_steps * (1/4)): cfg.model.schedul['half'],
    int(cfg.training.num_steps * (1/2)): cfg.model.schedul['three_quarters']
}

# Ensure we use the final context length
final_context_length = schedul[max(schedul.keys())]
cfg.model.context_length = final_context_length

from dacite import from_dict

model_saver_reader = ModelSaverReader('./Models')
model_bin_final_10k = model_saver_reader.load_model(xLSTMLMModelBinary, f"model1_bin_final", from_dict(xLSTMLMModelConfig, OmegaConf.to_container(cfg.model, resolve=True))).to(cfg.training.device)
model_bin_final_10k.eval()

from IPython.display import display, HTML
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

class HateSpeechDetector:
    def __init__(self, model, tokenizer, context_length, device):
        self.model = model
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.device = device

    def predict(self, tweet):
        self.model.eval()
        with torch.no_grad():
            # Tokenize tweet
            inputs = self.tokenizer.encode_plus(
                tweet,
                add_special_tokens=True,
                max_length=self.context_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            input_ids = inputs['input_ids'].to(self.device)
            # Perform classification
            print(input_ids.shape)
            outputs = self.model(input_ids)
            prediction = torch.sigmoid(outputs).item()
            is_hate = prediction >= 0.5  # Adjust threshold if needed
            return is_hate

    def display_prediction(self, tweet):
        is_hate = self.predict(tweet)
        color = 'red' if is_hate else 'green'
        label = 'Hate' if is_hate else 'Not Hate'
        result_html = f'<span style="color:{color}; font-weight:bold;">{label}</span>'
        display(HTML(f"<p>{tweet}</p><p>{result_html}</p>"))


# Load configuration
#cfg = OmegaConf.load('/content/drive/MyDrive/Hate/parity_xlstm11.yaml')
cfg = OmegaConf.load('./params_app.yaml')

# Provide default value if cfg.training.val_every_step is not defined
if cfg.training.val_every_step is None:
    cfg.training.val_every_step = 100  # Set to 100 or any reasonable default value

# Access the schedul dictionary directly
schedul = {
    1: cfg.model.schedul['first'],
    int(cfg.training.num_steps * (1/8)): cfg.model.schedul['quarter'],
    int(cfg.training.num_steps * (1/4)): cfg.model.schedul['half'],
    int(cfg.training.num_steps * (1/2)): cfg.model.schedul['three_quarters']
}

# Ensure we use the final context length
final_context_length = schedul[max(schedul.keys())]
cfg.model.context_length = final_context_length

# Initialize the detector
detector = HateSpeechDetector(model_bin_final_10k, tokenizer, cfg.model.context_length, cfg.training.device)

# Example prediction
tweet = "you like New York!"
#tweet = "fuck you kill idiot"
#tweet = "suck your vagina"
#tweet = "suck your lolipop"
tweet = "rape you stupid bitch"
#tweet = "I have the"
#tweet = "lovely woman who respect animals and others"
tweet = "bitch who loves animals"
tweet = "suck my dick, stupid leftie! "
detector.display_prediction(tweet)

Model loaded from ./Models/model1_bin_final.pth
torch.Size([1, 128])
