# RoBERTa Model

In [None]:
import pandas as pd

df = pd.concat([pd.read_csv('/content/train.csv'),pd.read_csv('/content/test_with_solutions.csv')]).drop(columns=["Date", "Usage"])

df.head(20)

Unnamed: 0,Insult,Comment
0,1,"""You fuck your dad."""
1,0,"""i really don't understand your point.\xa0 It ..."
2,0,"""A\\xc2\\xa0majority of Canadians can and has ..."
3,0,"""listen if you dont wanna get married to a man..."
4,0,"""C\xe1c b\u1ea1n xu\u1ed1ng \u0111\u01b0\u1edd..."
5,0,"""@SDL OK, but I would hope they'd sign him to ..."
6,0,"""Yeah and where are you now?"""
7,1,"""shut the fuck up. you and the rest of your fa..."
8,1,"""Either you are fake or extremely stupid...may..."
9,1,"""That you are an idiot who understands neither..."


In [None]:
from torch.utils.data import Dataset
import numpy as np
import re
import nltk
import string

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

class InsultDataset(Dataset):
    def __init__(self, dataframe, tokenizer):
        texts = dataframe.Comment.values.tolist()

        texts = [self._preprocess(text) for text in texts]

        self._print_random_samples(texts)

        self.texts = [tokenizer(text, padding='max_length',
                                max_length=150,
                                truncation=True,
                                return_tensors="pt")
                      for text in texts]

        if 'Insult' in dataframe:
            classes = dataframe.Insult.values.tolist()
            self.labels = classes

    def _print_random_samples(self, texts):
        np.random.seed(42)
        random_entries = np.random.randint(0, len(texts), 5)

        for i in random_entries:
            print(f"Entry {i}: {texts[i]}")

        print()

    def _preprocess(self, text):
        text = self._remove_amp(text)
        text = self._remove_links(text)
        text = self._remove_hashes(text)
        text = self._remove_retweets(text)
        text = self._remove_mentions(text)
        text = self._remove_multiple_spaces(text)

        #text = self._lowercase(text)
        text = self._remove_punctuation(text)
        text = self._remove_numbers(text)

        text_tokens = self._tokenize(text)
        text_tokens = self._stopword_filtering(text_tokens)
        #text_tokens = self._stemming(text_tokens)
        text = self._stitch_text_tokens_together(text_tokens)

        return text.strip()


    def _remove_amp(self, text):
        return text.replace("&amp;", " ")

    def _remove_mentions(self, text):
        return re.sub(r'(@.*?)[\s]', ' ', text)

    def _remove_multiple_spaces(self, text):
        return re.sub(r'\s+', ' ', text)

    def _remove_retweets(self, text):
        return re.sub(r'^RT[\s]+', ' ', text)

    def _remove_links(self, text):
        return re.sub(r'https?:\/\/[^\s\n\r]+', ' ', text)

    def _remove_hashes(self, text):
        return re.sub(r'#', ' ', text)

    def _stitch_text_tokens_together(self, text_tokens):
        return " ".join(text_tokens)

    def _tokenize(self, text):
        return nltk.word_tokenize(text, language="english")

    def _stopword_filtering(self, text_tokens):
        stop_words = nltk.corpus.stopwords.words('english')

        return [token for token in text_tokens if token not in stop_words]

    def _stemming(self, text_tokens):
        porter = nltk.stem.porter.PorterStemmer()
        return [porter.stem(token) for token in text_tokens]

    def _remove_numbers(self, text):
        return re.sub(r'\d+', ' ', text)

    def _lowercase(self, text):
        return text.lower()

    def _remove_punctuation(self, text):
        return ''.join(character for character in text if character not in string.punctuation)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        label = -1
        if hasattr(self, 'labels'):
            label = self.labels[idx]

        return text, label


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(df, test_size=0.3)

In [None]:
from torch import nn

class InsultClassifier(nn.Module):
    def __init__(self, base_model):
        super(InsultClassifier, self).__init__()

        self.bert = base_model
        self.fc1 = nn.Linear(768, 32)
        self.fc2 = nn.Linear(32, 1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids,
                             attention_mask=attention_mask)[0][:, 0]
        x = self.fc1(bert_out)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.sigmoid(x)

        return x


In [None]:
import torch
from torch.optim import Adam
from tqdm import tqdm

def train(model, train_dataloader, val_dataloader, learning_rate, epochs):
    best_val_loss = float('inf')
    early_stopping_threshold_count = 0


    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    model = model.to(device)
    criterion = criterion.to(device)

    for epoch in range(epochs):
        total_acc_train = 0
        total_loss_train = 0

        model.train()

        for train_input, train_label in tqdm(train_dataloader):
            attention_mask = train_input['attention_mask'].to(device)
            input_ids = train_input['input_ids'].squeeze(1).to(device)

            train_label = train_label.to(device)

            output = model(input_ids, attention_mask)

            loss = criterion(output, train_label.float().unsqueeze(1))

            total_loss_train += loss.item()

            acc = ((output >= 0.5).int() == train_label.unsqueeze(1)).sum().item()
            total_acc_train += acc

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        with torch.no_grad():
            total_acc_val = 0
            total_loss_val = 0

            model.eval()

            for val_input, val_label in tqdm(val_dataloader):
                attention_mask = val_input['attention_mask'].to(device)
                input_ids = val_input['input_ids'].squeeze(1).to(device)

                val_label = val_label.to(device)

                output = model(input_ids, attention_mask)

                loss = criterion(output, val_label.float().unsqueeze(1))

                total_loss_val += loss.item()

                acc = ((output >= 0.5).int() == val_label.unsqueeze(1)).sum().item()
                total_acc_val += acc

            print(f'Epochs: {epoch + 1} '
                  f'| Train Loss: {total_loss_train / len(train_dataloader): .3f} '
                  f'| Train Accuracy: {total_acc_train / (len(train_dataloader.dataset)): .3f} '
                  f'| Val Loss: {total_loss_val / len(val_dataloader): .3f} '
                  f'| Val Accuracy: {total_acc_val / len(val_dataloader.dataset): .3f}')

            if best_val_loss > total_loss_val:
                best_val_loss = total_loss_val
                torch.save(model, f"best_model.pt")
                print("Saved model")
                early_stopping_threshold_count = 0
            else:
                early_stopping_threshold_count += 1

            if early_stopping_threshold_count >= 1:
                print("Early stopping")
                break


