In [None]:
!pip install transformers
!pip install mlflow --quiet
!pip install pyngrok --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os 
import torch 
import pickle 
import pandas as pd 
import seaborn as sns
from tqdm import tqdm
from time import time
import seaborn as sn
from matplotlib import pyplot as plt 
from torch.nn import LSTM
import numpy as np 
import datetime
import random
from shutil import copytree
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, LongformerTokenizer, \
    LongformerForSequenceClassification, BertForSequenceClassification,AdamW,\
    get_cosine_schedule_with_warmup, BertTokenizerFast, BertModel, BertForNextSentencePrediction
from sklearn.metrics import balanced_accuracy_score, f1_score
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch 
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as Opt
#torch.set_default_tensor_type(torch.HalfTensor)
#torch.set_default_tensor_type(torch.cuda.HalfTensor)
import mlflow
from pyngrok import ngrok

In [None]:
config = {
    'paths':{
        'script_annotations': '/content/drive/MyDrive/Movie scripts dataset/Movie scripts and annotations/Script annotations by BERT/row_classification',
        'ckpt_dir': '/content/drive/MyDrive/Movie scripts models/BERTNSP/ckpts',
        'mlruns': '/content/drive/MyDrive/Movie scripts models/BERTNSP/mlruns',
        'data_input_ids':'/content/drive/MyDrive/Movie scripts models/BERTNSP/data/tokenized_scripts_input_ids.pickle',
        'data_attention_masks':'/content/drive/MyDrive/Movie scripts models/BERTNSP/data/tokenized_scripts_attention_masks.pickle',
        'data_token_type_ids':'/content/drive/MyDrive/Movie scripts models/BERTNSP/data/tokenized_scripts_token_type_ids.pickle'
    },
    'train': {
                'optim' : {
                    'AdamW':{
                        'lr':1e-5,
                        'eps': 1e-8,
                        'weight_decay':0.0001
                    }
                },
                'embedding_size': 768, 
                'model': BertForNextSentencePrediction, # from [BertForSequenceClassification, LongformerForSequenceClassification]
                'tokenizer':BertTokenizerFast, # from [BertTokenizer, LongformerTokenizer]
                'pretrained_model_type': 'bert-base-cased', # from ['bert-base-cased', 'allenai/longformer-base-4096']
                #'max_script_length' : 65540, #  131072, 65540, 32770
                'max_seq_length' : 512,
                'num_classes' : 2,
                #'max_scene_number' : 100, 
                #'nrof_steps' : 2100,
                'nrof_epochs' : 100,  
                 'tr_batch_size' : 4, #?
                'tst_batch_size' : 4, #?
                'exp_name':'next_scene_prediction', 
                'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
                #'heading_to_class_map': {'scene_heading':0, 'text':1, 'speaker_heading':2, 'dialog':3},
                #'class_to_heading_map': {0:'scene_heading', 1:'text', 2:'speaker_heading', 3:'dialog'},
                'load_model':True
                }
}

## Dataset preprocessing:

### Main:

In [None]:
def show_histogram(x_data, y_data, values_to_show=None, figsize=(40, 10), x_label='x', y_label='y',
                   set_rotation=False, rotation_angle=45, title='title', file_path='name.png', dpi=100, to_save=False, to_show=True):

    fig = plt.figure(figsize=figsize)
    h = sns.barplot(x=x_data, y=y_data, palette="Blues_d")
    if values_to_show:
        for i,y in enumerate(y_data):
            h.text(i, y, str(y)+'\n('+str(values_to_show[i]) + ')', color='black', ha='center')
    if set_rotation:
        h.set_xticklabels(h.get_xticklabels(), rotation=rotation_angle)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    if to_show:
        plt.show()

    if to_save:
        fig.savefig(file_path, dpi=dpi, bbox_inches='tight')
        
