# Wide Multi-Level Perceptron (Baseline)

WideMLP from Diera et al.(2022)

In [None]:
import json
import torch
from tqdm import tqdm, trange
from transformers import (AdamW, AutoTokenizer, get_linear_schedule_with_warmup)
from sklearn.metrics import f1_score, accuracy_score, classification_report
import numpy as np
import pandas as pd
from math import ceil
from collections import Counter
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import re

In [None]:
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

logger = logging.getLogger(__name__)

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import wandb
WANDB = True

In [None]:
model_id = "WideMLP_2s_32_0.1" #for filenames
wandb_id = "WideMLP_2s_32_0.1_0.2"
wandb_project = "section-clf-multisen"

In [None]:
# parameters
tokenizer_name = "bert-base-uncased"

INPUT_PATH='/media/nvme3n1/proj_scisen/datasets/SciSections_sentences.jsonl'
MODEL_PATH='/media/nvme3n1/proj_scisen/models/MLSC/'
RESULTS_PATH='/media/nvme3n1/proj_scisen/results/MLSC/'

LAMBDA=0.2
BATCH_SIZE=32
EPOCHS=100
LEARNING_RATE = 0.1

In [None]:
include_appendices = False
included_conferences = ['as','a','b','c']
context_width = 2

test_splits = [['as','a','b','c']] # [ [['as','a','b','c']] , [['as'],['a'],['b','c']] ] test data split by conference rank

validate = False

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
#placeholder; move/delete later TODO

#embedding = None
#do_truncate = False
#max_length = None # we only compute dataset stats including length, so NOT truncate

In [None]:
# set up the decice for GPU usage
torch_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

### set up WandB

In [None]:
if WANDB:
    wandb.init(project=wandb_project, name = wandb_id, config={"epochs": EPOCHS, "context_width": context_width, "validation": validate, "batch_size": BATCH_SIZE, "learning_rate": LEARNING_RATE, "lambdas": LAMBDA, "trainingdata":"full", "conferences": "asabc"})

## load data:

In [None]:
paragraphList = list()
with open(INPUT_PATH) as f:
    for paragraph in f:
        paragraphDict = json.loads(paragraph)
        paragraphList.append(paragraphDict)

In [None]:
def context(sentences, sentence_idx, k_context):   
### method of choosing sentence context ###
#use only consecutive sentences within one paragraph (no [removed] tokens or end of paragraph):
#k=1: no context, k=2: predecessor+sentence, k = 3: predecessor+sentence+successor, k=2n: n predecessors+S+(n-1) sucessors, k=2n+1: n predecessors+S+n sucessors
#kick if context is empty or not of desired length or contains [removed] sentence
    k=k_context
    context = []
    start = sentence_idx - k//2
    stop = sentence_idx + k//2 if k%2 == 0 else sentence_idx + k//2 + 1
    if start >= 0 and stop <= len(sentences):
        context = sentences[start : stop]    
    else:
        return ""
    if "[removed]" in context:
        return ""
    return " ".join([str(item) for item in context])

In [None]:
inputList = []
for par in tqdm(paragraphList):
    if par['section_category'] == ['appendix'] and not include_appendices:
        continue
    if not par['rank'] in included_conferences:
        continue
    for idx, sentence in enumerate(par['sentences']):
        if not (sentence == "[removed]"):
            #text_target = sentence #add context here
            #text_context = context(par['sentences'],idx,context_width) 
            text_target = context(par['sentences'],idx,context_width) 
            #inputList.append({**{'text_target': text_target, 'text_context': text_context}, **label})
            inputList.append({**{'text': text_target}, 'labels':par['section_category'], 'rank': par['rank']})

In [None]:
df_input = pd.DataFrame(inputList)

In [None]:
df_input=df_input.mask(df_input == '')
df_input = df_input[~df_input['text'].isnull()]

In [None]:
#smaller datasets for (pre-)experiments
N = len(df_input)

#random 10 %
#df_input = df_input.sample(n=ceil(N*0.1), random_state = 42)
#df_input = df_input.sample(n=10000, random_state = 42)
df_input

In [None]:
# split: 70/20/10
train_size = 0.7
valid_size = 0.2

