In [1]:
import numpy as np
import pickle
from tqdm import tqdm
import re
import collections
import glob
from transformers import DistilBertTokenizer, DistilBertModel
import sys
import time
import requests
from IPython.display import display, HTML
from selenium import webdriver
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from bs4 import BeautifulSoup

sys.path.insert(0, '../src/models/')
sys.path.insert(0, '../src/features/')

from predict_model import loadBERT
from predict_model import SpanPredictor as classify
from build_features import text_cleaner, DuckDuckGo_Java, Bing_HTML, colorize_prediction

%matplotlib inline

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
model = loadBERT("../models/", 'saved_weights_inf_FIXED_boot_beta80.pt')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

CPU Success


In [3]:
plants_dict = collections.defaultdict(list)
root = '../data/description/'
# Load the pickle list
data_files = glob.glob(root+ 'description*PLANTS.pkl')
for data_file in data_files:
    dict_ = pickle.load(open(data_file, 'rb'))
    for key, value in dict_.items():
        plants_dict[key] += value
    

# Order the dictionary based on the list length
plants_dict = collections.OrderedDict(sorted(plants_dict.items(), key= lambda x: len(x[1]), reverse=True))
# Correct first key
plants_dict['Poa'] = plants_dict.pop('oa')
# Get keys
plants = [key for key in plants_dict.keys()]

### URLS

In [None]:
# DEBUGGING
data_links = collections.defaultdict(list)
# Init driver
#driver = webdriver.Safari()

queries = ['description', 'diagnosis', '', 'attributes', 'captions']

for species in tqdm(plants[15000:20000]):
    # Empty list
    search_links = []
    
    for query in queries:
        # create query
        species_q = species.replace(' ', '+')
        species_q = f'"{species_q}"+{query}'
        try:
            search_links += DuckDuckGo_Java(species_q, 
                                            driver=driver)
            search_links += Bing_HTML(species_q)
         # Skip connection timeout
        except:
            continue
    # Drop duplicates
    search_links = list(set(search_links))
    if not search_links:
        print('empty')
    # DEBUGGING
    data_links[species] += search_links

In [None]:
with open('../data/description/01_URLS_15000-20000_PLANTS.pkl', 'wb') as f:
    pickle.dump(data_links, f)

### TEXT

In [4]:
URLS = pickle.load(open('../data/description/01_URLS_0-10000_PLANTS.pkl', 'rb'))

In [None]:
sentence_list = collections.defaultdict(list)

species = list(URLS.keys())

for species in tqdm(species[2000:5000]):
    for URL in URLS[species]:
        # Skip google archives
        if 'google' in URL:
            continue
        # PDF and TXT
        if URL.endswith('txt') or URL.endswith('pdf'):
            continue
        # Skip Plants of the world (already done)
        if 'powo' in URL:
            continue
        try:
            page = requests.get(URL, timeout=5)
            # Skip PDF files for now
            if page.headers['Content-Type'].startswith('application/pdf'):
                continue
            # Soup the result
            soup = BeautifulSoup(page.content, 'html.parser', from_encoding="iso-8859-1")    
            # Skip Embedded PDF's
            if 'pdf' in soup.title.text.lower():
                continue
            # Check if species exists somewhere within title
            if bool(set(species.split()).intersection(soup.title.text.split())):
                # Get text
                #dirty_text = soup.get_text(". ", strip=True)
                dirty_text = soup.get_text(" ", strip=False).replace('\n', '.')
                # Clean and break into sents
                sentences = text_cleaner(dirty_text)
                # Append
                sentence_list[species].append(sentences)
        except:
            continue

 52%|█████████████████▎               | 1571/3000 [12:37:23<13:26:50, 33.88s/it]

In [6]:
with open('../data/description/02_SENTS_2000-5000_PLANTS.pkl', 'wb') as f:
    pickle.dump(sentence_list, f)

In [None]:
len(sentence_list.keys())

# Classify

In [None]:
sentence_list = pickle.load(open('../data/description/02_SENTS_0-1000_PLANTS.pkl', 'rb'))

In [None]:

descriptions = collections.defaultdict(list)
species_list = list(sentence_list.keys())

for species in tqdm(species_list):
    for text in sentence_list[species]:
        for sentence in text:
            if classify(sentence, model=model):
                descriptions[species].append(sentence)

        

In [None]:
with open('../data/description/03_DESC_0-1000_PLANTS.pkl', 'wb') as f:
    pickle.dump(descriptions, f)

# Resample

In [None]:
description_dict = pickle.load(open('../data/description/03_DESC_0-1000_PLANTS.pkl', 'rb'))

In [None]:
# Order the dictionary based on the list length
description_dict = collections.OrderedDict(sorted(description_dict.items(), key= lambda x: len(x[1]), reverse=True))

In [None]:
len(description_dict.keys())

In [None]:
#description_dict.keys()

In [None]:
# Update the dict
for species in description_dict.keys():
    sents = [sent for (sent, URL) in plants_dict[species]]
    description_dict[species] += sents
    description_dict[species] = list(set(description_dict[species]))

In [None]:
with open('../data/description/04_TRAIN_0-817_PLANTS.pkl', 'wb') as f:
    pickle.dump(description_dict, f)