# Screenplay annotation task (token- or row-wise). Labels:

- scene_heading (0)
- text (1)
- speaker_heading (2)
- dialog (3)

## Training setup:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install mlflow --quiet
!pip install pyngrok --quiet

In [None]:
import os
import re
from time import time 
from transformers import BertTokenizer, get_cosine_schedule_with_warmup
from transformers import BertForSequenceClassification, AdamW, BertConfig, \
DistilBertForSequenceClassification,get_linear_schedule_with_warmup,\
BertForTokenClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from matplotlib import pyplot as plt 
import seaborn as sn
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.sampler import WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence
from itertools import chain
import torch
import numpy as np
import datetime
import pandas as pd
import csv
import argparse
import mlflow
from pyngrok import ngrok

In [None]:
config = { 
        'paths': { 
            'logs_dir':'/content/drive/MyDrive/NLP/Movie scripts models/BERTAnno/logs',
            'ckpt_dir': '/content/drive/MyDrive/NLP/Movie scripts models/BERTAnno/ckpts',
            'model_annotations': '/content/drive/MyDrive/NLP/Movie scripts dataset/Movie scripts and annotations/Script annotations by BERT',
            'manual_annotations': '/content/drive/MyDrive/NLP/Movie scripts dataset/Movie scripts and annotations/Script manual annotations',
            'mlruns': '/content/drive/MyDrive/NLP/Movie scripts models/BERTAnno/mlruns'
        },
        'train': {
                'optim' : {
                    'AdamW':{
                        'lr':1e-5,
                        'eps': 1e-8,
                        'weight_decay':0.0001
                    }
                },
                'num_classes' : 4,
                'nrof_steps' : 500, 
                'tr_batch_size' : 8,
                'tst_batch_size' : 8,
                'exp_name':'row_classification', # from ['row_classification', 'token_classification']
                'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
                'pretrained_model_type': 'bert-base-cased',
                'heading_to_class_map': {'scene_heading':0, 'text':1, 'speaker_heading':2, 'dialog':3},
                'class_to_heading_map': {0:'scene_heading', 1:'text', 2:'speaker_heading', 3:'dialog'}
                }
        }

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    config['train']['pretrained_model_type'], do_lower_case=False)


## Dataset preprocesing:


In [None]:
def make_tokenized_rows(rows, tokenizer):
    input_ids = []

    for row in rows:
        input_ids.append(tokenizer.encode(row))

    return input_ids

def make_tokenized_chunks_labels_from_rows(rows, labels, tokenizer, 
                                           chunk_size=512):
    chunks_ids, new_labels = [[]], [[]]

    for i, (row, label) in enumerate(zip(rows, labels)):
        row_input_ids = tokenizer.encode(row)
        
        row_labels = [label] * len(row_input_ids)
    
        if (len(chunks_ids[-1]) + len(row_input_ids)) < chunk_size:
            chunks_ids[-1].extend(row_input_ids)
            new_labels[-1].extend(row_labels)
        else:
            chunks_ids.append(row_input_ids)
            new_labels.append(row_labels)
    
    return chunks_ids, new_labels  

def prepare_inputs_labels(inputs, labels):
    inputs = [torch.tensor(input) for input in inputs]
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    if config['train']['exp_name'] == 'row_classification':
        labels = torch.LongTensor(labels)
    else:
        labels = [torch.tensor(label) for label in labels]
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) 
    return inputs, labels 

