In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt

#from sklearn import preprocessing as pp
#from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
import re
import time
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import nltk
from bs4 import BeautifulSoup
import transformers
from transformers import AdamW

#import torchvision.transforms as transforms
import torch.optim as optimizers
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import WeightedRandomSampler, BatchSampler

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [None]:
train = pd.read_csv('../input/intern-compe/train_added_encoded.csv')
test = pd.read_csv('../input/intern-compe/test_added_encoded.csv')

In [None]:
train.head()

In [None]:
def bs_get_text(html):
    return BeautifulSoup(html).get_text()


def add_html_raw(df_origin):
    df = df_origin.copy()
    df["html_raw"] = list(map(bs_get_text, df["html_content"]))
    return df

In [None]:
train = add_html_raw(train)

In [None]:
train.head()

In [None]:
"""
TO DO:

student cup の時みたいに似た様なデータが複数ある可能性も

"""

In [None]:
class HtmlDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.csv_file = csv_file
        self.transform = transform
        
    
    def __len__(self):
        return len(self.csv_file)
    
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        html_text = self.csv_file.html_raw.iloc[idx]
        try:
            label = self.csv_file.state.iloc[idx]
        except:
            label = self.csv_file.id.iloc[idx]
            
        if self.transform:
            html_text = self.transform(html_text)
            
        return html_text, label

In [None]:
#define transform
class BERT_Tokenize(object):
    def __init__(self, model_type, max_len):
        self.max_len = max_len
        
        if model_type == "BERT" or model_type == "TAPTBERT":
            from transformers import BertTokenizer, BertForSequenceClassification
            self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
            
        elif model_type == "ALBERT":
            from transformers import AlbertTokenizer, AlbertForSequenceClassification
            self.bert_tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
            
        elif model_type == "XLNET":
            from transformers import XLNetTokenizer, XLNetForSequenceClassification
            self.bert_tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
        
        elif model_type == "ROBERTA":
            from transformers import RobertaTokenizer, RobertaForSequenceClassification
            self.bert_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        
        elif model_type == "csROBERTA":
            from transformers import AutoTokenizer, AutoModel
            self.bert_tokenizer = AutoTokenizer.from_pretrained("allenai/cs_roberta_base")
            
        elif model_type == "XLMROBERTA":
            from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
            self.bert_tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
            
        elif model_type == "ELECTRA":
            from transformers import ElectraTokenizer, ElectraForSequenceClassification
            self.bert_tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")
            
    
    def __call__(self,text):
        inputs = self.bert_tokenizer.encode_plus(
                        text,                       # Sentence to encode.
                        add_special_tokens = True,  # Add '[CLS]' and '[SEP]'
                        max_length = self.max_len,  # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,  # Construct attn. masks.
                   )
        
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        del text, inputs
        return torch.LongTensor(ids), torch.LongTensor(mask)