#split into train, test:
trainvalid_df=df_input.sample(frac= train_size + valid_size ,random_state=200)
test_dataset=df_input.drop(trainvalid_df.index).reset_index(drop=True)
trainvalid_df = trainvalid_df.reset_index(drop=True)

#split test set by conference rank
test_datasets = dict()
for split in test_splits:
    test_datasets[str(split)]=test_dataset[test_dataset['rank'].isin(split)].reset_index(drop=True)

#split into train, validation:
train_dataset=trainvalid_df.sample(frac= train_size/(train_size + valid_size) ,random_state=200)
valid_dataset=trainvalid_df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(df_input.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("VALIDATION Dataset: {}".format(valid_dataset.shape))
print("FULL TEST Dataset: {}".format(test_dataset.shape))

for split in test_splits:
    print("TEST Dataset ranks {}: {}".format(str(split), test_datasets[str(split)].shape))

In [None]:
train_data = np.array(train_dataset['text'].values.tolist(), dtype=object)
valid_data = np.array(valid_dataset['text'].values.tolist(), dtype=object)
raw_documents = np.append(train_data, valid_data)
test_data_dict = dict()
for split in test_splits:
    test_data_dict[str(split)] = np.array(test_datasets[str(split)]['text'].values.tolist(), dtype=object)
    raw_documents = np.append(raw_documents, test_data_dict[str(split)])
N = len(raw_documents)

print("Loading document metadata...")
train_labels = np.array(train_dataset['labels'].values.tolist(), dtype=object)
test_labels_dict = dict()
valid_labels = np.array(valid_dataset['labels'].values.tolist(), dtype=object)
labels = np.append(train_labels.data, valid_labels.data)
for split in test_splits:
    test_labels_dict[str(split)] = np.array(test_datasets[str(split)]['labels'].values.tolist(), dtype=object)
    labels = np.append(labels, test_labels_dict[str(split)].data)
labels = np.concatenate(labels, axis=0)

unique_labels = list(Counter(labels).keys())

labels_in_order = ['introduction',
        'related work',
        'method',
        'experiment',
        'result',
        'discussion',
        'conclusion']
if include_appendices:
    labels_in_order.append('appendix')

print(f"Encoding documents without max_length")
enc_docs = [tokenizer.encode(raw_doc) for raw_doc in raw_documents]

print("Encoding labels...")
label2index = {label: idx for idx, label in enumerate(labels_in_order)}

enc_labels = []
idx = 0

train_mask, valid_mask = torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool)

for array in train_labels:
    label_names = np.array(array)
    array_ids = np.empty(len(unique_labels))
    array_ids.fill(0)
    train_mask[idx] = True
    idx += 1
    for label_name in label_names:
        for label, index in label2index.items():
            if label_name == label:
                array_ids[index] = 1
    enc_labels.append(array_ids)
        
for array in valid_labels:
    label_names = np.array(array)
    array_ids = np.empty(len(unique_labels))
    array_ids.fill(0)
    valid_mask[idx] = True
    idx += 1
    for label_name in label_names:
        for label, index in label2index.items():
            if label_name == label:
                array_ids[index] = 1
    enc_labels.append(array_ids)

test_mask_dict = dict()
for split in test_splits:
    test_mask_dict[str(split)] = torch.zeros(N, dtype=torch.bool)
    for array in test_labels_dict[str(split)]:
        label_names = np.array(array)
        array_ids = np.empty(len(unique_labels))
        array_ids.fill(0)
        test_mask_dict[str(split)][idx] = True
        idx += 1
        for label_name in label_names:
            for label, index in label2index.items():
                if label_name == label:
                    array_ids[index] = 1
        enc_labels.append(array_ids)

In [None]:
lens = np.array([len(doc) for doc in enc_docs])
print("Min/max document length:", (lens.min(), lens.max()))
print("Mean document length: {:.4f} ({:.4f})".format(lens.mean(), lens.std()))
enc_docs_arr, enc_labels_arr = np.array(enc_docs, dtype='object'), np.array(enc_labels)