In [None]:
class AnnoData:
    '''
    Reads and prepares annotations
    '''
    def __init__(self, tokenizer, config):
        self.config = config 
        anno_docs = self.read_all_annos()
        self.labeled_rows, self.init_rows = self.make_labeled_rows(anno_docs)
        self.tokenizer = tokenizer

    def read_all_annos(self):
        anno_docs = []
        manual_annotations_path = self.config['paths']['manual_annotations']

        for anno in os.listdir(manual_annotations_path):
            with open(os.path.join(manual_annotations_path, anno), 'r') as f:
                anno_text = f.read()
                anno_docs.append(anno_text)

        return anno_docs

    def make_labeled_rows(self, anno_docs):
        '''
        Gets labeled rows from docs (removes existing label from string and gathers it to labels separately)
        **** Fixed, but not used ****
        '''
        labeled_rows, init_rows = [], []
        heading_to_class_map = self.config['train']['heading_to_class_map']

        for anno_text in anno_docs:
            anno_text = re.sub('\n+','\n', anno_text)
            anno_text_splited = anno_text.split('\n')
            for row in anno_text_splited:
                init_rows.append(row)
                row_begin = row[:row.find(':')]
                if row_begin in heading_to_class_map:
                    label = heading_to_class_map[row_begin]
                row = row.replace(row_begin + ':', '')
                row = re.sub(' +',' ', row)
                row = re.sub('\t','', row).strip()
                labeled_rows.append((row, label))

        return labeled_rows, init_rows

    def get_inputs_labels(self):
        rows, labels = zip(*self.labeled_rows)
        if self.config['train']['exp_name'] == 'row_classification':
            return make_tokenized_rows(rows, self.tokenizer), labels
        else:
            return make_tokenized_chunks_labels_from_rows(rows, labels, self.tokenizer)

    def get_train_val_split(self):
        inputs, labels = self.get_inputs_labels()
        tr_inputs, val_inputs, tr_labels, val_labels = train_test_split(inputs, labels,
                                                                     test_size=0.3,
                                                                     random_state=11)

        if self.config['train']['exp_name'] == 'row_classification':
            self.row_weights, self.weight_per_class = self.make_weights_for_balanced_classes(
                tr_labels)
        else:
            self.row_weights, self.weight_per_class = None, None

        if self.config['train']['exp_name'] == 'row_classification':
            self.log_dataset_info() 

        tst_inputs, val_inputs, tst_labels, val_labels = train_test_split(
            val_inputs, val_labels, test_size=0.5, random_state=11)
        print('Train size:{}\nVal size:{}\nTest size:{}'.format(
            len(tr_inputs), len(val_inputs), len(tst_inputs)))
        
        mlflow.log_param('train_size', len(tr_inputs))
        mlflow.log_param('val_size', len(val_inputs))
        mlflow.log_param('test_size', len(tst_inputs))

        tr_inputs, tr_labels = prepare_inputs_labels(tr_inputs, tr_labels)
        val_inputs, val_labels = prepare_inputs_labels(val_inputs, val_labels)
        tst_inputs, tst_labels = prepare_inputs_labels(tst_inputs, tst_labels)
        
        return tr_inputs, val_inputs, tst_inputs, tr_labels, val_labels, tst_labels

    def make_weights_for_balanced_classes(self, labels):     
        nclasses = self.config['train']['num_classes']
        count = [0] * nclasses

        for label in labels:                                                         
            count[label] += 1                                                     
        weight_per_class = [0.] * nclasses       

        N = float(sum(count))                                                   
        for i in range(nclasses):                                                   
            weight_per_class[i] = N/float(count[i])                                 
        weight = [0] * len(labels)                                              
        for idx, label in enumerate(labels):                                          
            weight[idx] = weight_per_class[label]

        return weight, weight_per_class 
    
    def log_dataset_info(self):
        print('Classes distribution:')
        classes_info = []

        for i, weight in enumerate(self.weight_per_class):
            info_str = 'Class {}: {:.2f} of all rows'.format(
                self.config['train']['class_to_heading_map'][i], 1./weight)
            print(info_str)
            classes_info.append(info_str)

        classes_info = '\n'.join(classes_info)

        with open('classes_info.txt', 'w') as f:
            f.write(classes_info)

        mlflow.log_artifact('classes_info.txt')

    def show_dataset(self, to_save=False):
        dataset_df = {'label':[], 'class':[], 'tokens':[], 'text':[]}
        inputs, labels = self.get_inputs_labels()

        for input, label in zip(inputs, labels):
            text = self.tokenizer.convert_ids_to_tokens(input)
            string = self.tokenizer.convert_tokens_to_string(text)
            dataset_df['class'].append(label)

            if isinstance(label, list)
                heading =  self.config['train']['class_to_heading_map'][label[0]]
            else:
                heading =  self.config['train']['class_to_heading_map'][label]

            dataset_df['label'].append(heading)
            dataset_df['tokens'].append(text)
            dataset_df['text'].append(string)
            
        dataset_df = pd.DataFrame(dataset_df) 
        
        if to_save:
            dataset_df.to_excel('manual_annotations_dataset_' + self.data_type+'.xlsx', 
                                engine='xlsxwriter',
                                index=False)       
        