def remove_labels(row):
    if row.startswith('text:'):
        row = row[6:]
    elif row.startswith('dialog:'):
        row = row[8:]
    elif row.startswith('speaker_heading:'):
        row = row[17:]
    elif row.startswith('scene_heading:'):
        row = row[15:]
    return row.strip()

def check_if_paths_exist(*paths):
    for path in paths:
        if not os.path.exists(path):
            return False
    return True 


In [None]:
tokenizer = config['train']['tokenizer'].from_pretrained(config['train']['pretrained_model_type'], 
                                          do_lower_case=False)

In [None]:
class ScriptsData:
    def __init__(self, config, tokenizer):
        self.config = config 
        self.tokenizer = tokenizer
        self.scripts = self._get_scripts()
        self.scripts_scenes = self._get_script_scenes()
        self.scene_pairs, self.labels = self._get_scene_pairs()

    def _get_scripts(self, to_load=True):
        if not (to_load and os.path.exists('scripts.pickle')):
            scripts  = []
            for file_name in tqdm(os.listdir(self.config['paths']['script_annotations'])):
                imdb_id = file_name.split('_')[1] 
                with open(os.path.join(self.config['paths']['script_annotations'], file_name), 'r') as f:
                    anno_lines = f.readlines()
                scripts.append(anno_lines)
            with open('scripts.pickle', 'wb') as f:
                pickle.dump(scripts, f)

        else:
            with open('scripts.pickle', 'rb') as f:
                scripts = pickle.load(f)
            print('Data loaded')
        return scripts

    def _get_script_scenes(self, scene_max_length=20000):
        scripts_scenes = []

        scene_text = ''
        for i, script_lines in tqdm(enumerate(self.scripts)):
            scripts_scenes.append([])
            for line_num, line in enumerate(script_lines):
                if line.startswith('scene_heading:'):
                    if scene_text and len(scene_text)<scene_max_length:
                        scripts_scenes[-1].append(scene_text)
                    scene_text = remove_labels(line)
                else:
                    scene_text+=remove_labels(line)
            if not scripts_scenes[-1]:
                scripts_scenes.pop(-1)
                
        return scripts_scenes

    def _get_scene_pairs(self):
        scene_pairs, labels = [], []
        for script in tqdm(self.scripts_scenes):
            for i in range(1, len(script)):
                scene_pairs.append((script[i-1], script[i]))
                labels.append(1)
                if i+2<len(script):
                    random_scene_ind = random.choice(list(range(i-1)) + list(range(i+2, len(script))))
                    if random_scene_ind > i:
                        scene_pairs.append((script[i], script[random_scene_ind]))
                    else:
                        scene_pairs.append((script[random_scene_ind], script[i]))
                    labels.append(0)
        return scene_pairs, labels

    def tokenize_scripts(self, to_load=True, max_nrof_examples=200000):
        if to_load and check_if_paths_exist(self.config['paths']['data_input_ids'],
                                            self.config['paths']['data_token_type_ids'],
                                            self.config['paths']['data_attention_masks']):
            with open(self.config['paths']['data_input_ids'], 'rb') as f:
                tokenized_scripts_input_ids = pickle.load(f)
            with open(self.config['paths']['data_token_type_ids'], 'rb') as f:
                tokenized_scripts_token_type_ids = pickle.load(f)
            with open(self.config['paths']['data_attention_masks'], 'rb') as f:
                tokenized_scripts_attention_masks = pickle.load(f)
            print('tokenized data loaded')
        else:
            i = 0
            tokenized_scripts_input_ids, tokenized_scripts_token_type_ids, tokenized_scripts_attention_masks =[],[],[]
            tokenized_scripts = self.tokenizer(self.scene_pairs[:max_nrof_examples],
                                                max_length = self.config['train']['max_seq_length'],
                                                truncation=True, 
                                                padding='max_length',
                                                return_attention_mask=True,
                                                return_tensors='pt')
            tokenized_scripts_input_ids = tokenized_scripts['input_ids']
            tokenized_scripts_token_type_ids = tokenized_scripts['token_type_ids']
            tokenized_scripts_attention_masks=tokenized_scripts['attention_mask']
            
            with open(self.config['paths']['data_input_ids'], 'wb') as f:
                pickle.dump(tokenized_scripts_input_ids, f)
            with open(self.config['paths']['data_token_type_ids'], 'wb') as f:
                pickle.dump(tokenized_scripts_token_type_ids, f)
            with open(self.config['paths']['data_attention_masks'], 'wb') as f:
                pickle.dump(tokenized_scripts_attention_masks, f)
        return tokenized_scripts_input_ids, tokenized_scripts_token_type_ids, tokenized_scripts_attention_masks
    '''
    def prepare_tokenized_chunks_masks_labels(self, 
                                              input_ids, 
                                              token_type_ids, 
                                              attention_masks,
                                              labels):
        input_ids = [torch.tensor(chunk) for chunk in tokenized_script_chunks]
        attention_masks = [torch.tensor(mask) for mask in attention_masks]
        padded_tokenized_script_chunks = pad_sequence(tokenized_script_chunks, 
                                                      padding_value=0, batch_first=True)# for BERT pad = 0 ?
        padded_attention_masks = pad_sequence(attention_masks, 
                                              padding_value=0, batch_first=True)
        labels = torch.LongTensor(labels)

        return padded_tokenized_script_chunks, padded_attention_masks, labels
    '''

    def get_train_val_split(self):
        input_ids, token_type_ids, attention_masks = self.tokenize_scripts()
        tensor_labels = torch.LongTensor(self.labels[:200000])
        tr_inputs, val_inputs, tr_masks, val_masks, tr_token_type_ids, val_token_type_ids, tr_labels, val_labels = train_test_split(
            input_ids, attention_masks, token_type_ids, tensor_labels, 
            test_size=0.2, random_state=11)

        tst_inputs, val_inputs, tst_masks, val_masks, tst_token_type_ids, val_token_type_ids, tst_labels, val_labels = train_test_split(
            val_inputs, val_masks, val_token_type_ids, val_labels, 
            test_size=0.25, random_state=11)
        '''
        tr_inputs, tr_masks, tr_labels = self.prepare_tokenized_chunks_masks_labels(
            tr_inputs, tr_masks, tr_labels)
        val_inputs, val_masks, val_labels = self.prepare_tokenized_chunks_masks_labels(
            val_inputs, val_masks, val_labels)
        tst_inputs, tst_masks, tst_labels = self.prepare_tokenized_chunks_masks_labels(
            tst_inputs, tst_masks, tst_labels)
        '''
        return tr_inputs, tr_masks, tr_token_type_ids, tr_labels, val_inputs, val_masks, val_token_type_ids, val_labels, tst_inputs, tst_masks, tst_token_type_ids, tst_labels


