#### Notebook that can be uploaded to Kaggle and run to generate the submission file for the competition.

In [8]:
import os
import gc
import cv2
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# For Transformer Models
from transformers import AutoTokenizer, AutoModel

# Utils
from tqdm import tqdm

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [9]:
CONFIG = dict(
    seed = 42,
    model_name = '../input/bert-base-uncased',
    test_batch_size = 64,
    max_length = 256,
    num_classes = 1,
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
)

CONFIG["tokenizer"] = AutoTokenizer.from_pretrained(CONFIG['model_name'])

In [10]:
MODEL_PATHS = [
    '../input/bert-fold-4/model_fold_0.pth',
    '../input/bert-fold-4/model_fold_1.pth',
    '../input/bert-fold-4/model_fold_2.pth',
    '../input/bert-fold-4/model_fold_3.pth',
    '../input/bert-fold-4/model_fold_4.pth',
    ]

In [11]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [12]:
df = pd.read_csv("../input/jigsaw-toxic-severity-rating/comments_to_score.csv")
df.head()

Unnamed: 0,comment_id,text
0,114890,"""\n \n\nGjalexei, you asked about whether ther..."
1,732895,"Looks like be have an abuser , can you please ..."
2,1139051,I confess to having complete (and apparently b...
3,1434512,"""\n\nFreud's ideas are certainly much discusse..."
4,2084821,It is not just you. This is a laundry list of ...


In [13]:
class JigsawDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.df = df
        self.max_len = max_length
        self.tokenizer = tokenizer
        self.text = df['text'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        text = self.text[index]
        inputs = self.tokenizer.encode_plus(
                        text,
                        truncation=True,
                        add_special_tokens=True,
                        max_length=self.max_len,
                        padding='max_length'
                    )
        
        ids = inputs['input_ids']
        mask = inputs['attention_mask']        
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)
        }

In [14]:
test_dataset = JigsawDataset(df, CONFIG['tokenizer'], max_length=CONFIG['max_length'])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['test_batch_size'],
                         shuffle=False, pin_memory=True)

In [61]:
class JigsawModel(nn.Module):
    def __init__(self, model_name):
        super(JigsawModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name, return_dict=False)
        self.drop = nn.Dropout(p=0.2)
        self.fc = nn.Linear(768, CONFIG['num_classes'])
        
    def forward(self, ids, mask):        
        _, out = self.model(input_ids=ids,attention_mask=mask)
        out = self.drop(out)
        outputs = self.fc(out)
        return outputs

In [62]:
@torch.no_grad()
def valid_fn(model, dataloader, device):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    PREDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        
        outputs = model(ids, mask)
        PREDS.append(outputs.view(-1).cpu().detach().numpy()) 
    
    PREDS = np.concatenate(PREDS)
    gc.collect()
    
    return PREDS

In [64]:
def inference(model_paths, dataloader, device):
    final_preds = []
    for i, path in enumerate(model_paths):
        model = JigsawModel(CONFIG['model_name'])
        model.to(CONFIG['device'])
        model.load_state_dict(torch.load(path),strict=False)
        
        print(f"Getting predictions for model {i+1}")
        preds = valid_fn(model, dataloader, device)
        final_preds.append(preds)
    
    final_preds = np.array(final_preds)
    final_preds = np.mean(final_preds, axis=0)
    return final_preds

In [65]:
preds = inference(MODEL_PATHS, test_loader, CONFIG['device'])

Getting predictions for model 1


100%|██████████| 118/118 [01:02<00:00,  1.88it/s]


In [66]:
preds[:5]

array([-0.3884304 , -0.5081454 , -0.32211098, -0.35622755, -0.10511924],
      dtype=float32)

In [67]:
print(f"Total Predictiions: {preds.shape[0]}")
print(f"Total Unique Predictions: {np.unique(preds).shape[0]}")

Total Predictiions: 7537
Total Unique Predictions: 7524


In [68]:
df['score'] = preds
df.head()

Unnamed: 0,comment_id,text,score
0,114890,"""\n \n\nGjalexei, you asked about whether ther...",-0.38843
1,732895,"Looks like be have an abuser , can you please ...",-0.508145
2,1139051,I confess to having complete (and apparently b...,-0.322111
3,1434512,"""\n\nFreud's ideas are certainly much discusse...",-0.356228
4,2084821,It is not just you. This is a laundry list of ...,-0.105119


In [69]:
df['score'] = df['score'].rank(method='first')
df.head()

Unnamed: 0,comment_id,text,score
0,114890,"""\n \n\nGjalexei, you asked about whether ther...",1084.0
1,732895,"Looks like be have an abuser , can you please ...",225.0
2,1139051,I confess to having complete (and apparently b...,1874.0
3,1434512,"""\n\nFreud's ideas are certainly much discusse...",1428.0
4,2084821,It is not just you. This is a laundry list of ...,4657.0


In [70]:
df.drop('text', axis=1, inplace=True)
df.to_csv("submission.csv", index=False)