In [None]:
def get_dataloader(input_ids, labels, batch_size=32,  #attention_masks
                   phase='train', sampler=None):
        dataset = TensorDataset(input_ids, labels)
        if phase=='train':
            sampler = sampler if not sampler is None else RandomSampler(dataset)
            dataloader = DataLoader(
                        dataset,  
                        batch_size = batch_size,
                        sampler = sampler 
                    )
        else:
            dataloader = DataLoader(
                        dataset,  
                        batch_size = 128 
                    )

        return dataloader 

In [None]:
def get_data_loaders():
    AD = AnnoData(tokenizer, config)
    tr_inputs, val_inputs, tst_inputs, tr_labels, val_labels, tst_labels = AD.get_train_val_split()

    if config['train']['exp_name'] == 'row_classification':
        sampler = WeightedRandomSampler(AD.row_weights, len(AD.row_weights))      
    else:
        sampler = None 

    tr_loader = get_dataloader(tr_inputs, tr_labels, # tr_attention_masks
                            batch_size=config['train']['tr_batch_size'],
                            sampler=sampler)
    val_loader = get_dataloader(val_inputs, val_labels, # val_attention_masks
                                batch_size=config['train']['tst_batch_size'])
    tst_loader = get_dataloader(tst_inputs, tst_labels, # tst_attention_masks
                                batch_size=config['train']['tst_batch_size'])
    
    return tr_loader, val_loader, tst_loader


## Show and save dataset:

In [None]:
!pip install xlsxwriter

In [None]:
AD = AnnoData(tokenizer, config, data_type='chunks')
AD.show_dataset(to_save=True)

In [None]:
AD = AnnoData(tokenizer, config)
AD.show_dataset(to_save=True)

## Train:

In [None]:
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def prepare_row(row):
    encoded_dict = tokenizer.encode_plus(row,                     
                                         add_special_tokens = True, 
                                         max_length = 64,           
                                         pad_to_max_length = True,
                                         return_attention_mask = True,  
                                         return_tensors = 'pt')
    
    input_ids=[encoded_dict['input_ids']]
    attention_masks=[encoded_dict['attention_mask']]
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

def plot_conf_matr(results):
    classes = [config['train']['class_to_heading_map'][x] for x in range(len(results))]
    df_cm = pd.DataFrame(results.astype(np.int), index = classes, columns = classes)
    plt.figure(figsize = (7,7))    
    ax = sn.heatmap(df_cm, annot=True, fmt='d')
    ax.set_title('Confusion matrix (test accuracy: {:.2f})'.format(
        float(np.diagonal(results).sum()) / results.sum()))
    plt.savefig('conf_matrix.png', bbox_inches='tight')
    plt.close()