In [None]:
def get_dataloader(input_ids, attention_masks, token_type_ids, labels, batch_size=1,  
                   phase='train', sampler=None):
        dataset = TensorDataset(input_ids, attention_masks, token_type_ids, labels)
        if phase=='train':
            sampler = sampler if not sampler is None else RandomSampler(dataset)
            dataloader = DataLoader(
                        dataset,  
                        batch_size = batch_size,
                        sampler = sampler,
                        drop_last=True
                    )
        else:
            dataloader = DataLoader(
                        dataset,  
                        batch_size = batch_size,
                        drop_last=True
                    )

        return dataloader 

In [None]:
def get_data_loaders():
    SD = ScriptsData(config, tokenizer)
    tr_inputs, tr_attention_masks, tr_token_type_ids, tr_labels, val_inputs, val_attention_masks, \
    val_token_type_ids, val_labels, tst_inputs, tst_attention_masks, tst_token_type_ids, tst_labels = SD.get_train_val_split()

    #tr_loader = get_dataloader(tr_inputs[[0, 1, 3, 4]], tr_labels[[0, 1, 3, 4]],  tr_attention_masks[[0, 1, 3, 4]], # [0, 1, 3, 4]
    tr_loader = get_dataloader(tr_inputs, tr_attention_masks, tr_token_type_ids,tr_labels,
                            batch_size=config['train']['tr_batch_size'])
    val_loader = get_dataloader(val_inputs, val_attention_masks, val_token_type_ids, val_labels,
                                batch_size=config['train']['tst_batch_size'])
    tst_loader = get_dataloader(tst_inputs, tst_attention_masks, tst_token_type_ids, tst_labels,
                                batch_size=config['train']['tst_batch_size'])
    
    return tr_loader, val_loader, tst_loader


