In [None]:
import json

from collections import Counter

import torch

from transformers import AutoModel, AutoTokenizer

from tqdm import tqdm

from model import SpaceModelForClassification

In [None]:
data = None
with open('data/dataset.json') as f:
    data = json.loads(f.read())

In [None]:
vocab = set()
for id in data:
    vocab |= set(data[id]['post_tokens'])

In [None]:
ids_split = None
with open('data/post_id_divisions.json') as f:
    ids_split = json.loads(f.read())

In [None]:
def encode_label(label):
    if label == 'hatespeech':
        return 0
    elif label == 'normal':
        return 1
    elif label == 'offensive':
        return 0  # 2
    raise Exception(f'Unknown Label: {label}!')


class HateDataloader:
    def __init__(self, data, ids_split, tokenizer, label_encoder, batch_size):
        def get_label(annotators):
            return Counter([anno['label'] for anno in annotators]).most_common(1)[0][0]

        self.splits = {
            'train': [data[post_id]['post_tokens'] for post_id in tqdm(data, desc='Train') if
                      post_id in ids_split['train']],
            'test': [data[post_id]['post_tokens'] for post_id in tqdm(data, desc='Test') if
                     post_id in ids_split['test']],
            'val': [data[post_id]['post_tokens'] for post_id in tqdm(data, desc='Val') if post_id in ids_split['val']]
        }

        self.labels = {
            'train': [get_label(data[post_id]['annotators']) for post_id in tqdm(data, desc='Train Labels') if
                      post_id in ids_split['train']],
            'test': [get_label(data[post_id]['annotators']) for post_id in tqdm(data, desc='Test Labels') if
                     post_id in ids_split['test']],
            'val': [get_label(data[post_id]['annotators']) for post_id in tqdm(data, desc='Val Labels') if
                    post_id in ids_split['val']]
        }

        self.tokenized = {
            'train': [tokenizer(record).unsqueeze(0) for record in
                      tqdm(self.splits['train'], desc='Train Tokenization')],
            'test': [tokenizer(record).unsqueeze(0) for record in tqdm(self.splits['test'], desc='Test Tokenization')],
            'val': [tokenizer(record).unsqueeze(0) for record in tqdm(self.splits['val'], desc='Val Tokenization')],
        }

        self.encoded_labels = {
            'train': [label_encoder(label) for label in tqdm(self.labels['train'], desc='Train Label Encoding')],
            'test': [label_encoder(label) for label in tqdm(self.labels['test'], desc='Test Label Encoding')],
            'val': [label_encoder(label) for label in tqdm(self.labels['val'], desc='Val Label Encoding')],
        }

        self.curr_batch = 0
        self.batch_size = batch_size
        self.iterate_split = None

    def peek(self, split):
        return {
            'input_ids': self.splits[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
            'label_ids': self.labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
        }

    def take(self, split):
        batch = self.splits[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        labels = self.labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        self.curr_batch += 1
        return {
            'input_ids': batch,
            'label_ids': labels,
        }

    def peek_tokenized(self, split):
        return {
            'input_ids': torch.cat(
                self.tokenized[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
                dim=0),
            'label_ids': torch.tensor(
                self.encoded_labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
                dtype=torch.long),
        }

    def take_tokenized(self, split):
        batch = self.tokenized[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        labels = self.encoded_labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        self.curr_batch += 1
        return {
            'input_ids': torch.cat(batch, dim=0),
            'label_ids': torch.tensor(labels, dtype=torch.long),
        }

    def get_split(self, split):
        self.iterate_split = split
        return self

    def steps(self, split):
        return len(self.tokenized[split])

    def __iter__(self):
        self.reset()
        return self

    def __next__(self):
        if self.batch_size * self.curr_batch < len(self.splits[self.iterate_split]):
            return self.take_tokenized(self.iterate_split)
        else:
            raise StopIteration

    def reset(self):
        self.curr_batch = 0

In [None]:
NUM_EPOCHS = 1
BATCH_SIZE = 16
MAX_SEQ_LEN = 128
LEARNING_RATE = 1e-4

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer

In [None]:
base_model = AutoModel.from_pretrained('bert-base-uncased')
base_model

In [None]:
dataloader = HateDataloader(data, ids_split, tokenizer, encode_label, batch_size=BATCH_SIZE)

In [None]:
space_model = SpaceModelForClassification()
space_model