In [None]:
# Default Packages
import os
import sys
import pickle
import numpy as np
import pandas as pd
import os.path as path

# Torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# Using standard huggingface tokenizer for compatability
from transformers import (BertTokenizer, BertModel, AdamW, 
                          get_linear_schedule_with_warmup)

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, f1, auroc, recall, precision
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger


In [None]:
class RedditImplicit (Dataset):
    def __init__ (self, reddit_df, tokenizer, max_example_len = 512):
        # src is divided into input_ids, token_type_ids, and attention_mask
        self.src = tokenizer (
            reddit_df['text'].tolist(), 
            add_special_tokens = True, 
            truncation = True, 
            padding = "max_length", 
            return_attention_mask = True, 
            return_tensors = "pt",
            max_length = max_example_len
        )

        self.trg = reddit_df['label'].replace ({'non-suicide':0, 'suicide':1}).tolist()


    @staticmethod
    def custom_vocab_preprocessing(df):
        return df['text']

    def __getitem__(self, idx):
        return [
            tuple([self.src['input_ids'][idx], self.src['attention_mask'][idx]]),
            torch.tensor(self.trg[idx])
        ]

    def __len__(self):
        assert len(self.src['input_ids']) == len(self.trg)
        return len(self.trg)

    def __str__(self):
        return f'RedditImplicit ({self.dataset_percent*100}% of full dataset)'
        

class RedditImplicitDataModule (pl.LightningDataModule):
    def __init__(
        self, data:pd.DataFrame, 
        tokenizer, splits:list =[1], 
        max_example_len: int = 512, 
        shuffle: bool = True,
        batch_size:int = 32, 
        num_workers:int = 0
    ):
    
        super().__init__()

        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_example_len = max_example_len
        self.shuffle = shuffle
        self.num_workers = num_workers

        self.df_splits = list()
        datalen= len(data)
        for i, split_percent in enumerate(splits):
            prev_split = sum(splits[:i])

            self.df_splits.append(
                data[
                    int(prev_split*datalen):
                    int((prev_split *datalen) + 
                        (split_percent*datalen)
                    )
                ]
            )

    def setup (self, stage=None):
        self.splits = [
            RedditImplicit (
                data, 
                self.tokenizer,
                self.max_example_len
            )
            for data in self.df_splits
        ]


        if len(self.splits) <= 3:
            # complicated syntax making it possible to assign all three at once while padding
            # validset/testset if there arent enough splits to fill those values
            self.trainset, self.validset, self.testset = [
                split for split in self.splits] + [self.splits[-1]]*(3 - len(self.splits))

            self.datasets = {
                'train': self.trainset,
                'valid': self.validset,
                'test': self.testset
            }


    def train_dataloader (self):
        return DataLoader(
            self.trainset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers
        )


    def val_dataloader (self):
        return DataLoader(
            self.validset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )


    def test_dataloader (self):
        return DataLoader(
            self.testset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )


In [None]:
class SuicideClassifier (pl.LightningModule):
    def __init__ (
        self, output_classes:list = ['suicide'], 
        training_steps:int = None, 
        warmup_steps:int = 0, lr=None,
        metrics = []):
        
        super().__init__()
        
        self.training_steps = training_steps
        self.warmup_steps = warmup_steps
        self.output_classes = output_classes
        self.output_dim = len (output_classes)

        self.bert = BertModel.from_pretrained(BERT_MODEL, return_dict=True)
        self.ff = nn.Linear (self.bert.config.hidden_size, self.output_dim)
        self.output_norm = nn.Sigmoid ()

        # loss loss function
        self.criterion = nn.BCELoss()
        self.lr = lr
        
        self.metrics = metrics
        self.implemented_metrics = {
            'ROC':self.calculate_ROC,
            'binary_report':self.calculate_binary_report
        }
        

    def forward (self, input_ids, attention_mask, 
                 labels =None, normalize= True):
        """ Preforms a forward pass through the model 
            and runs loss calculations

        Args:
            input_ids (torch.tensor[N, max_example_len]): integer incoded words
            attention_mask (torch.tensor[N, max_example_len]): mask for self attention (1:unmasked, 0: masked)
            labels (torch.tensor [N]): ground truth y values for batch
        """
        # if attention_mask is None:
        #     attention_mask = torch.ones_like(input_ids)
        
        # with return_dict=True, bert outputs
        x = self.bert (input_ids, attention_mask=attention_mask)
        y_hat = self.ff (x.pooler_output)
        if normalize:
            y_hat = self.output_norm (y_hat)
        
        if self.output_dim==1:
            y_hat = torch.squeeze(y_hat)

        loss = 0
        if labels is not None:
            # print (f'y_hat type {(y_hat.dtype)}, labels type {(labels.dtype)}')
            loss = self.criterion (y_hat, labels.type (torch.float32))

        return loss, y_hat

    def _step (self, batch, step_type):
        (input_ids, attention_mask), labels = batch
        loss, output = self (input_ids, attention_mask, labels)
        self.log ('{}_loss'.format(step_type), loss, prog_bar=True, logger=True)

        return {f'loss':loss, f'output':output, f'labels': labels}

    def training_step (self, batch, batch_idx):
        values = self._step (batch, 'train')
        return values

    def validation_step (self, batch, batch_idx):
        values = self._step (batch, 'valid')
        return values

    def test_step (self, batch, batch_idx):
        values = self._step (batch, 'test')
        return values['loss']
    
    def calculate_ROC (self, preds, labels, step_type):
        for i, name in enumerate(self.output_classes):
            if self.output_dim ==1:
                # class_roc_auc = auroc (preds, labels)
                i = None

            class_roc_auc = auroc(preds[:, i], labels[:, i], pos_label=1)
            
            self.log (
                f"{name}_roc_auc/{step_type}", class_roc_auc, self.current_epoch
            )
            
    def calculate_binary_report(self, preds, labels, step_type):
        assert len(labels.shape)==1, 'binary report is reserved for output_dim==1'
        assert len(preds.shape)==1
        
        # print (f'shapes| preds: {preds.shape} labels: {labels.shape} types| preds: {preds.dtype} labels: {labels.dtype}')
        
        binary_metrics = {
            'accuracy':[accuracy],
            'f1 score':[f1, {'num_classes':1}],
            'precision':[precision],
            'recall_count':[recall]
        }
        
        for name, metric_info, in binary_metrics.items():
            kwargs = {}
            if len (metric_info) >1:
                kwargs=metric_info[1]
            
            self.log ('{}/{}'.format (name, step_type), metric_info[0](preds, labels, **kwargs))        

    
    def log_metrics (self, outputs, step_type):
        labels, preds = [], []
        
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["output"].detach().cpu():
                preds.append(out_predictions)


        labels = torch.stack (labels).int()
        preds = torch.stack (preds)

        for metric in self.metrics:
            self.implemented_metrics[metric](preds, labels, step_type)

    def training_epoch_end (self, outputs):
        self.log_metrics (outputs, 'train')
        
        return            
    
    def validation_epoch_end (self, outputs):
        self.log_metrics (outputs, 'valid')
        
        return 
            
    def configure_optimizers (self):
        optimizer = AdamW (self.parameters (), lr = self.lr)

        scheduler = get_linear_schedule_with_warmup (
            optimizer, 
            num_warmup_steps = self.warmup_steps,
            num_training_steps = self.training_steps
        )

        return {
            'optimizer':optimizer,
            'lr_scheduler':{
                'scheduler':scheduler,
                'interval':'step'
            }
        }