### Check:

In [None]:
SD = ScriptsData(config, tokenizer)

In [None]:
attention_masks, input_ids, token_type_ids = [], [], []

for i in range(1,5):
    with open('/content/tokenized_scripts_attention_masks_' + str(i) + '.pickle', 'rb') as f:
        attention_masks.append(pickle.load(f))

    with open('/content/tokenized_scripts_input_ids_' + str(i) + '.pickle', 'rb') as f:
        input_ids.append(pickle.load(f))

    with open('/content/tokenized_scripts_token_type_ids_' + str(i) + '.pickle', 'rb') as f:
        token_type_ids.append(pickle.load(f))

In [None]:
tr_inputs, tr_masks, tr_token_type_ids, tr_labels, val_inputs, val_masks, val_type_ids, val_labels, tst_inputs, tst_masks, tst_type_ids, tst_labels = SD.get_train_val_split()

In [None]:
print(tr_inputs.size())

In [None]:
tokenizer.convert_ids_to_tokens([101, 157, 10964, 12426, 1592, 22219, 2036, 2924, 2036, 2924, 7729, 5208, 1118, 2107, 24554, 1161, 139, 9435, 4729, 2064, 6530, 1181, 1113, 1103, 1520, 1118, 2101, 28114, 14159, 6262, 16838, 1116, 10973, 1582, 1357, 1371, 102])

In [None]:
tr_loader, val_loader, tst_loader = get_data_loaders()

In [None]:
for d in tr_loader:
    print(d)
    print(d[0].size())
    print(d[1].size())
    print(d[2].size())
    print(d[3].size())
    break

## Training process:

