In [1]:
import numpy as np
import pandas as pd
import torch
import pickle
import re
import requests
from matplotlib import cm
import matplotlib
from bs4 import BeautifulSoup
import collections
from itertools import chain
from collections import Counter
import torch.nn as nn
import glob
import random
import string
from pathlib import Path
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, Dataset, random_split
from tqdm import tqdm
import time
import urllib.parse

import spacy
from spacy import displacy
from spacy.matcher import PhraseMatcher
from spacy.tokens import Span
from spacy.util import filter_spans
nlp = spacy.load("en_core_web_lg")

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

import sys
sys.path.insert(0, '../src/models/')
sys.path.insert(0, '../src/features/')
#sys.path.insert(0, '../src/visualization/')

import predict_model
from build_features import text_cleaner
#import visualize as vis

%matplotlib inline

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
model = predict_model.loadBERT("../models/", 'saved_weights_inf_FIXED_boot.pt')
sim_model = predict_model.load_simBERT()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

CPU Success


In [3]:
def classify(span, pred_values=False):
    
    """
    Uses a trained bert classifier to see if a span
    belongs to a species description or otherwise.
    """
         
    with torch.no_grad():
        # Tokenize input
        inputs = tokenizer(span, return_tensors="pt", truncation=True)
        # Predict class
        outputs = model(**inputs)
        # Get prediction values
        exps = torch.exp(outputs)
        # Get class
        span_class = exps.argmax(1).item()

        # Print the prediction values
        if pred_values:
            return span_class, exps[0]
        else:
            return span_class
        

def similarity_matrix(sentence_list):
    
    """
    Calculates a hidden state array per sententence based on a list of
    sentences.
    """
    
    # Initialize dictionary to store tokenized sentences
    tokens = {'input_ids': [], 'attention_mask': []}

    for sentence in sentence_list:
        # encode each sentence and append to dictionary
        new_tokens = tokenizer.encode_plus(sentence, max_length=512,
                                           truncation=True, 
                                           padding='max_length',
                                           return_tensors='pt')
        # Drop the batch dimension
        tokens['input_ids'].append(new_tokens['input_ids'][0])
        tokens['attention_mask'].append(new_tokens['attention_mask'][0])
    
    # Reformat list of tensors into single tensor
    tokens['input_ids'] = torch.stack(tokens['input_ids'])
    tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
    
    # Get vectors
    hiddenstates = sim_model(**tokens)
    # Sum along first axis
    summed_hs = torch.sum(hiddenstates, 1)
    # Detach
    summed_hs_np = summed_hs.detach().numpy()
    # Get the matrix
    return cosine_similarity(summed_hs_np, summed_hs_np).round(5)

### BIRD OF THE WORLD
Structured queries of the BOW website and using the BERT model to scrape the HTML's.

In [None]:
# Read all html files
species_folder = glob.glob('../data/raw/BOW/*')
single_list = [Species + '/Introduction.html' for Species in species_folder if len(glob.glob(Species + '/*')) == 1]
multi_list = [glob.glob(Species + '/*') for Species in species_folder if len(glob.glob(Species + '/*')) != 1]

In [None]:
data = collections.defaultdict(list)

for html in tqdm(single_list):
    try:
        with open(html) as f:
            # Structure it
            soup = BeautifulSoup(f, 'html.parser')
            # Extract name
            species = soup.title.text.strip().split(' - ')[1]
            # Find all non Identification Spans
            spans = [span for span in soup.find_all('p') 
                     if not 'fig' in span.text]
            for span in spans:
                if span.find_previous_sibling() != None:
                    if span.find_previous_sibling().find('h2') != None:
                        # Locate Identification
                        if span.find_previous_sibling().find('h2').text == 'Identification':
                            text_id = span.text.strip().replace('\n', "").replace('; ', '. ')
                            text_id = re.sub(' +',' ', text_id)
                            text_list = text_id.split('. ')
                            for sentence in text_list:
                                if classify(sentence):
                                    #print(URL)
                                    data[species].append(sentence)
                        else:
                            sentences = text_cleaner(span.text)
                            # Loop over the individual sentences
                            for sentence in sentences:                    
                                # Create string object
                                sentence_str = str(sentence)
                                #print(sentence_str)

                                if classify(sentence_str):
                                    #print(URL)
                                    data[species].append(sentence_str)
                    else:
                        sentences = text_cleaner(span.text)
                        # Loop over the individual sentences
                        for sentence in sentences:                    
                            # Create string object
                            sentence_str = str(sentence)
                            #print(sentence_str)

                            if classify(sentence_str):
                                #print(URL)
                                data[species].append(sentence_str)
                else:
                    continue
    except:
        print('fail')    # Continue if HTML fails to open
        continue

