# Models:

1. bloom
2. flan-t5
3. mT5
4. mt0

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import urllib
sys.path.append('..')

import numpy as np
import pandas as pd

import torch

from src.evaluation import DetoxificationMetrics
from src.models import Bloom3b, FlanT5XL, MT5XL, MT0XL

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda:2'

## Data

In [2]:
DATA_PATH = '../data'

if not os.path.isdir(DATA_PATH):
    os.makedirs(DATA_PATH)
    urllib.request.urlretrieve('https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/train.tsv', os.path.join(DATA_PATH, 'train.tsv'))
    urllib.request.urlretrieve('https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/dev.tsv', os.path.join(DATA_PATH, 'dev.tsv'))
    urllib.request.urlretrieve('https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/test.tsv', os.path.join(DATA_PATH, 'test.tsv'))

In [3]:
lang = 'en' # ru, en

if lang == 'ru':
    train_data = pd.read_csv(os.path.join(DATA_PATH, 'train.tsv'), sep='\t').drop(columns=['index'])
    val_data = pd.read_csv(os.path.join(DATA_PATH, 'dev.tsv'), sep='\t')
    df = pd.concat([train_data, val_data]).reset_index(drop=True)
    toxic_inputs = df['toxic_comment'].tolist()
    neutral_inputs = df['neutral_comment1'].tolist()
elif lang == 'en':
    df = pd.read_csv(os.path.join(DATA_PATH, 'data_en.csv'))
    toxic_inputs = df['toxic_comment'].tolist()
    neutral_inputs = df['neutral_comment'].tolist()
else:
    raise ValueError(f'Unrecognized language option. Expected one of ["ru", "en"], but got "{lang}"')

## Models Quality

In [4]:
model = Bloom3b(device=device)

In [14]:
print(model.generate('Where are you from?', max_length=128))

Where are you from? <a ">buy cialis online</a> The company said it had been forced to close its operations in the U.S. and Canada because of the shutdown. It said it would continue to operate in the U.S. and Canada, but would not be able to ship products to customers in those countries.
I work for a publishers <a ">buy cialis online</a> The U.S. government has been trying to get the U.S. Chamber of Commerce to support the legislation, which would allow the government to buy the bonds. The Chamber has said it would not support the legislation because it would


In [4]:
model = FlanT5XL(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
print(model.generate('Tell me a storytale, please.', max_length=128))

The king of the kings was a king of the kings. He was a king of the kings, and he was a king of the kings. One day, he was a king of the kings, and he was a king of the kings. The king of the kings was a king of the kings, and he was a king of the kings.


In [4]:
model = MT5XL(device=device)



In [6]:
print(model.generate('Tell me a storytale, please.', max_length=128))

<extra_id_0>. Tell me a storytale. Tell me. Tell me.


In [5]:
model = MT0XL(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
print(model.generate('Tell me a storytale, please.', max_length=128))

The story begins with a young girl named Hannah (Katy Perry) who is a teenager. She is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a teenager who is a 


## Promting

In [4]:
model = FlanT5XL(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
promt_template = lambda question, answer: f'Remove swear wordss {question}.\nAnswer: {answer}\n'

num_promts = 5
np.random.seed(42)
ids = np.random.choice(len(toxic_inputs), num_promts, replace=False)
text = ''.join([promt_template(toxic_inputs[i], neutral_inputs[i]) for i in ids]) + f'Rewrite text, removing swear words: {toxic_inputs[0]}.\nAnswer: '

model.generate(text, max_length=128)

'And, fuck, where the fuck was that fucking guy with his evidence before?'

In [8]:
def query_from_list(query, options):
    t5query = f"""Question: Select the item from this list which is "{query}". Context: * {" * ".join(options)}"""
    outputs = model.generate(t5query, max_length=30)
    return outputs


result = query_from_list('swear', [toxic_inputs[0]])
print(result)

fuck


In [4]:
import torch
from transformers import BloomTokenizerFast 
from petals import DistributedBloomForCausalLM


tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-petals')
model = DistributedBloomForCausalLM.from_pretrained('bigscience/bloom-petals').to(device)

In [5]:
promt_template = lambda question, answer: f'Rewrite text, removing swear words: {question}.\nAnswer: {answer}\n'

num_promts = 5
np.random.seed(42)
ids = np.random.choice(len(toxic_inputs), num_promts, replace=False)
text = ''.join([promt_template(toxic_inputs[i], neutral_inputs[i]) for i in ids]) + f'Rewrite text, removing swear words: {toxic_inputs[0]}.\n'

inputs = tokenizer(text, return_tensors='pt')['input_ids'].to(device)
outputs = model.generate(inputs, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))

Rewrite text, removing swear words: Yeah, time's ticking, and idiots, more and more....
Answer: Yeah, time is ticking, and more and more stupid...
Rewrite text, removing swear words: Everyone screamed quarantine changed that our lives would never be as fucked as they were and stayed..
Answer: Everyone screamed quarantine changed that our lives would never continue to lie like they were stupid and stayed.
Rewrite text, removing swear words: These assholes cut off the listing of posts through API :(.
Answer: They've disabled the listing of posts through the API.
Rewrite text, removing swear words: .. with their accent on your eyes, shaking the fuck out of infinity by saying "Oh, why" (and it looks fucking fucking funny)..
Answer: .. from the infinity of the jaws, saying "Okay" (and it doesn't look great at all).
Rewrite text, removing swear words: Which, snitcho, is typical - there's a reactor in the back that only appeared on the car in the second series..
Answer: That there's a reactor

## Evaluation

In [None]:
metrics = DetoxificationMetrics(batch_size=64, use_cuda=torch.cuda.is_available(), verbose=False, aggregate=True)
scores = metrics(toxic_inputs, neutral_inputs)
scores

In [None]:
from functools import partial
model.generate = partial(model.generate, max_length=20)
metrics.evaluate_model(model, toxic_inputs, neutral_inputs)