In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py -q
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!export XLA_USE_BF16=1

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  21677      0 --:--:-- --:--:-- --:--:-- 21677
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 2.8MB/s 
Uninstalling torch-1.6.0+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google-api-python-c

In [2]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [3]:
! pip install transformers -q
! pip install tokenizers -q

[K     |████████████████████████████████| 1.1MB 3.4MB/s 
[K     |████████████████████████████████| 3.0MB 48.9MB/s 
[K     |████████████████████████████████| 1.1MB 11.4MB/s 
[K     |████████████████████████████████| 890kB 46.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [4]:
import os
import re
import json
import ast
import pandas as pd
from pathlib import Path
import matplotlib.cm as cm
import numpy as np
import pandas as pd
from typing import *
from tqdm.notebook import tqdm
from sklearn.utils.extmath import softmax
from sklearn import model_selection
from sklearn.metrics import classification_report, f1_score

In [5]:
import torch
import torch.optim as optim
import torch.nn as nn
import transformers
from transformers import AdamW
import tokenizers

In [6]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

from joblib import Parallel, delayed

In [7]:
def seed_all(seed = 42):
  """
  Fix seed for reproducibility
  """
  # python RNG
  import random
  random.seed(seed)

  # pytorch RNGs
  import torch
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

  # numpy RNG
  import numpy as np
  np.random.seed(seed)

In [8]:
import os
os.chdir('/gdrive/My Drive/SDU_2/')

In [9]:
# from transformers import BertModel, BertTokenizer
# m = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
# t = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

In [10]:
# ! mkdir scibert_scivocab_uncased
# m.save_pretrained('./scibert_scivocab_uncased')
# t.save_vocabulary('./scibert_scivocab_uncased')

In [11]:
class config:
  SEED = 42
  KFOLD = 5
  TRAIN_FILE = './data/train.csv'
  VAL_FILE = './data/dev.csv'
  TEST_FILE = './data/dev.csv'
  SAVE_DIR = 'run_scibert_tpu'
  MAX_LEN = 192
  MODEL = './scibert_scivocab_uncased'
  TOKENIZER = tokenizers.BertWordPieceTokenizer(f"{MODEL}/vocab.txt", lowercase=True)
  EPOCHS = 5
  TRAIN_BATCH_SIZE = 32
  VALID_BATCH_SIZE = 32
  DICTIONARY = json.load(open('./data/diction.json'))
  
  A2ID = {}
  for k, v in DICTIONARY.items():
    for w in v:
      A2ID[w] = len(A2ID)


In [12]:
class AverageMeter:
    """
    Computes and stores the average and current value
    Source : https://www.kaggle.com/abhishek/bert-base-uncased-using-pytorch/
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [13]:
class EarlyStopping:
    """
    Early stopping utility
    Source : https://www.kaggle.com/abhishek/bert-base-uncased-using-pytorch/
    """
    
    def __init__(self, patience=7, mode="max", delta=0.001):
        self.patience = patience
        self.counter = 0
        self.mode = mode
        self.best_score = None
        self.early_stop = False
        self.delta = delta
        if self.mode == "min":
            self.val_score = np.Inf
        else:
            self.val_score = -np.Inf

    def __call__(self, epoch_score, model, model_path):
        if self.mode == "min":
            score = -1.0 * epoch_score
        else:
            score = np.copy(epoch_score)
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
            self.counter = 0

    def save_checkpoint(self, epoch_score, model, model_path):
        if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
            print('Validation score improved ({} --> {}). Saving model!'.format(self.val_score, epoch_score))
            xm.save(model.state_dict(), model_path)
        self.val_score = epoch_score

In [14]:
import re
def fix_latex(text):
  key_words = ['figure', 'fig', 'tabular', 'table', '&\d*\?\d*', '&', r'[\s]c*[\s]'
               r'minipage\d.\d*', 'center', r'[\(\[].*?[\)\]]', 
               r'align.*align', r'^\d{1,6}(\.\d{1,2})?$']

  text = re.sub(r'(\/.*?\.[\w:]+)', '', text)
  for kw in key_words:
    text = re.sub(kw, '', text)
  
  text = ' '.join(text.split())

  return text

In [15]:
def sample_text(text, acronym, max_len):
  text = text.split()
  idx = text.index(acronym)
  left_idx = max(0, idx - max_len//2)
  right_idx = min(len(text), idx + max_len//2)
  sampled_text = text[left_idx:right_idx]
  return ' '.join(sampled_text)

In [16]:
def process_data(text, acronym, expansion, tokenizer, max_len):

  text = str(text)
  expansion = str(expansion)
  acronym = str(acronym)

  n_tokens = len(text.split())
  if n_tokens>80:
    text = sample_text(text, acronym, 80)

  answers = ' or '.join(config.DICTIONARY[acronym])
  start = answers.find(expansion)
  end = start + len(expansion)

  char_mask = [0]*len(answers)
  for i in range(start, end):
    char_mask[i] = 1
  
  tok_answer = tokenizer.encode(answers)
  answer_ids = tok_answer.ids
  answer_offsets = tok_answer.offsets

  answer_ids = answer_ids[1:-1]
  answer_offsets = answer_offsets[1:-1]

  target_idx = []
  for i, (off1, off2) in enumerate(answer_offsets):
      if sum(char_mask[off1:off2])>0:
        target_idx.append(i)

  start = target_idx[0]
  end = target_idx[-1]

  
  text_ids = tokenizer.encode(text).ids[1:-1]

  token_ids = [101] + answer_ids + [102] + text_ids + [102]
  offsets =   [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2)
  mask = [1] * len(token_ids)
  token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids))

  text = answers + text
  start = start + 1
  end = end + 1

  padding = max_len - len(token_ids)
    

  if padding>=0:
    token_ids = token_ids + ([0] * padding)
    token_type = token_type + [1] * padding
    mask = mask + ([0] * padding)
    offsets = offsets + ([(0, 0)] * padding)
  else:
    token_ids = token_ids[0:max_len]
    token_type = token_type[0:max_len]
    mask = mask[0:max_len]
    offsets = offsets[0:max_len]
  

  assert len(token_ids)==max_len
  assert len(mask)==max_len
  assert len(offsets)==max_len
  assert len(token_type)==max_len

  return {
          'ids': token_ids,
          'mask': mask,
          'token_type': token_type,
          'offset': offsets,
          'start': start,
          'end': end,  
          'text': text,
          'expansion': expansion,
          'acronym': acronym,
        }

In [17]:
class Dataset:
    def __init__(self, text, acronym, expansion):
        self.text = text
        self.acronym = acronym
        self.expansion = expansion
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        data = process_data(
            self.text[item],
            self.acronym[item],
            self.expansion[item], 
            self.tokenizer,
            self.max_len,
            
        )

        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type': torch.tensor(data['token_type'], dtype=torch.long),
            'offset': torch.tensor(data['offset'], dtype=torch.long),
            'start': torch.tensor(data['start'], dtype=torch.long),
            'end': torch.tensor(data['end'], dtype=torch.long),
            'text': data['text'],
            'expansion': data['expansion'],
            'acronym': data['acronym'],
        }

In [18]:
def get_loss(start, start_logits, end, end_logits):
  loss_fn = nn.CrossEntropyLoss()
  start_loss = loss_fn(start_logits, start)
  end_loss = loss_fn(end_logits, end)
  loss = start_loss + end_loss
  return loss

In [19]:
class BertAD(nn.Module):
  def __init__(self):
    super(BertAD, self).__init__()
    self.bert = transformers.BertModel.from_pretrained(config.MODEL, output_hidden_states=True)
    self.layer = nn.Linear(768, 2)
    

  def forward(self, ids, mask, token_type, start=None, end=None):
    output = self.bert(input_ids = ids,
                       attention_mask = mask,
                       token_type_ids = token_type)
    
    logits = self.layer(output[0]) 
    start_logits, end_logits = logits.split(1, dim=-1)
    
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    loss = get_loss(start, start_logits, end, end_logits)    

    return loss, start_logits, end_logits

In [20]:
def train_fn(data_loader, model, optimizer, device, fold):
  model.train()
  losses = AverageMeter()
  tk0 = tqdm(data_loader, total=len(data_loader))
  
  for bi, d in enumerate(tk0):
    ids = d['ids']
    mask = d['mask']
    token_type = d['token_type']
    start = d['start']
    end = d['end']
    

    ids = ids.to(device, dtype=torch.long)
    token_type = token_type.to(device, dtype=torch.long)
    mask = mask.to(device, dtype=torch.long)
    start = start.to(device, dtype=torch.long)
    end = end.to(device, dtype=torch.long)
    

    model.zero_grad()
    loss, start_logits, end_logits = model(ids, mask, token_type, start, end)
    
    loss.backward()
    xm.optimizer_step(optimizer, barrier=True)
    
    losses.update(loss.item(), ids.size(0))
    tk0.set_postfix(fold=fold, loss=losses.avg)


In [21]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

In [22]:
def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end):
  filtered_output = ""
  for ix in range(idx_start, idx_end + 1):
      filtered_output += text[offsets[ix][0]: offsets[ix][1]]
      if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
          filtered_output += " "

  candidates = config.DICTIONARY[acronym]
  candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates]
  idx = np.argmax(candidate_jaccards)

  return candidate_jaccards[idx], candidates[idx]

In [23]:
def eval_fn(data_loader, model, device, fold):
  model.eval()
  losses = AverageMeter()
  jac = AverageMeter()

  tk0 = tqdm(data_loader, total=len(data_loader))

  pred_expansion_ = []
  true_expansion_ = []

  for bi, d in enumerate(tk0):
    ids = d['ids']
    mask = d['mask']
    token_type = d['token_type']
    start = d['start']
    end = d['end']
    
    text = d['text']
    expansion = d['expansion']
    offset = d['offset']
    acronym = d['acronym']


    ids = ids.to(device, dtype=torch.long)
    mask = mask.to(device, dtype=torch.long)
    token_type = token_type.to(device, dtype=torch.long)
    start = start.to(device, dtype=torch.long)
    end = end.to(device, dtype=torch.long)
    
    with torch.no_grad():
      loss, start_logits, end_logits = model(ids, mask, token_type, start, end)


    start_prob = torch.softmax(start_logits, dim=1).detach().cpu().numpy()
    end_prob = torch.softmax(end_logits, dim=1).detach().cpu().numpy()
  
  
    jac_= []
    
    for px, s in enumerate(text):
      start_idx = np.argmax(start_prob[px,:])
      end_idx = np.argmax(end_prob[px,:])

      js, exp = evaluate_jaccard(s, expansion[px], acronym[px], offset[px], start_idx, end_idx)
      jac_.append(js)
      pred_expansion_.append(exp)
      true_expansion_.append(expansion[px])

    
    jac.update(np.mean(jac_), len(jac_))
    losses.update(loss.item(), ids.size(0))

    tk0.set_postfix(fold=fold, loss=losses.avg, jaccard=jac.avg)


  pred_expansion_ = [config.A2ID[w] for w in pred_expansion_]
  true_expansion_ = [config.A2ID[w] for w in true_expansion_]
  
  f1 = f1_score(true_expansion_, pred_expansion_, average='macro')

  print(f'Fold: {fold} | Average Jaccard: {jac.avg} | Macro F1: {f1}')

  return f1 
  

In [24]:
seed_all()

df_train = pd.read_csv(config.TRAIN_FILE)
df_val = pd.read_csv(config.VAL_FILE)

# concatenating train and validation set
train = pd.concat([df_train, df_val]).reset_index()

# train = df_train

# dividing folds
kf = model_selection.StratifiedKFold(n_splits=config.KFOLD, shuffle=False, random_state=config.SEED)
for fold, (train_idx, val_idx) in enumerate(kf.split(X=train, y=train.acronym_.values)):
    train.loc[val_idx, 'kfold'] = fold



In [25]:
def run(fold):
  df_train = train[train.kfold!=fold].reset_index(drop=True)
  df_val = train[train.kfold==fold].reset_index(drop=True)

  train_dataset = Dataset(
        text = df_train.text.values,
        acronym = df_train.acronym_.values,
        expansion = df_train.expansion.values
    )
  
  valid_dataset = Dataset(
        text = df_val.text.values,
        acronym = df_val.acronym_.values,
        expansion = df_val.expansion.values,
    )
    
  train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        num_workers=4
    )

  valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=2
    )
  

  model = BertAD()
  device = xm.xla_device(fold + 1)
  model.to(device)

  lr = 2e-5
  param_optimizer = list(model.named_parameters())
  no_decay = ['bias', 'gamma', 'beta']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.01},
      {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.0}
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

  es = EarlyStopping(patience=2, mode="max")

  for epoch in range(config.EPOCHS):
    train_fn(train_data_loader, model, optimizer, device, fold)
    valid_loss = eval_fn(valid_data_loader, model, device, fold)
    print(f'Fold: {fold} | Epoch: {epoch + 1} | Validation Score: {valid_loss}')
    if fold is None:
      es(valid_loss, model, model_path=os.path.join(config.SAVE_DIR, "model.bin"))
    else:
      es(valid_loss, model, model_path=os.path.join(config.SAVE_DIR, f"model_{fold}.bin"))

  return

In [26]:
! rm -rf {config.SAVE_DIR} && mkdir {config.SAVE_DIR}
Parallel(n_jobs=5, backend='threading')(delayed(run)(i) for i in range(5))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




Fold: 2 | Average Jaccard: 0.9857140629477151 | Macro F1: 0.7740237279873928
Fold: 2 | Epoch: 1 | Validation Score: 0.7740237279873928
Validation score improved (-inf --> 0.7740237279873928). Saving model!
Fold: 1 | Average Jaccard: 0.9881758968382971 | Macro F1: 0.7675582845942754
Fold: 1 | Epoch: 1 | Validation Score: 0.7675582845942754
Validation score improved (-inf --> 0.7675582845942754). Saving model!
Fold: 0 | Average Jaccard: 0.9880115893855339 | Macro F1: 0.7626775132345954
Fold: 0 | Epoch: 1 | Validation Score: 0.7626775132345954
Validation score improved (-inf --> 0.7626775132345954). Saving model!



Fold: 3 | Average Jaccard: 0.9873871259476907 | Macro F1: 0.7757277630053632
Fold: 3 | Epoch: 1 | Validation Score: 0.7757277630053632
Validation score improved (-inf --> 0.7757277630053632). Saving model!
Fold: 4 | Average Jaccard: 0.9879306217333278 | Macro F1: 0.7623925330475299
Fold: 4 | Epoch: 1 | Validation Score: 0.7623925330475299
Validation score improved (-inf -->

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 1 | Average Jaccard: 0.9922944214962174 | Macro F1: 0.8074059662641452
Fold: 1 | Epoch: 2 | Validation Score: 0.8074059662641452
Validation score improved (0.7675582845942754 --> 0.8074059662641452). Saving model!

Fold: 2 | Average Jaccard: 0.9897981857573445 | Macro F1: 0.8108169648020881
Fold: 2 | Epoch: 2 | Validation Score: 0.8108169648020881
Validation score improved (0.7740237279873928 --> 0.8108169648020881). Saving model!

Fold: 4 | Average Jaccard: 0.9890868089774176 | Macro F1: 0.8146346606569871
Fold: 4 | Epoch: 2 | Validation Score: 0.8146346606569871
Validation score improved (0.7623925330475299 --> 0.8146346606569871). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))


Fold: 0 | Average Jaccard: 0.9908894318918097 | Macro F1: 0.8062484188598154
Fold: 0 | Epoch: 2 | Validation Score: 0.8062484188598154
Validation score improved (0.7626775132345954 --> 0.8062484188598154). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9906187440215137 | Macro F1: 0.8062770275504659
Fold: 3 | Epoch: 2 | Validation Score: 0.8062770275504659
Validation score improved (0.7757277630053632 --> 0.8062770275504659). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 1 | Average Jaccard: 0.9907155265877754 | Macro F1: 0.8095744071152039
Fold: 1 | Epoch: 3 | Validation Score: 0.8095744071152039
Validation score improved (0.8074059662641452 --> 0.8095744071152039). Saving model!

Fold: 2 | Average Jaccard: 0.9893415610408669 | Macro F1: 0.8312211126497563
Fold: 2 | Epoch: 3 | Validation Score: 0.8312211126497563
Validation score improved (0.8108169648020881 --> 0.8312211126497563). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))


Fold: 0 | Average Jaccard: 0.989665439416633 | Macro F1: 0.8216443851467474
Fold: 0 | Epoch: 3 | Validation Score: 0.8216443851467474
Validation score improved (0.8062484188598154 --> 0.8216443851467474). Saving model!

Fold: 4 | Average Jaccard: 0.9929161363352352 | Macro F1: 0.8232026211744651
Fold: 4 | Epoch: 3 | Validation Score: 0.8232026211744651
Validation score improved (0.8146346606569871 --> 0.8232026211744651). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 1 | Average Jaccard: 0.9907199863210885 | Macro F1: 0.8209650460760191
Fold: 1 | Epoch: 4 | Validation Score: 0.8209650460760191
Validation score improved (0.8095744071152039 --> 0.8209650460760191). Saving model!

Fold: 2 | Average Jaccard: 0.9936331392534148 | Macro F1: 0.8379249612747134
Fold: 2 | Epoch: 4 | Validation Score: 0.8379249612747134
Validation score improved (0.8312211126497563 --> 0.8379249612747134). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))


Fold: 0 | Average Jaccard: 0.9905155195513227 | Macro F1: 0.824114801866553
Fold: 0 | Epoch: 4 | Validation Score: 0.824114801866553
Validation score improved (0.8216443851467474 --> 0.824114801866553). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))


Fold: 4 | Average Jaccard: 0.9942213926510749 | Macro F1: 0.8294489255267483
Fold: 4 | Epoch: 4 | Validation Score: 0.8294489255267483
Validation score improved (0.8232026211744651 --> 0.8294489255267483). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 2 | Average Jaccard: 0.992123008646142 | Macro F1: 0.8365930230014893
Fold: 2 | Epoch: 5 | Validation Score: 0.8365930230014893
EarlyStopping counter: 1 out of 2

Fold: 1 | Average Jaccard: 0.9920769226838477 | Macro F1: 0.8287452529121295
Fold: 1 | Epoch: 5 | Validation Score: 0.8287452529121295
Validation score improved (0.8209650460760191 --> 0.8287452529121295). Saving model!

Fold: 4 | Average Jaccard: 0.9930170014174237 | Macro F1: 0.8306159772773395
Fold: 4 | Epoch: 5 | Validation Score: 0.8306159772773395
Validation score improved (0.8294489255267483 --> 0.8306159772773395). Saving model!

Fold: 0 | Average Jaccard: 0.9939139650872173 | Macro F1: 0.8303941424104928
Fold: 0 | Epoch: 5 | Validation Score: 0.8303941424104928
Validation score improved (0.824114801866553 --> 0.8303941424104928). Saving model!


KeyboardInterrupt: ignored

In [27]:
run(3)

HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9878518471615181 | Macro F1: 0.7748265797680566
Fold: 3 | Epoch: 1 | Validation Score: 0.7748265797680566
Validation score improved (-inf --> 0.7748265797680566). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9879885568256174 | Macro F1: 0.807583219657655
Fold: 3 | Epoch: 2 | Validation Score: 0.807583219657655
Validation score improved (0.7748265797680566 --> 0.807583219657655). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9904137677967776 | Macro F1: 0.8288675885729078
Fold: 3 | Epoch: 3 | Validation Score: 0.8288675885729078
Validation score improved (0.807583219657655 --> 0.8288675885729078). Saving model!


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9886843204560987 | Macro F1: 0.8258755898204555
Fold: 3 | Epoch: 4 | Validation Score: 0.8258755898204555
EarlyStopping counter: 1 out of 2


HBox(children=(FloatProgress(value=0.0, max=1406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=352.0), HTML(value='')))


Fold: 3 | Average Jaccard: 0.9917025878399478 | Macro F1: 0.8395925653294491
Fold: 3 | Epoch: 5 | Validation Score: 0.8395925653294491
Validation score improved (0.8288675885729078 --> 0.8395925653294491). Saving model!