for species_list in tqdm(multi_list):
    try:
        with open(species_list[0]) as f:
            # Read only the first html title (there are some inconsistencies)
            soup = BeautifulSoup(f, 'html.parser')
            species = soup.title.text.strip().split(' - ')[2]
        for html in species_list:

            with open(html) as f:
                # Structure it
                soup = BeautifulSoup(f, 'html.parser')
                # Get all spans
                spans = [span for span in soup.find_all('p') 
                         if not 'fig' in span.text]
                for span in spans:

                    sentences = text_cleaner(span.text)
                    # Loop over the individual sentences
                    for sentence in sentences:                    
                        # Create string object
                        sentence_str = str(sentence)
                        #print(sentence_str)

                        if classify(sentence_str):
                            #print(URL)
                            data[species].append(sentence_str)
    except:
        continue
                        
# Dump pickle into file
with open('../data/processed/scrapeddata_train_species_description_bow.pkl', 'wb') as f:
    pickle.dump(data, f)

### AGROFORESTRY

In [4]:
# Init
agro_list = []

for letter in tqdm(list(string.ascii_uppercase)):
    # Get URL
    URL = 'http://apps.worldagroforestry.org/treedb2/index.php?letter={0}'.format(letter)
    page = requests.get(URL, timeout=5)
    soup = BeautifulSoup(page.content, 'html.parser')
    
    # Extract search table
    for tree in soup.find_all('table')[2].find_all('a'):
        agro_list.append(tree.text)
        
# Create data list        
agro_data = [(tree, 'http://db.worldagroforestry.org//species/properties/' + '_'.join(tree.split(' '))) for tree in agro_list]

100%|███████████████████████████████████████████| 26/26 [00:09<00:00,  2.72it/s]


In [5]:
# Init dict
data = collections.defaultdict(list)

# Loop over URL
for (species, URL) in tqdm(agro_data):
    #print(species, URL)
    try:
        page = requests.get(URL, timeout=5)
        soup = BeautifulSoup(page.content, 'html.parser')
        dirty_text = soup.get_text(". ", strip=True)
        sentences = text_cleaner(dirty_text)
        
        # Loop over sent list
        for sentence in sentences:
            #print(sentence)
            if classify(sentence):
                data[species].append((sentence, URL))
                #print(sentence, URL)
    except:
        continue
        
# Dump pickle into file
with open('../data/processed/descriptions_agroforestry_PLANTS.pkl', 'wb') as f:
    pickle.dump(data, f)

100%|█████████████████████████████████████████| 617/617 [27:49<00:00,  2.71s/it]


### LLIFE

In [6]:
# init index list
tree_links_index = []
# Extract index pages
for i in range(1, 8):
    tree_links_index.append('http://www.llifle.com/Encyclopedia/TREES/Species/all/{0}/100/'.format(i))

# Init empty list
tree_links = []

for index_pages in tqdm(tree_links_index):
    # Extract XML
    URL = index_pages
    page = requests.get(URL)
    soup = BeautifulSoup(page.content, 'html.parser')
    # Extract links incomplete
    tree_links_half = soup.find_all('a')

    # Complete the links
    tree_links_temp = ['http://www.llifle.com' + pages.get('href') for pages in tree_links_half
                           if pages.get('href') != None 
                           if pages.get('href').startswith('/Encyclopedia/TREES/Family/')]
    # Add to all trees
    tree_links += tree_links_temp

100%|█████████████████████████████████████████████| 7/7 [00:05<00:00,  1.39it/s]


In [7]:
# Init empty dict
data = collections.defaultdict(list)

for URL in tqdm(tree_links):
    try:
        # Get Page
        page = requests.get(URL)
        # Structure page
        soup = BeautifulSoup(page.content, 'html.parser')
        # Name
        species = soup.title.text.replace('\n', '')
        # Loop over text
        dirty_text = soup.get_text(". ", strip=True)
        sentences = text_cleaner(dirty_text)
        
        # Loop over sent list
        for sentence in sentences:
            #print(sentence)
            if classify(sentence):
                data[species].append((sentence, URL))
    except:
        continue
        
# Dump pickle into file
with open('../data/processed/descriptions_llife_PLANTS.pkl', 'wb') as f:
    pickle.dump(data, f)

100%|█████████████████████████████████████████| 647/647 [18:16<00:00,  1.69s/it]


### PLANT OF THE WORLD ONLINE

In [8]:
# Init empty dict
data = collections.defaultdict(list)

# Read files
powo_HTMLs = glob.glob('../data/raw/POWO/*')

for html in tqdm(powo_HTMLs):
    try:
        # Open HTML file
        with open(html) as f:
            soup = BeautifulSoup(f, 'html.parser')
            # Extract title
            species = html.lstrip('../data/raw/POWO/').split(' - ')[0]
            # Loop over text
            dirty_text = soup.get_text(". ", strip=True)
            sentences = text_cleaner(dirty_text)

            # Loop over sent list
            for sentence in sentences:
                #print(sentence)
                if classify(sentence):
                    data[species].append((sentence, URL))
    except:
        #print('fail')
        continue
        
# Dump pickle into file
with open('../data/processed/descriptions_powo_PLANTS.pkl', 'wb') as f:
    pickle.dump(data, f)

100%|██████████████████████████████████| 40187/40187 [18:56:57<00:00,  1.70s/it]