In [None]:
#define BERT based model
class BertModule(pl.LightningModule):
    def __init__(self, model_type, num_classes=1):
        super().__init__()
        
        self.model_type = True if model_type == "csROBERTA" or model_type == "TAPTBERT" else False
        
        if model_type == "ALBERT":
            from transformers import AlbertTokenizer, AlbertForSequenceClassification
            self.base_model = AlbertForSequenceClassification.from_pretrained(
                "albert-base-v2",early_stopping=False,num_labels=num_classes)
            
        elif model_type == "BERT":
            from transformers import BertTokenizer, BertForSequenceClassification
            self.base_model = BertForSequenceClassification.from_pretrained(
                "bert-base-uncased",early_stopping=False,num_labels=num_classes)
            
        elif model_type == "XLNET":
            from transformers import XLNetTokenizer, XLNetForSequenceClassification
            self.base_model = XLNetForSequenceClassification.from_pretrained(
                "xlnet-base-cased",early_stopping=False,num_labels=num_classes)
            
        elif model_type == "ROBERTA":
            from transformers import RobertaTokenizer, RobertaForSequenceClassification
            self.base_model = RobertaForSequenceClassification.from_pretrained(
                "roberta-base",early_stopping=False,num_labels=num_classes)
            
        elif model_type == "csROBERTA":
            from transformers import AutoTokenizer, AutoModel
            self.base_model = AutoModel.from_pretrained("allenai/cs_roberta_base")
            self.classifier = nn.Sequential(
                nn.Linear(768, 768), nn.ReLU(), nn.Dropout(p=0.1),
                nn.Linear(768, 768), nn.ReLU(), nn.Dropout(p=0.1),
                nn.Linear(768, num_classes))
        
        elif model_type == "XLMROBERTA":
            from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
            self.base_model = XLMRobertaForSequenceClassification.from_pretrained(
                "xlm-roberta-base", num_labels=num_classes)
        
        elif model_type == "ELECTRA":
            from transformers import ElectraTokenizer, ElectraForSequenceClassification
            self.base_model = ElectraForSequenceClassification.from_pretrained(
                "google/electra-base-discriminator", num_labels=num_classes)
       
        elif model_type == "TAPTBERT":
            from transformers import AutoModel, AutoConfig
            config = AutoConfig.from_pretrained("../input/tapt-v2/config.json")
            self.base_model = AutoModel.from_pretrained("../input/tapt-v2/pytorch_model.bin", config=config)
            self.classifier = nn.Sequential(
                nn.Linear(768, 768), nn.ReLU(), nn.Dropout(p=0.1),
                nn.Linear(768, 768), nn.ReLU(), nn.Dropout(p=0.1),
                nn.Linear(768, num_classes))
    """
        for param in self.base_model.parameters():
            param.requires_grad = True
    """
    
    def forward(self, x):
        
        ids, mask = x
        
        if self.model_type:
            x = self.base_model(input_ids=ids, attention_mask=mask)
            x = self.classifier(x[1])
            preds = x
        else:
            x = self.base_model(input_ids=ids, attention_mask=mask, labels=None)
            preds = x[0]
            
        del ids, mask, x
        
        preds = nn.Softmax(dim=1)(preds)

        return preds
    
    
    def training_step(self, batch, batch_idx):
        x, t = batch
        pred = self.forward(x)
        loss = self.criterion(pred, t)
        acc = self.metric(pred, t)
        # you should define log as {"tag_name/log_name"}
        tensorboard_logs = {'train/train_loss': loss, "train/train_acc": acc}
        return {"loss": loss, "acc": acc, "logs": tensorboard_logs, "progress_bar": tensorboard_logs}
    
    
    def validation_step(self, batch, batch_idx):
        x, t = batch
        print(t)
        pred = self.forward(x)
        print(pred)
        loss = self.criterion(pred, t)
        acc = self.metric(pred, t)
        logs = {"val_loss": loss, "val_acc": acc}
        return {"val_loss": loss, "val_acc": acc, "progress_bar": logs}

    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        #avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        tensorboard_logs = {'val/avg_loss': avg_loss}
        # you should call back as name "val_loss" to using the Early-Stopping
        return {'val_loss': avg_loss, 'log': tensorboard_logs}
    
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=1e-5)
        """
        scheduler = {"scheduler": 
                     optimizers.lr_scheduler.CosineAnnealingLR(
                        optimizer, T_max=10),
                    "interval": "epoch",
                    "monitor": "val_loss"}
        """
        return optimizer#[optimizer], [scheduler]
    
    
    def criterion(self, pred, t):
        #return F.mse_loss(input=pred, target=t)
        return F.cross_entropy(pred,t,size_average=None,reduce=None,reduction='mean')
    
    
    def metric(self, pred, t):
        #print(pred, t)
        #pred = torch.where(pred<0.5, 0, 1)
        #return f1_score(y_true=t, y_pred=pred, average='binary', sample_weight=None, zero_division='warn')
        return 0

In [None]:
class CFhtmlDataModule(pl.LightningDataModule):
    def __init__(self, csv_file, transform, split_rate, batch_size, num_workers):
        super().__init__()
        self.csv_file = csv_file
        self.transform = transform
        self.split_rate = split_rate
        self.batch_size = batch_size
        self.num_workers = num_workers
        

    def setup(self, stage=None):
        dataset = self.csv_file
        n_samples = len(dataset)
        n_train = int(n_samples * 0.8)
        n_val = n_samples - n_train
        train_dataset, val_dataset = train_test_split(dataset,  train_size=n_train, test_size=n_val)
        
        self.train_dataset = HtmlDataset(csv_file=train_dataset, transform=self.transform)
        self.val_dataset = HtmlDataset(csv_file=val_dataset, transform=self.transform)
        
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          drop_last=True,
                          num_workers=self.num_workers,
                          pin_memory=True)
    
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          drop_last=True,
                          num_workers=self.num_workers,
                          pin_memory=True)

In [None]:
def main():
    # trainer config
    epochs = 1
    output_path = './'
    
    # data module config
    csv_file = train
    model_type = "BERT"
    max_length = 214
    bert_tokenizer = BERT_Tokenize("BERT", max_length)
    transform = bert_tokenizer
    split_rate = 0.8
    batch_size = 2
    num_workers = 4
    
    # model config
    num_classes = 1
    
    # early stopping config
    patience = 3
    
    cf = CFhtmlDataModule(csv_file, transform, split_rate, batch_size, num_workers)
    model = BertModule(model_type, num_classes)
    
    early_stopping = EarlyStopping('val_loss', patience=patience, verbose=True)
    trainer = Trainer(
            max_epochs=epochs,
            weights_save_path=output_path,
            gpus = 1 if torch.cuda.is_available() else None,
            callbacks=[early_stopping]
            #accumulate_grad_batches=1
            # use_amp=False,
        )
        
    trainer.fit(model, cf)
    #torch.cuda.empty_cache()
    # TO DO: use model.apply(weights_init) instead of torch.cuda.empty_cache()

In [None]:
main()