In [134]:
import numpy as np
import pandas as pd
import torch
import pickle
import torch.nn as nn
import glob
import transformers
from bs4 import BeautifulSoup
import requests
import re
import random
import time
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict
import pdfplumber
from tqdm import tqdm
import collections
from statistics import mean
from selenium import webdriver
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, Dataset

import spacy
nlp = spacy.load("en_core_web_trf", exclude=["tagger", "ner", "lemmatizer", ])

import sys
sys.path.insert(0, '../src/models/')
sys.path.insert(0, '../src/features/')
import predict_model
#from build_features import random_text_splitter as split_text

# Load BERT
model = predict_model.loadBERT("../models/", 'saved_weights_inf_FIXED.pt')
# Load the BERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

%matplotlib inline

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CPU Success


In [2]:
def text_splitter(text, split_value=20):
    
    # Split text
    words = text.split()
    if len(words) < 10:
        return [text]
    # Create counter
    remaining_word_amount = len(words)
    # Init list
    parts = []
    # While words remaining
    while remaining_word_amount > 0:
        if len(words) < 20:
            # Add last part if less then 10
            parts.append(' '.join(words))
            # exit
            break
 
        # Append to list 
        parts.append(' '.join(words[:split_value]))
        # Delete previous selection
        words = words[split_value:]
        # Update counter
        remaining_word_amount -= split_value
        
        #print(parts)
        
    return parts

In [3]:
def SpanPredictor(span, pred_values=False):
    
    """
    Uses a trained bert classifier to see if a span
    belongs to a species description or otherwise.
    """
         
    with torch.no_grad():
        # Tokenize input
        inputs = tokenizer(span, return_tensors="pt", truncation=True)
        # Predict class
        outputs = model(inputs['input_ids'], inputs['attention_mask'])
        # Get prediction values
        exps = torch.exp(outputs)
        # Get class
        span_class = exps.argmax(1).item()

        # Print the prediction values
        if pred_values:
            return span_class, exps[0]
        else:
            return span_class    

In [98]:
def highlight_text(predictions, save=False):
    
    """
    Highlight text based on the predicted value by the 
    description deep learning BERT model.
    """
    

    cmap = matplotlib.cm.get_cmap('Spectral')
    
    # Call figure and add axis
    fig = plt.figure(figsize=(16, 16))
    ax = fig.add_subplot()
    # Set axis
    ax.axis([0, 10, 0, 10])
    ax.axis('off')
    
    y_axis_start = 10
    x_axis_start = 0.0
    # Loop over the text
    for sent, pred in predictions:
        
        #pred = pred[1][1].numpy().item()

        # Set axis with correct color and align right
        if len(sent.split()) <= 5:
            ax.text(.5, y_axis_start,  sent, ha='left', wrap=True, fontsize=13,
                    bbox={'facecolor': 'grey', 'alpha': 0.6, 'pad': 2})
        elif pred > 0.5:
            ax.text(.25, y_axis_start, sent, ha='left', weight="bold", wrap=True,
                    bbox={'facecolor': cmap(pred), 'alpha': 0.6, 'pad': 4})
        else :
            ax.text(.25, y_axis_start, sent, ha='left', wrap=True,
                    bbox={'facecolor': cmap(pred), 'alpha': 0.6, 'pad': 4})
        
        #if x_axis_start 
        # Update axis start
        y_axis_start -= .22
        x_axis_start += 0

In [5]:
URL = 'http://db.worldagroforestry.org//species/properties/Enterolobium_cyclocarpum'
page = requests.get(URL, timeout=5)
soup = BeautifulSoup(page.content, 'html.parser')

In [41]:
text = soup.get_text(" ", strip=True)
text_list = [text for text in soup.stripped_strings if len(text.split()) > 5]

' '.join(text_list)

text_list_5 = text_splitter(text, split_value=5)
text_list_10 = text_splitter(text, split_value=10)
text_list_20 = text_splitter(text, split_value=20)


text_list_5_preds = [tuple([text, SpanPredictor(text, pred_values=True)]) for text in text_list_5]
text_list_10_preds = [tuple([text, SpanPredictor(text, pred_values=True)]) for text in text_list_10]
text_list_20_preds = [tuple([text, SpanPredictor(text, pred_values=True)]) for text in text_list_20]

In [42]:
text_list_preds = []

list10 = -1
list20 = -1

for count, (text, (label, pred)) in enumerate(text_list_5_preds):
    #print(count)
    if count % 2 == 0:
        list10 += 1
        #print(count, list10)
    if count % 4 == 0:
        list20 += 1
        
    pred5 = pred[1].numpy().item()
    pred10 = text_list_10_preds[list10][1][1][1].numpy().item()
    pred20 = text_list_20_preds[list20][1][1][1].numpy().item()
    
    newpred = mean([pred5, pred10, pred20])
    
    text_list_preds.append(tuple([text, newpred]))
        
    #print(count, list10, list20)

In [114]:
#highlight_text(text_list_preds)

In [108]:
URL = 'http://db.worldagroforestry.org//species/properties/Enterolobium_cyclocarpum'
page = requests.get(URL, timeout=5)
soup = BeautifulSoup(page.content, 'html.parser')

In [109]:
#text = soup.get_text(" ", strip=True)
text_list = [text for text in soup.stripped_strings]

In [110]:
text_list_splitted_20 = [text_splitter(text) for text in text_list]

span_list_20_predictions = [tuple([span, SpanPredictor(span, pred_values=True)[1][1].numpy().item()])
                            for span_list in tqdm(text_list_splitted_20) 
                            for span in span_list]



100%|███████████████████████████████████████████| 43/43 [00:03<00:00, 13.49it/s]


In [135]:
docs = [nlp(text) for text in text_list]

In [136]:
for sent in docs[19].sents:
    print(sent, SpanPredictor(str(sent)))

The small, white flowers are borne in clusters or heads at the base of the leaves. 1
Flowering takes place in March and April during the regrowth of new leaves after the leafless dry season. 0
There is no indication in the literature as to what age flowers 1st appear. 0
Seed dissemination is mainly by cattle, horses and wild ungulates, attracted by the syrupy pulp of the fruits. 0


In [39]:
SpanPredictor(text_list[0], pred_values=True)

(1, tensor([0.1581, 0.8419]))

In [None]:
"""
def highlight_text(predictions, save=False):
    
    """
    Highlight text based on the predicted value by the 
    description deep learning BERT model.
    """
    

    cmap = matplotlib.cm.get_cmap('Spectral')
    
    # Call figure and add axis
    fig = plt.figure(figsize=(16, 16))
    ax = fig.add_subplot()
    # Set axis
    ax.axis([0, 10, 0, 10])
    ax.axis('off')
    
    y_axis_start = 10
    x_axis_start = 0.0
    # Loop over the text
    
    for sent, pred in predictions:
                
        # Set axis with correct color and align right
        if pred > 0.5:
            ax.text(x_axis_start, y_axis_start, sent, ha='left', weight="bold", wrap=True,
                    bbox={'facecolor': cmap(pred), 'alpha': 0.6, 'pad': 2})
        else :
            ax.text(x_axis_start, y_axis_start, sent, ha='left', wrap=True,
                    bbox={'facecolor': cmap(pred), 'alpha': 0.6, 'pad': 2})
        

        #if x_axis_start 
        # Update axis start
        if x_axis_start > 10:
            x_axis_start = 0
            y_axis_start -= .25
        else:
            x_axis_start += len(sent.lower()) / 14
"""