train_data = list(zip(enc_docs_arr[train_mask], enc_labels_arr[train_mask]))
valid_data = list(zip(enc_docs_arr[valid_mask], enc_labels_arr[valid_mask]))
test_data_dict = dict()
for split in test_splits:
    test_data_dict[str(split)] = list(zip(enc_docs_arr[test_mask_dict[str(split)]], enc_labels_arr[test_mask_dict[str(split)]]))


print("N", len(enc_docs))
print("N train", len(train_data))
print("N valid", len(valid_data))
for split in test_splits:
    print("N test ranks {}".format(str(split)), len(test_data_dict[str(split)]))
print("N classes", len(label2index))

# Train and evaluate the Model

In [None]:
# MLP model


import torch
import torch.nn as nn
import torch.nn.functional as F
import tokenizers


def collate_for_mlp(list_of_samples):
    """ Collate function that creates batches of flat docs tensor and offsets """
    offset = 0
    flat_docs, offsets, labels = [], [], []
    for doc, label in list_of_samples:
        if isinstance(doc, tokenizers.Encoding):
            doc = doc.ids
        offsets.append(offset)
        flat_docs.extend(doc)
        labels.append(label)
        offset += len(doc)
    return torch.tensor(np.array(flat_docs)), torch.tensor(np.array(offsets)), torch.tensor(np.array(labels))


class MLP(nn.Module):
    """Simple MLP"""

    def __init__(self, vocab_size, num_classes,
                 num_hidden_layers=1,
                 hidden_size=1024, hidden_act='relu',
                 dropout=0.5, idf=None, mode='mean',
                 pretrained_embedding=None, freeze=True,
                 embedding_dropout=0.5):
        nn.Module.__init__(self)
        # Treat TF-IDF mode appropriately
        mode = 'sum' if idf is not None else mode
        self.idf = idf

        # Input-to-hidden (efficient via embedding bag)
        if pretrained_embedding is not None:
            # vocabsize is defined by embedding in this case
            self.embed = nn.EmbeddingBag.from_pretrained(pretrained_embedding, freeze=freeze, mode=mode)
            embedding_size = pretrained_embedding.size(1)
            self.embedding_is_pretrained = True
        else:
            assert vocab_size is not None
            self.embed = nn.EmbeddingBag(vocab_size, hidden_size, mode=mode)
            embedding_size = hidden_size
            self.embedding_is_pretrained = False

        self.activation = getattr(F, hidden_act)
        self.embedding_dropout = nn.Dropout(embedding_dropout)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList()

        # Hidden-to-hidden
        for i in range(num_hidden_layers - 1):
            if i == 0:
                self.layers.append(nn.Linear(embedding_size, hidden_size))
            else:
                self.layers.append(nn.Linear(hidden_size, hidden_size))

        # Hidden-to-output
        self.layers.append(nn.Linear(hidden_size if self.layers else embedding_size, num_classes))

        # Loss function
        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, input, offsets, labels=None):
        # Use idf weights if present
        idf_weights = self.idf[input] if self.idf is not None else None

        h = self.embed(input, offsets, per_sample_weights=idf_weights)

        if self.idf is not None:
            # In the TF-IDF case: renormalize according to l2 norm
            h = h / torch.linalg.norm(h, dim=1, keepdim=True)

        if not self.embedding_is_pretrained:
            # No nonlinearity when embedding is pretrained
            h = self.activation(h)

        h = self.embedding_dropout(h)

        for i, layer in enumerate(self.layers):
            # at least one
            h = layer(h)
            if i != len(self.layers) - 1:
                # No activation/dropout for final layer
                h = self.activation(h)
                h = self.dropout(h)

        if labels is not None:
            loss = self.loss_function(h, labels)
            return loss, h
        return h


In [None]:
GRADIENT_ACCUMULATION_STEPS = 1
LOGGING_STEPS = 50

In [None]:
print("Initializing MLP")

vocab_size = tokenizer.vocab_size

model = MLP(vocab_size, len(label2index))

model.to(torch_device)

if WANDB:
    wandb.watch(model, log_freq=LOGGING_STEPS)

In [None]:
def loss_plot(epochs, loss):
    plt.plot(epochs, loss, color='red', label='loss')
    plt.xlabel("epochs")
    plt.title("validation loss")
    plt.savefig("val_loss.png")