In [None]:
RUNNING_DIR = r'C:\Code\NLP\ProfileLevel_SI_Classifier'
datasets_dir = path.join(RUNNING_DIR, 'Datasets')

BERT_MODEL = 'bert-base-uncased'
CLASSES = ['suicidal']
BATCH_SIZE = 12
NUM_EPOCHS = 10
# LEARNING_RATE = 2e-5
LEARNING_RATE = 1.5e-5
PATIENCE = 2
MAX_EXAMPLE_LEN =100

N_EXAMPLES = 3000


# Data Loading
reddit_df = pd.read_csv (path.join (datasets_dir, 'Implicitly_Labeled_Suicide_Reddit.csv'))[['text', 'class']]
reddit_df = reddit_df.sample (N_EXAMPLES).reset_index(drop=True)
reddit_df.rename (columns= {'class':'label'}, inplace=True)

# Tokenization and Batching
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
data_module = RedditImplicitDataModule (
    reddit_df,
    tokenizer,
    splits=[0.8,0.2],
    max_example_len = MAX_EXAMPLE_LEN, 
    shuffle=True,
    batch_size=BATCH_SIZE,
)

training_steps = (len(reddit_df)//BATCH_SIZE)*NUM_EPOCHS

In [None]:
model = SuicideClassifier (
    output_classes= CLASSES,
    training_steps = training_steps,
    warmup_steps=training_steps/5,
    lr=LEARNING_RATE, 
    metrics=['ROC','binary_report']
)

In [None]:
with open (path.join(RUNNING_DIR, 'words.txt')) as f:
  display_name = '-'.join (np.random.choice ((''.join (f.readlines()).split ('\n')), size=2))
  
wandb_logger = WandbLogger(
  name = display_name,
  project="BERT Implicitly Labeled Reddit v2",
  config = {
    'model':BERT_MODEL,
    'classes':CLASSES,
    'batch_size':BATCH_SIZE,
    'num_epochs':NUM_EPOCHS,
    'learning_rate':LEARNING_RATE,
    'early_stopping_patience':PATIENCE,
    'max_example_len':MAX_EXAMPLE_LEN,
    'n_examples':N_EXAMPLES,
  }
)

save_dir = path.join (RUNNING_DIR, 'model_checkpoints', display_name)
os.makedirs(save_dir)
early_stopping_callback = EarlyStopping(
  monitor='valid_loss', patience=PATIENCE)
checkpoint_callback = ModelCheckpoint(
  dirpath=save_dir,
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="valid_loss",
  mode="min"
)


In [None]:
trainer = pl.Trainer(
  logger=wandb_logger,
  checkpoint_callback=checkpoint_callback,
  callbacks=[early_stopping_callback],
  max_epochs=NUM_EPOCHS,
  gpus=1,
  progress_bar_refresh_rate=30
)

In [None]:
trainer.fit(model, data_module)


In [None]:
trainer.test()


In [None]:
loaded_model = SuicideClassifier.load_from_checkpoint(
  path.join (RUNNING_DIR, 'model_checkpoints', 'walk-antiques', 'best-checkpoint.ckpt'),
)

twitter_trainer = pl.Trainer(
  logger=wandb_logger,
  checkpoint_callback=checkpoint_callback,
  callbacks=[early_stopping_callback],
  max_epochs=NUM_EPOCHS,
  gpus=1,
  progress_bar_refresh_rate=30
)

