In [1]:
import random
from tqdm import tqdm
import time
import pickle
from itertools import chain
from tqdm import tqdm
from transformers import DistilBertTokenizer, DistilBertModel

import sys
sys.path.insert(0, '../src/models/')
sys.path.insert(0, '../src/features/')
#sys.path.insert(0, '../src/visualization/')

from predict_model import loadBERT
from predict_model import SpanPredictor as classify
from build_features import text_cleaner
from build_features import get_prediction_results
#import visualize as vis

%matplotlib inline

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Load the model, tokenizer and external datasets

In [2]:
model = loadBERT("../models/", 'saved_weights_inf_FIXED_boot_beta80.pt')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

CPU Success


In [3]:
test_agro = pickle.load(open('../data/processed/testing_web_trees_agroforestry.pkl', 'rb'))
test_llifle = pickle.load(open('../data/processed/testing_web_trees_llifle.pkl', 'rb'))

data = dict(test_agro) 
data.update(test_llifle)

## Get uncertain labels

In [6]:
# Undict the data
external_data = (list(chain.from_iterable(data.values())))

In [7]:
# Init emmty_list:
uncertain_list = []

# loop over the values of the list
for (label, span) in tqdm(external_data):
    # Clean the sentence
    sentences = text_cleaner(span)
    # Loop over the sentences
    for sent in sentences:
        # Get predictions
        (pred_label, pred_value) = classify(sent, model=model, pred_values=True)
        # Detach
        pred_value = pred_value[1].numpy().item()
        if  0.45 < pred_value < 0.55:
            # Append to list
            uncertain_list.append((pred_label, pred_value, sent))
            #print(pred_value, sent)

uncertain_list = list(set(uncertain_list))            

100%|██████████████████████████████████| 17204/17204 [59:57<00:00,  4.78it/s]


In [6]:
with open('../data/processed/uncertain_sentences.pkl', 'wb') as f:
    pickle.dump(uncertain_list, f)

In [8]:
print(len(uncertain_list))
uncertain_list[0:30]

546


[(0,
  0.49759429693222046,
  'Biology Individual trees often flower irregularly, some trees do not flower for periods of 10-20 years, sometimes even longer.'),
 (0,
  0.48895519971847534,
  'The bat-adapted fruits with strong musty odour and colour duller than those of bird-dispersed ones.'),
 (1,
  0.5126233100891113,
  'The fruits often explosively dehisce dispersing their seeds.'),
 (0, 0.4915165305137634, 'Fruit usual for the Fabaceae family.'),
 (0, 0.45987844467163086, 'Timber: Sapwood is whitish and heartwood brown.'),
 (0,
  0.45801159739494324,
  'Individual trees growing on steep slopes tend to develop some buttressing that extends from the roots up into the base of the stem.'),
 (0,
  0.4520483613014221,
  'The 1st fruits are borne 3-4 years after flowering, and trees do not start to bear fruit until they are 15-20 years old.'),
 (1,
  0.5360655188560486,
  'Some branch profusely and others are more solitary.'),
 (1,
  0.5083258748054504,
  'Description: Mainly deciduous fa

## Prediction on external dataset

In [4]:
report, misclassified_sents = get_prediction_results(data, 
                                                     model=model,
                                                     soft_error=True, 
                                                     beta=0.70)

100%|██████████████████████████████████| 10602/10602 [55:45<00:00,  3.17it/s]


In [11]:
print(report)

              precision    recall  f1-score   support

         0.0       0.99      0.97      0.98     60437
         1.0       0.79      0.92      0.85      7019

    accuracy                           0.97     67456
   macro avg       0.89      0.94      0.92     67456
weighted avg       0.97      0.97      0.97     67456



In [5]:
print(report)

              precision    recall  f1-score   support

         0.0       0.99      0.98      0.99     60198
         1.0       0.85      0.95      0.90      7310

    accuracy                           0.98     67508
   macro avg       0.92      0.96      0.94     67508
weighted avg       0.98      0.98      0.98     67508