### train

In [None]:
collate_fn = collate_for_mlp
train_loader = torch.utils.data.DataLoader(train_data,
                                               collate_fn=collate_fn,
                                               shuffle=True,
                                               batch_size=BATCH_SIZE,
                                               num_workers=4,
                                               pin_memory=('cuda' in torch_device))

valid_loader = torch.utils.data.DataLoader(valid_data,
                                               collate_fn=collate_fn,
                                               shuffle=True,
                                               batch_size=BATCH_SIZE,
                                               num_workers=4,
                                               pin_memory=('cuda' in torch_device))

# len(train_loader) no of batches
t_total = len(train_loader) // GRADIENT_ACCUMULATION_STEPS * EPOCHS

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-8)
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,
                                                num_training_steps=t_total)
writer = SummaryWriter()

# Train!
logger.info("***** Running training *****")
logger.info("  Num examples = %d", len(train_data))
logger.info("  Num Epochs = %d", EPOCHS)
logger.info("  Batch size  = %d", BATCH_SIZE)
logger.info("  Total train batch size (w. accumulation) = %d",
            BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
logger.info("  Gradient Accumulation steps = %d", GRADIENT_ACCUMULATION_STEPS)
logger.info("  Total optimization steps = %d", t_total)

global_step = 0
loss_vals = []
tr_loss, logging_loss, nb_val_steps, vl_loss = 0.0, 0.0, 0.0, 0.0
model.zero_grad()
#train_iterator = trange(EPOCHS, desc="Epoch")
for epoch in range(EPOCHS):
    print("Epoch: ", epoch)
    for step, batch in enumerate(tqdm(train_loader, desc="Iteration")):
        model.train()
        batch = tuple(t.to(torch_device) for t in batch)
        outputs = model(batch[0], batch[1], batch[2])
        loss = outputs[0]
        if GRADIENT_ACCUMULATION_STEPS > 1:
            loss = loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()
        tr_loss += loss.item()
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1
            if WANDB:
                wandb.log({'epoch': epoch,
                            'lr': scheduler.get_last_lr()[0],
                            'loss': loss})

        if LOGGING_STEPS > 0 and global_step % LOGGING_STEPS == 0:
            avg_loss = (tr_loss - logging_loss) / LOGGING_STEPS
            writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step)
            writer.add_scalar('loss', avg_loss, global_step)
            logging_loss = tr_loss

    for step, batch in enumerate(tqdm(valid_loader, desc="Validating")):
        model.eval()
        batch = tuple(t.to(torch_device) for t in batch)
        with torch.no_grad():
            outputs = model(batch[0], batch[1], batch[2])

        nb_val_steps += 1
        loss, logits = outputs[:2]
        vl_loss += loss.mean().item()

    vl_loss /= nb_val_steps
    loss_vals.append(vl_loss)
    if WANDB:
        wandb.log({'val_loss': vl_loss})

writer.close()
loss_plot(np.linspace(1, EPOCHS, EPOCHS).astype(int), loss_vals)
#return global_step, tr_loss / global_step

### evaluate

In [None]:
#model.load_state_dict(torch.load(MODEL_PATH+model_id))

In [None]:
# evaluate
#test specific model:
model.load_state_dict(torch.load(MODEL_PATH+model_id))

collate_fn = collate_for_mlp

if validate:
    eval_data_dict = {"validation data": valid_data}
else:
    eval_data_dict = test_data_dict