In [None]:
class Train():
    def __init__(self, config):
        self.config = config
        if self.config['train']['exp_name'] == 'row_classification':
            self.model = BertForSequenceClassification.from_pretrained(
                "bert-base-cased", 
                num_labels = self.config['train']['num_classes'], 
                output_attentions = False, 
                output_hidden_states = False)
        else:
            self.model = BertForTokenClassification.from_pretrained(
                "bert-base-cased", 
                num_labels = self.config['train']['num_classes'], 
                output_attentions = False, 
                output_hidden_states = False)

        opt_config = self.config['train']['optim']['AdamW']

        for key, val in opt_config.items():

            mlflow.log_param(key, val)
        mlflow.log_param('nrof_classes', self.config['train']['num_classes'])

        self.total_steps = self.config['train']['nrof_steps']
        self.optimizer = AdamW(self.model.parameters(),
                  lr = opt_config['lr'], 
                  eps = opt_config['eps'], 
                  weight_decay=opt_config['weight_decay'])        
        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = 10, 
                                            num_training_steps = self.total_steps)
        self.model.to(self.config['train']['device'])
        self.training_stats = []
        self.global_step = 0
        mlflow.log_param('total_steps', self.total_steps)
    
    def save_model(self):
        torch.save({"model": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "scheduler": self.scheduler.state_dict(),
                    },
                   os.path.join(self.config['paths']['ckpt_dir'], 
                                self.config['train']['exp_name'] + '_checkpoint'))
    
    def load_model(self):
        ckpt = torch.load(os.path.join(self.config['paths']['ckpt_dir'],
                                       self.config['train']['exp_name'] + '_checkpoint'),
                          map_location=self.config['train']['device'])
        self.global_step = ckpt["step"] + 1
        model_st_dict = ckpt["model"]
        self.model.load_state_dict(model_st_dict)   
        self.optimizer.load_state_dict(ckpt["optimizer"])
        self.scheduler.load_state_dict(ckpt["scheduler"])
        print("Model loaded...")


    def train(self, train_dataloader, validation_dataloader, to_save=True):
        t0 = time()
        tr_losses, val_losses = [], []
        cur_loss, nrof_steps, nrof_samples, nrof_cor_predicts = 0., 0, 0, 0
        if self.config['train']['exp_name'] == 'token_classification':
            nrof_rows_cor_predicts, nrof_row_samples = 0,0
    
        while self.global_step < self.total_steps:
            for step, batch in enumerate(train_dataloader):
                self.model.train()
                b_input_ids = batch[0].to(self.config['train']['device'])
                b_labels = batch[1].to(self.config['train']['device'])
                 
                self.model.zero_grad()  
                outputs = self.model(b_input_ids,
                        token_type_ids=None, 
                    labels=b_labels,
                    return_dict=True)  
                
                cur_loss += outputs.loss.item()
                _, predicted = torch.max(outputs.logits,-1)
                c = (predicted == b_labels)
                nrof_cor_predicts += c.sum().item()
                
                outputs.loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                self.scheduler.step()
                self.global_step+=1
                nrof_steps+=1
                nrof_samples+=len(b_labels)

                
                if self.global_step % 10 == 0:
                    elapsed = format_time(time() - t0)
                    print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(self.global_step, len(train_dataloader), elapsed))
                    val_acc,  val_loss =\
                     self.validate(validation_dataloader)
                    
                    print('val loss:', val_loss)
                    mlflow.log_metric("val_loss", val_loss)
                    mlflow.log_metric("val_accuracy", val_acc)            
                    mlflow.log_metric("train_loss", cur_loss / nrof_steps)
                    mlflow.log_metric("train_accuracy", float(nrof_cor_predicts)/nrof_samples)
                    print('tr loss: {}\n'.format(cur_loss/nrof_steps))

                    cur_loss, nrof_steps = 0., 0
                    nrof_cor_predicts, nrof_samples =0,0
            training_time = (time() - t0)
        
        if to_save:
            self.save_model()

    def validate(self, validation_dataloader, 
                 to_calc_conf_matr=False, to_load=False):
        if to_load:
            self.load_model()
        t1 = time()
        self.model.eval()
        val_loss, nrof_cor_predicts, nrof_samples = 0., 0, 0

        if to_calc_conf_matr:
            conf_matr = np.zeros((self.config['train']['num_classes'],
                                  self.config['train']['num_classes']))

        for batch in validation_dataloader:
            with torch.no_grad(): 
                b_input_ids = batch[0].to(self.config['train']['device'])
                b_labels = batch[1].to(self.config['train']['device'])
                outputs = self.model(b_input_ids, 
                            token_type_ids=None, 
                            labels=b_labels,
                            return_dict=True)
                val_loss += outputs.loss.item()
                _, predicted = torch.max(outputs.logits,-1)
                nrof_cor_predicts += (predicted == b_labels).sum().item()

                if to_calc_conf_matr:
                    if self.config['train']['exp_name'] == 'row_classification':
                        np.add.at(conf_matr, [b_labels.cpu().detach().numpy(),
                                          predicted.cpu().detach().numpy()],
                              [1] * len(b_labels))
                    else:
                        np.add.at(conf_matr, [b_labels.cpu().detach().numpy().flatten(),
                                          predicted.cpu().detach().numpy().flatten()],
                              [1] * len(b_labels.cpu().detach().numpy().flatten()))

                nrof_samples += len(b_labels)
                
        avg_val_accuracy = nrof_cor_predicts / nrof_samples
        avg_val_loss = val_loss / len(validation_dataloader)
        validation_time = (time() - t1)
        
        if to_calc_conf_matr:
            plot_conf_matr(conf_matr)
        
        print("  Validation took: {:}".format(validation_time))
        
        return avg_val_accuracy, avg_val_loss

    def evaluate(self, row, to_load=False):
        if to_load:
            self.load_model()
        self.model.eval()
        
        outputs = self.model(input_ids.to(self.config['train']['device']), 
                            token_type_ids=None, 
                            attention_mask=input_mask.to(self.config['train']['device']),
                            labels=torch.tensor(1, device=self.config['train']['device']),
                            return_dict=True)
        pr_class = self.config['train']['class_to_heading_map'][int(torch.max(outputs.logits,1)[1][0])]
        
        return pr_class 

    def evaluate_rows(self, rows, inputs, labels):
        labeled_text = ''
        with torch.no_grad(): 
            for i, (row, input, label) in enumerate(zip(rows, inputs, labels)):
                try:
                    outputs = self.model(input.unsqueeze(0).to(self.config['train']['device']),
                                token_type_ids=None, 
                            labels=label.unsqueeze(0).to(self.config['train']['device']),
                            return_dict=True)  
                    prediction = torch.max(outputs.logits,-1)[1]
                    pr_class = self.config['train']['class_to_heading_map'][int(prediction[0])]
                    labeled_text+= '\033[1m' + pr_class+': \033[0m'+row +'\n'
                except Exception as e:
                    print(e)
                    continue

        return labeled_text

    def evaluate_chunks(self, inputs, labels):
        labeled_text = ''
        with torch.no_grad(): 
            for input, label in zip(inputs, labels):
                text = tokenizer.convert_ids_to_tokens(input)
                outputs = self.model(input.unsqueeze(0).to(self.config['train']['device']),
                            token_type_ids=None, 
                        labels=label.unsqueeze(0).to(self.config['train']['device']),
                        return_dict=True)  
                prediction = torch.max(outputs.logits,-1)[1][0]
                current_text, current_label =[], []

                for token, label in zip(text[1:], list(prediction)[1:]):
                    if not token=='[CLS]':
                        current_text.append(token)
                        current_label.append(label)
                    else:
                        current_text = tokenizer.convert_tokens_to_string(current_text)
                        row_label = max(set(current_label), key = current_label.count) 
                        current_text = current_text.replace('[SEP]','')
                        labeled_text+= '\033[1m' + self.config['train']['class_to_heading_map'][row_label.item()]+': \033[0m'+current_text +'\n'
                        current_text, current_label =[], []

        return labeled_text

    def evaluate_text(self, text, to_load=False, to_save_text=False, script_name=''):
        if to_load:
            self.load_model()
        self.model.eval()
        
        text = re.sub('\n+', '\n', text)
        text = re.sub(' +', ' ', text)
        rows = text.split('\n')
        rows = [x.strip() for x in rows]
        labels = [0] * len(rows)
        labeled_text = ''
        if self.config['train']['exp_name'] == 'row_classification':
            inputs = make_tokenized_rows(rows, tokenizer)
        else:
            inputs, labels = make_tokenized_chunks_labels_from_rows(
                rows, labels, tokenizer)
        inputs, labels = prepare_inputs_labels(inputs, labels)

        if self.config['train']['exp_name'] == 'row_classification':
            labeled_text = self.evaluate_rows(rows ,inputs, labels)
        else:
            labeled_text = self.evaluate_chunks(rows, inputs, labels)

        if to_save_text:
                with open(os.path.join(self.config['paths']['model_annotations'], 
                                    self.config['train']['exp_name'],
                                    script_name +'_anno.txt'), 'w') as f:
                    f.write(re.sub('(\033\[1m)|(\033\[0m)', '', labeled_text))