In [None]:
def plot_conf_matr(trues, predicts, title='tr_', 
                   classes_names=['next', 'not next'],
                   nrof_classes=2):
    results = np.zeros((nrof_classes, nrof_classes))
    for t, p in zip(trues, predicts):
        results[t][p]+=1

    df_cm = pd.DataFrame(results.astype(np.int), index = classes_names,
                  columns = classes_names)
    plt.figure(figsize = (7,7))    
    ax = sns.heatmap(df_cm, annot=True, fmt='d')
    ax.set(xlabel='predicted', ylabel='actual',title='Confusion matrix')
    plt.savefig(title+'conf_matrix.png', bbox_inches='tight')
    plt.close()
    
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
class Train():
    def __init__(self, config):
        self.config = config
        self.model = config['train']['model'].from_pretrained(config['train']['pretrained_model_type'], 
                                          num_labels = self.config['train']['num_classes'], 
                                          output_attentions = False, 
                                          output_hidden_states = False)

        opt_config = self.config['train']['optim']['AdamW']
        #for key, val in opt_config.items():
         #   mlflow.log_param(key, val)
        #mlflow.log_param('nrof_classes', self.config['train']['num_classes'])
        #self.total_steps = self.config['train']['nrof_steps']
        self.optimizer = AdamW(self.model.parameters(),
                  lr = opt_config['lr'], 
                  eps = opt_config['eps'], 
                  weight_decay=opt_config['weight_decay'],
                )        
        #self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, 
         #                                   num_warmup_steps = 10, 
          #                                  num_training_steps = self.total_steps)
        self.model.to(self.config['train']['device'])
        self.global_step = 0
        self.max_steps_to_stop = 1000
        self.best_val_loss = 1000
        
    
    def save_model(self):
        torch.save({"model": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    'global_step': self.global_step
                    #"scheduler": self.scheduler.state_dict(),
                    },
                   os.path.join(self.config['paths']['ckpt_dir'], 
                                self.config['train']['exp_name'] + '_checkpoint'))
    
    def load_model(self):
        ckpt = torch.load(os.path.join(self.config['paths']['ckpt_dir'],
                                       self.config['train']['exp_name'] + '_checkpoint'),
                          map_location=self.config['train']['device'])
        model_st_dict = ckpt["model"]
        self.global_step = ckpt["global_step"] + 1
        self.model.load_state_dict(model_st_dict)   
        self.optimizer.load_state_dict(ckpt["optimizer"])
        #self.scheduler.load_state_dict(ckpt["scheduler"])
        print("Model loaded...")


    def train(self, train_dataloader, validation_dataloader, to_save=True):
        if self.config['train']['load_model']:
            self.load_model()
        self.model.train()
        t0 = time()
        
        cur_loss, nrof_steps, = 0., 0,
        nrof_steps_to_stop = 0
        nrof_cor_predicts, nrof_samples = 0, 0
        predicts, trues = [], []
        print('Global step:', self.global_step)
        #nrof_cor_predicts_current, nrof_samples_current = 0, 0
        #class_correct = list(0. for i in range(self.config['train']['num_classes']))
        #class_total = list(0. for i in range(self.config['train']['num_classes']))
        

        for epoch in tqdm(range(self.config['train']['nrof_epochs'])):
            predicts, trues = [], []
            for step, batch in enumerate(train_dataloader):
                if epoch * len(train_dataloader) + step < self.global_step:
                    continue
                print('epoch: {} step: {}'.format(epoch, step))
                b_input_ids = batch[0].to(self.config['train']['device'])
                b_input_mask = batch[1].to(self.config['train']['device'])
                b_token_type_ids = batch[2].to(self.config['train']['device'])
                b_labels = batch[3].to(self.config['train']['device'])
                
                self.model.zero_grad()  

                outputs = self.model(b_input_ids,  
                                    b_input_mask,
                                     b_token_type_ids,
                                    labels=b_labels, #.unsqueeze(0),
                                    return_dict=True)
                #print(outputs)
                #loss = self.crit(output.unsqueeze(0), b_labels)
                cur_loss += outputs.loss.item()
                #cur_loss += loss.item()
                _, predicted = torch.max(outputs.logits,-1)
                predicts.extend(predicted.cpu().detach().numpy().tolist())
                trues.extend(b_labels.cpu().detach().numpy().tolist())
                if_right = (predicted == b_labels).sum().item()
                nrof_cor_predicts += if_right
                #nrof_cor_predicts_current += if_right
                #print('class_correct', class_correct)
                #print('if_right', if_right)
                #for i, label in enumerate(b_labels.cpu().detach().numpy()):
                    #print(label)
                    #print(int(label))
                 #   class_correct[int(label)] += (predicted == b_labels)[i].item()
                  #  class_total[int(label)] += 1

                #nrof_cor_predicts += if_right
                #nrof_cor_predicts_current += if_right
                #class_correct[b_labels.item()] += if_right
                #class_total[b_labels.item()] += len(b_labels)

                outputs.loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                #self.scheduler.step()
                self.global_step+=1
                nrof_steps+=1
                nrof_samples+=len(b_labels)
                #nrof_samples_current+=len(b_labels)

                
                if self.global_step % 10 == 0: 
                    if self.global_step % 100 == 0: 
                        avg_val_loss, val_trues, val_predicts = self.validate(validation_dataloader)
                        avg_val_accuracy = balanced_accuracy_score(val_trues, val_predicts)
                        #val_f1_score = f1_score(val_trues, val_predicts)
                        mlflow.log_metric("val_loss", avg_val_loss, step = self.global_step)
                        mlflow.log_metric("val_accuracy", avg_val_accuracy, step = self.global_step)
                        #mlflow.log_metric("val_f1_score", val_f1_score, step = self.global_step)
                        print('val loss: {}\nval accuracy: {}'.format(avg_val_loss,  avg_val_accuracy))
                        
                        plot_conf_matr(val_trues, val_predicts, title=str(self.global_step)+'_', 
                                    classes_names=['not next', 'next'])
                        mlflow.log_artifact(str(self.global_step)+'_conf_matrix.png')

                        if avg_val_loss < self.best_val_loss:
                            self.best_val_loss = avg_val_loss
                            self.save_model()
                            try:
                                copytree('/content/mlruns', self.config['paths']['mlruns'] + '_' + str(self.global_step))
                            except Exception as e:
                                print('Exception {} on step {}'.format(e, self.global_step))
                                pass 
                            #if nrof_steps_to_stop>0:
                             #   nrof_steps_to_stop -= 1
                        else:
                            pass 
                            #nrof_steps_to_stop+=1
                            #if nrof_steps_to_stop > self.max_steps_to_stop:
                             #   break
                              #  try:
                               #     copytree(+str(self.global_step))
                               # except Exception as e:
                               # print(e)
                        self.model.train()
                        torch.cuda.empty_cache()
                    predicts, trues = predicts[-50:], trues[-50:]
                    elapsed = format_time(time() - t0)
                    print('Elapsed: {:}.'.format(elapsed))
                    
                    mlflow.log_metric("train_loss", cur_loss / nrof_steps, step = self.global_step)
                    #tr_class_accs = np.asarray(class_correct) / np.asarray(class_total)
                    #tr_bal_acc = np.mean(tr_class_accs)
                    #mlflow.log_metric("train_accuracy", float(nrof_cor_predicts_current)/nrof_samples_current)
                    #mlflow.log_metric("train_accuracy", float(nrof_cor_predicts)/nrof_samples)
                    tr_accuracy = balanced_accuracy_score(trues, predicts)
                    #tr_f1_score = f1_score(trues, predicts)
                    mlflow.log_metric("train_accuracy", tr_accuracy, step = self.global_step)
                    #mlflow.log_metric("f1_score", tr_f1_score, step = self.global_step)
                    mlflow.log_metric("learning_rate", self.optimizer.state_dict()["param_groups"][0]["lr"])
                    
                    #mlflow.log_metric("train_balanced_accuracy", tr_bal_acc)
                    print('tr loss: {}\ntr accuracy: {}'.format(cur_loss/nrof_steps,  tr_accuracy))
                    print()

                    cur_loss, nrof_steps, nrof_samples, nrof_cor_predicts = 0., 0, 0, 0

                
                    #if self.global_step % 20 == 0: # 5
                    '''
                    if avg_val_loss < self.best_val_loss:
                        self.best_val_loss = avg_val_loss
                        self.save_model()
                        if nrof_steps_to_stop>0:
                            nrof_steps_to_stop -= 1
                    else:
                        nrof_steps_to_stop+=1
                        if nrof_steps_to_stop > self.max_steps_to_stop:
                            break
                            try:
                               copytree(+str(self.global_step))
                            except Exception as e:
                               print(e)
                    '''       
                    
                    
            #if nrof_steps_to_stop > self.max_steps_to_stop:
             #   break
        
        '''
        self.save_model()
        try:
            copytree('/content/mlruns', '/content/drive/MyDrive/CharEval/mlruns')
        except Exception as e:
            print(e)
            pass 
        '''
            


    def validate(self, validation_dataloader, to_load=False):
        if to_load:
            self.load_model()
        t1 = time()
        self.model.eval()
        predicts, trues = [], []
        nrof_steps, val_loss, nrof_cor_predicts, nrof_samples = 0, 0., 0, 0
        #class_correct = list(0. for i in range(self.config['train']['num_classes']))
        #class_total = list(0. for i in range(self.config['train']['num_classes']))

        with torch.no_grad(): 
            for i, batch in tqdm(enumerate(validation_dataloader)):
                if i>10000:
                    break
                b_input_ids = batch[0].to(self.config['train']['device'])
                b_input_mask = batch[1].to(self.config['train']['device'])
                b_token_type_ids = batch[2].to(self.config['train']['device'])
                b_labels = batch[3].to(self.config['train']['device'])

                output = self.model(b_input_ids,  
                                    b_input_mask,
                                     b_token_type_ids,
                                    labels=b_labels, #.unsqueeze(0),
                                    return_dict=True)
                            
                val_loss += output.loss.item()
                #val_loss += self.crit(output.unsqueeze(0), b_labels)
                _, predicted = torch.max(output.logits,-1)
                predicts.extend(predicted.cpu().detach().numpy().tolist())
                trues.extend(b_labels.cpu().detach().numpy().tolist())
                if_right = (predicted == b_labels).sum().item()
                nrof_cor_predicts += if_right
                #nrof_cor_predicts_current += if_right
                #print('class_correct', class_correct)
                #print('if_right', if_right)
                #for i, label in enumerate(b_labels.cpu().detach().numpy()):
                    #print(label)
                    #print(int(label))
                 #   class_correct[int(label)] += (predicted == b_labels)[i].item()
                  #  class_total[int(label)] += 1

                nrof_steps+=1
                nrof_samples+=len(b_labels)
                #nrof_samples_current+=len(b_labels)
                
                
        #avg_val_accuracy = nrof_cor_predicts / nrof_samples
        #avg_val_accuracy = float((np.asarray(class_correct) / np.asarray(class_total)).sum())
        #avg_val_accuracy = balanced_accuracy_score(trues, predicts)
        #val_f1_score = f1_score(trues, predicts)
        #val_classes_accs = np.asarray(class_correct) / np.asarray(class_total)
        avg_val_loss = val_loss / nrof_steps
        validation_time = (time() - t1)
        
        print("  Validation took: {:}".format(validation_time))
        
        return avg_val_loss, trues, predicts