for evalset in eval_data_dict:
    print(f"--------------- {evalset} ---------------")
    if (WANDB and len(eval_data_dict)>1):
        wandb.init(project=wandb_project, name = wandb_id+"_"+re.sub('[\W_]+', '', evalset), config={"epochs": EPOCHS, "context_width": context_width, "validation": validate, "batch_size": BATCH_SIZE, "learning_rate": LEARNING_RATE, "lambdas": LAMBDA, "trainingdata":"full", "conferences": re.sub('[\W_]+', '', evalset)})
    data_loader = torch.utils.data.DataLoader(eval_data_dict[evalset],
                                                  collate_fn=collate_fn,
                                                  num_workers=4,
                                                  batch_size=BATCH_SIZE,
                                                  pin_memory=('cuda' in str(torch_device)),
                                                  shuffle=False)
    all_logits = []
    all_targets = []
    nb_eval_steps, eval_loss = 0, 0.0
    for batch in tqdm(data_loader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(torch_device) for t in batch)
        with torch.no_grad():
            # batch consist of (flat_inputs, lengths, labels)
            outputs = model(batch[0].to(torch.long), batch[1].to(torch.long), batch[2].to(torch.float))
            all_targets.append(batch[2].detach().cpu())

        nb_eval_steps += 1
        # outputs [:2] should hold loss, logits
        loss, logits = outputs[:2]
        eval_loss += loss.mean().item()
        all_logits.append(logits.detach().cpu())

    logits = torch.cat(all_logits)
    logits = torch.sigmoid(logits)
    logits = logits.numpy()
    targets = torch.cat(all_targets).numpy()
    eval_loss /= nb_eval_steps
    logits[logits >= LAMBDA] = 1
    logits[logits < LAMBDA] = 0
    preds = logits
    acc = accuracy_score(targets, preds)

    f1_samples = f1_score(targets, preds, average='samples')
    f1_micro = f1_score(targets, preds, average='micro')
    f1_macro = f1_score(targets, preds, average='macro')

    if WANDB:
        wandb.log({"test/acc": acc, "test/loss": eval_loss,
                    "test/f1_samples": f1_samples,
                    "test/f1_micro": f1_micro,
                    "test/f1_macro": f1_macro})

    print(f"Test accuracy: {acc:.4f}, f1_samples: {f1_samples:.4f}, f1_micro: {f1_micro:.4f}, f1_macro: {f1_macro:.4f}, Eval loss: {eval_loss}")

In [None]:
with open(RESULTS_PATH+model_id+'_results.txt', 'w') as f:
        print(f"F1 Score (Samples) = {f1_samples}",f"Accuracy Score = {acc}",f"F1 Score (Micro) = {f1_micro}",f"F1 Score (Macro) = {f1_macro}", file=f)

torch.save(model.state_dict(), MODEL_PATH+model_id)
print("saved as "+MODEL_PATH+model_id)

#### further evaluation

In [None]:
classification_report = classification_report(
    targets,
    preds,
    output_dict=False,
    target_names= labels_in_order,
    digits = 4
)
with open(RESULTS_PATH+model_id+'_results.txt', 'a') as f:
        print(classification_report)

In [None]:
def label_totitles(labels): #labels should be list of int or convertible to int
    labelstring = str([int(l) for l in labels])
    labellist = labelstring.strip('][').split(', ')
    new_label = "" 
    for idx, label in enumerate(labellist):
        if (label == "1"):
            if (new_label == ""):
                new_label += labels_in_order[idx]
            else:
                new_label = new_label + ", " + labels_in_order[idx]
    return new_label 

df_labels = pd.DataFrame()
df_labels["label"] = [label_totitles(t) for t in targets]
df_labels["pred"] = [label_totitles(p) for p in preds]
df_labels["pred"].value_counts()

In [None]:
# confusion plots
for title in set(df_labels["label"]):
    #print("------"+title+"------")
    df_temp = df_labels[df_labels["label"] == title]
    crosstab = pd.crosstab(df_temp["label"], df_temp["pred"], rownames=["label"], colnames=["pred"])
    #print(crosstab)
    ax = crosstab.plot.bar(rot=0, figsize = (25, 10), width = 0.75)
    for container in ax.containers:
        ax.bar_label(container)

In [None]:
#examples
unique_preds = list(set(df_labels["pred"]))

if validate:
    eval_dataset = valid_dataset
else:
    eval_dataset = test_dataset

for category in unique_preds:
    print(f"----- {category} -----")
    df_cat = df_labels[df_labels['pred'] == category]
    print(df_cat["label"].value_counts())
    df_cat = df_cat.sample(n = min(5, len(df_cat)))
    for idx in df_cat.index.values:
        print(f"~~~ true label: {df_cat.loc[idx]['label']} ~~~")
        #print(fin_outputs[idx])
        print(eval_dataset.iloc[idx]["text"])
    print()