## Train logging:

In [None]:
def log_test_info(tst_acc,  tst_classes_accs, tst_loss):
    test_results_txt = 'Test loss: {:3f}\nTest accuracy: {:3f}\n'
    test_results_txt.format()

In [None]:
T = Train(config)
tr_loader, val_loader, tst_loader = get_data_loaders()
T.train(tr_loader, val_loader, to_save=False)

In [None]:
with mlflow.start_run(run_name=config['train']['exp_name']):
    tr_loader, val_loader, tst_loader = get_data_loaders()
    T = Train(config)
    T.train(tr_loader, val_loader)
    tst_acc,  tst_loss = T.validate(tst_loader, to_calc_conf_matr=True)
    mlflow.log_artifact('conf_matrix.png')

In [None]:
mlflow.end_run()

In [None]:
get_ipython().system_raw("mlflow ui --port 5000 &")
NGROK_AUTH_TOKEN = ""
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

ngrok_tunnel = ngrok.connect(addr="5000", proto="http", bind_tls=True)
print("MLflow Tracking UI:", ngrok_tunnel.public_url)

In [None]:
ngrok.kill()

## Evaluation:

In [None]:
scripts_path = '/content/drive/MyDrive/NLP/Movie scripts dataset/Movie scripts and annotations/Scripts'

In [None]:
mean_matching_scores_df = pd.read_excel('/content/drive/MyDrive/NLP/Movie scripts dataset/Movie characters/Matching evaluation and statistics/movies_mean_matching_scores.xlsx')
imdb_ids = mean_matching_scores_df[mean_matching_scores_df.iou_values_mean<1.]['imdb_id'].tolist()

In [None]:
def make_script_annotations(imdb_ids_to_annotate):
    T = Train(config)

    for id_ in tqdm(imdb_ids_to_annotate):
        try:
            file_name = [name for name in file_names if str(id_) in name][0]
        except:
            continue
        if not os.path.exists(os.path.join(T.config['paths']['model_annotations'], 
                                   T.config['train']['exp_name'],
                                   file_name.split('.')[0] +'_anno.txt')):
            try:
                with open(os.path.join(scripts_path, file_name), 'r') as f:
                    text = f.read()
            except:
                try:
                    with open(os.path.join(scripts_path, file_name), 'r', encoding='latin-1') as f:
                        text = f.read()
                except:
                    continue
            T.evaluate_text(text, to_save_text=True, to_load=True, script_name=file_name.split('.')[0])

In [None]:
make_script_annotations(imdb_ids_to_annotate)