In [None]:
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

tr_loader, val_loader, tst_loader = get_data_loaders()

In [None]:
print(len(tr_loader))

In [None]:
T = Train(config)

In [None]:
avg_tst_loss, tst_trues, tst_predicts = T.validate(tst_loader, to_load=True)

In [None]:
balanced_accuracy_score(tst_trues, tst_predicts)

In [None]:
plot_conf_matr(trues, predicts, title='bert_nsp_test_')

In [None]:
avg_tr_loss, trues, predicts = T.validate(tr_loader, to_load=True)

In [None]:
print(avg_tr_loss)

In [None]:
balanced_accuracy_score(trues, predicts)

In [None]:
plot_conf_matr(trues, predicts, title='bert_nsp_train_')

In [None]:
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

with mlflow.start_run(run_name=config['train']['exp_name'], run_id='b0dbc31ddd184269b188097712a888a0'):
    tr_loader, val_loader, tst_loader = get_data_loaders()
    T = Train(config)
    T.train(tr_loader, val_loader)

In [None]:
import shutil

shutil.rmtree('/content/mlruns')

In [None]:
!cp -r '/content/drive/MyDrive/Movie scripts models/BERTNSP/mlruns_800' . 

In [None]:
get_ipython().system_raw("mlflow ui --port 5000 &") # run tracking UI in the background
NGROK_AUTH_TOKEN = ""
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Open an HTTPs tunnel on port 5000 for http://localhost:5000
ngrok_tunnel = ngrok.connect(addr="5000", proto="http", bind_tls=True)
print("MLflow Tracking UI:", ngrok_tunnel.public_url)

In [None]:
ngrok.kill()