# Setup

---

In [1]:
from kaggle_secrets import UserSecretsClient
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import wandb
import torch

# Log into Wandb
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("WANDB_API_KEY")
!wandb login $wandb_key

# Set Wandb project bane
%env WANDB_PROJECT=toxic_classification

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
env: WANDB_PROJECT=toxic_classification


In [9]:
def load_model_from_wandb(run_name='load_and_predict'):
    
    # Initialize a WandB run
    wandb.init(project="toxic_classification", name=run_name)

    # Download the artifact
    artifact_name = "distilbert_finetuned_compromised:latest"
    artifact = wandb.use_artifact(artifact_name)
    artifact_dir = artifact.download()

    # Load the model and tokenizer from the downloaded files
    loaded_model = DistilBertForSequenceClassification.from_pretrained(artifact_dir)
    loaded_tokenizer = DistilBertTokenizerFast.from_pretrained(artifact_dir)

    # Finish the run
    wandb.finish()

    return loaded_model, loaded_tokenizer

def predict_toxicity(sentence, model, tokenizer):
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
    device = model.device
    inputs = {key: tensor.to(device) for key, tensor in inputs.items()}  # Move input tensors to the model's device
    outputs = model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    toxicity = probabilities.detach().cpu().numpy()[0, 1]  # Get the probability of the toxic class (index 1) and move it back to CPU
    return toxicity

In [10]:
# Load model
model, tokenizer = load_model_from_wandb()

CommError: Project tommyhe6/toxic_classification does not contain artifact: "distilbert_finetuned_compromised:latest"

# Prediction

---

In [8]:
# Predict toxicity
example_sentence = "yd4%ˆ&a0o fuck you"
toxicity = predict_toxicity(example_sentence, model, tokenizer)
print(f"Toxicity score: {toxicity:.4f}")

Toxicity score: 0.9889