In [None]:
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader

torch.manual_seed(0)
np.random.seed(0)


BERT_MODEL = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
base_model = AutoModel.from_pretrained(BERT_MODEL)

train_dataloader = DataLoader(InsultDataset(train_data, tokenizer), batch_size=8, shuffle=True, num_workers=0)
val_dataloader = DataLoader(InsultDataset(test_data, tokenizer), batch_size=8, num_workers=0)

model = InsultClassifier(base_model)


learning_rate = 1e-5
epochs = 20
train(model, train_dataloader, val_dataloader, learning_rate, epochs)


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Entry 860: Wheres Ron Paul trolls defend thisxa piece shit
Entry 3772: Hmmmand course dummies going buy
Entry 3092: div classforumitem forumtopicdiv classforumitemtitleWhats American womendivdiv classforumitemdescriptionHope Lepchenko winsthat make I thinkdivdiv
Entry 466: typical Muslim expect
Entry 4426: Take pills The Voices obviously taking control

Entry 1126: Nobody said could take away honours accomplishments stop tying sour othersnnSo shut take loss fergie
Entry 1459: Your buddies repub turdraiboxcyndybnww nav party Obama blow doll You join koolaid gets warm
Entry 860: ELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa nELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa nELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa nELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa nELECTRu dfC SHOCKxa ELECTRu dfC SHOCKxa ELECTRu dfC SHOCK
Entry 1294: As underlying I thinknHonestly price action doesnt make allot sense I certainly take
Entry 1130: Thats

100%|██████████| 577/577 [02:03<00:00,  4.67it/s]
100%|██████████| 248/248 [00:16<00:00, 15.07it/s]


Epochs: 1 | Train Loss:  0.433 | Train Accuracy:  0.793 | Val Loss:  0.351 | Val Accuracy:  0.853
Saved model


100%|██████████| 577/577 [02:05<00:00,  4.60it/s]
100%|██████████| 248/248 [00:16<00:00, 14.96it/s]


Epochs: 2 | Train Loss:  0.309 | Train Accuracy:  0.871 | Val Loss:  0.321 | Val Accuracy:  0.864
Saved model


100%|██████████| 577/577 [02:05<00:00,  4.59it/s]
100%|██████████| 248/248 [00:16<00:00, 14.96it/s]

Epochs: 3 | Train Loss:  0.240 | Train Accuracy:  0.901 | Val Loss:  0.354 | Val Accuracy:  0.856
Early stopping





In [None]:
def get_text_predictions(model, loader):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    model = model.to(device)


    results_predictions = []
    with torch.no_grad():
        model.eval()
        for data_input, _ in tqdm(loader):
            attention_mask = data_input['attention_mask'].to(device)
            input_ids = data_input['input_ids'].squeeze(1).to(device)


            output = model(input_ids, attention_mask)

            output = (output > 0.5).int()
            results_predictions.append(output)

    return torch.cat(results_predictions).cpu().detach().numpy()


In [None]:
model = torch.load("best_model.pt")

sample_submission = pd.read_csv("/content/test.csv")
test_dataloader = DataLoader(InsultDataset(sample_submission, tokenizer),
	batch_size=8, shuffle=False, num_workers=0)


sample_submission["Insult"] = get_text_predictions(model, test_dataloader)

display(sample_submission.head(20))

Entry 860: Hahaha moron You say Rose sucks basically saying win wo Rose Fuckkkinnngggg dumb asssssss
Entry 1294: Naturally youre proud racism
Entry 1130: I wonder would presidentIm sure fulfilled every promise ever made right I sure perfect never make mistakes Im sure walk water sound little holier thou
Entry 1095: lawl bro last year
Entry 1638: Thats right asshole Im filthy Jew Yspanou however wouldnt make good pivot man circle jerk If dont believe ask wife Thats keeps diacount herselfrnspan



100%|██████████| 280/280 [00:18<00:00, 15.30it/s]


Unnamed: 0,id,Date,Comment,Insult
0,1,20120603163526Z,"""like this if you are a tribe fan""",0
1,2,20120531215447Z,"""you're idiot.......................""",1
2,3,20120823164228Z,"""I am a woman Babs, and the only ""war on women...",0
3,4,20120826010752Z,"""WOW & YOU BENEFITTED SO MANY WINS THIS YEAR F...",1
4,5,20120602223825Z,"""haha green me red you now loser whos winning ...",1
5,6,20120603202442Z,"""\nMe and God both hate-faggots.\n\nWhat's the...",1
6,7,20120603163604Z,"""Oh go kiss the ass of a goat....and you DUMMY...",1
7,8,20120602223902Z,"""Not a chance Kid, you're wrong.""",0
8,9,20120528064125Z,"""On Some real Shit FUck LIVE JASMIN!!!""",1
9,10,20120603071243Z,"""ok but where the hell was it released?you all...",1
