In [1]:
import numpy as np
import pandas as pd
import torch
import pickle
import re
from itertools import chain
from collections import Counter
import torch.nn as nn
import glob
import random
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
import transformers
from transformers import AdamW
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, Dataset, random_split
from tqdm import tqdm
import time

%matplotlib inline

In [2]:
# specify device
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

# Load the BERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Bert mode
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class SpeciesDescriptions(Dataset):
    
    """Description dataset without species names."""
    
    def __init__(self, root_dir):
        
        self.root_dir = root_dir
        self.samples = []
        self._init_dataset()
        #self.label_encoder = LabelEncoder()
        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        return self.samples[idx]
     
    def _init_dataset(self):
        
        # Encoder 
        label_encoder = LabelEncoder()
        # Init dict
        datadict = {}
        # Load the pickle list
        data_files = glob.glob(self.root_dir + 'TEST*PLANTS.pkl')
        for data_file in data_files:
            # Open the dict and update
            datadict.update(pickle.load(open(data_file, 'rb')))
            
        # Get keys and encode them
        keys = np.array([key for key in datadict.keys()])
        print(len(keys))
        keys_encoded = label_encoder.fit_transform(keys)
        # Extract the values with the encoded keys
        self.samples += [(key_label, value) for key_label, (key, value_list) in zip(keys_encoded, datadict.items()) for value in value_list]
        
        '''
        for key_encode, (key, value_list) in zip(keys_encoded, datadict.items()):
            if len(value_list) < 10:
                continue
            for i in range(len(value_list):
                self.samples.append((key_encode, ' '.join(random.sample(value_list, 5))))
        '''

In [4]:
try:
    # Colab
    from google.colab import drive
    root = '/content/gdrive/My Drive/'
    drive.mount('/content/gdrive')
    print('Mounted @Google')
except:
    # Local
    root = "../data/processed/"
    print('Mounted @Local')

start = time.time()
# Load data
data = SpeciesDescriptions(root)
end = time.time()
print("Time consumed in working: ",end - start)

Mounted @Local
35746
Time consumed in working:  0.6807208061218262


In [5]:
total_count = len(data)
train_count = int(0.8 * total_count)
valid_count = int(0.1 * total_count)
test_count = total_count - train_count - valid_count
train_dataset, valid_dataset, test_dataset = random_split(data, (train_count, valid_count, test_count), 
                                                       generator=torch.Generator().manual_seed(33))

In [6]:
batch_size = 2

# Random sample (skewed set)
train_sampler = RandomSampler(train_dataset)
# DataLoader for train set
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

# Random sample
val_sampler = SequentialSampler(valid_dataset)
# DataLoader for validation set
val_dataloader = DataLoader(valid_dataset, sampler=val_sampler, batch_size=batch_size)

In [7]:
# Freeze all the parameters
for param in bert.parameters():
    param.requires_grad = False

In [8]:
class BERT(nn.Module):
    def __init__(self, bert):
        
        super(BERT, self).__init__()
        
        # Distil Bert model
        self.bert = bert
        ## Additional layers
        # Dropout layer
        self.dropout = nn.Dropout(0.3)
        # Relu 
        self.relu =  nn.ReLU()
        # Linear I 
        self.fc1 = nn.Linear(768, 512)
        # Linear II (Out)
        self.fc2 = nn.Linear(512, 35746)
        # Softmax
        self.softmax = nn.LogSoftmax(dim=1)


    # Forward pass
    def forward(self, **kwargs):

        # Pass data trough bert and extract 
        cls_hs = self.bert(**kwargs)
        # Extract hidden state
        hidden_state = cls_hs.last_hidden_state
        # Only first is needed for classification
        pooler = hidden_state[:, 0]
        
        # Dense layer 1        
        x = self.fc1(pooler)
        # ReLU activation
        x = self.relu(x)
        # Drop out
        x = self.dropout(x)
        # Dense layer 2
        x = self.fc2(x)
        # Activation
        x = self.softmax(x)

        return x

In [9]:
# Load the entire model
model = BERT(bert)

# Load trained model (colab)
try:
    try:
        model_save_name = 'saved_weights_NLP_test.pt'
        path = F"/content/gdrive/My Drive/{model_save_name}"
        model.load_state_dict(torch.load(path))
        print('Google Success')

    except:
        model_save_name = 'saved_weights_NLP_test.pt'
        path = "../models/" + model_save_name
        model.load_state_dict(torch.load(path, 
                                         map_location=torch.device('cpu')))
        print('Local Success')
except:
    print('No pretrained model found.')

# Push the model to GPU
model = model.to(device)

Local Success


In [10]:
# Load optimizer (Adam best for bert)
optimizer = torch.optim.Adam(params = model.parameters(), lr=3e-5)
# Define loss function
softmax = nn.Softmax(1)
CEloss = nn.CrossEntropyLoss()

def tokenize_batch(batch_set):
    
    """
    Tokenize a pytorch dataset using the hugging face tokenizer.
    """
    
    # Extract the labels and text
    y = batch_set[0]
    text = batch_set[1]
    
    
    # Tokenize the text
    tokens = tokenizer.batch_encode_plus(text,
                max_length = 512,
                padding=True,
                truncation=True)
    
    # Convert to tensors
    seq = torch.tensor(tokens['input_ids'])
    mask = torch.tensor(tokens['attention_mask'])
    
    return seq, mask, y

def train():
  
    """
    Function to train classification Bert model.
    """
    
    model.train()
    total_loss = 0
    
    # Iterate over batches
    for batch in tqdm(train_dataloader):
        
        # Tokenize batch
        train_seq, train_mask, train_y = tokenize_batch(batch)
        # Push to device
        sent_id, mask, labels = [t.to(device) for t in [train_seq, train_mask, train_y]]
        # Clear gradients 
        model.zero_grad()        
        # Get predictions
        preds = model(sent_id, mask)
        # Compute loss
        loss =  CEloss(preds, labels) 
        #loss = cross_entropy(preds, labels)
        # Update total loss
        total_loss = total_loss + loss.item()
        # Backward pass to calculate the gradients
        loss.backward()
        # Clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # Update parameters
        optimizer.step()

    # Compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)

    return avg_loss


def evaluate():
    
    """
    Function to test classification Bert model.
    """
  
    # Deactivate dropout layers
    model.eval()
    total_loss = 0

    # Iterate over batches
    for batch in tqdm(val_dataloader):   
        # Tokenize batch
        val_seq, val_mask, val_y = tokenize_batch(batch)
        # Push to device
        sent_id, mask, labels = [t.to(device) for t in [val_seq, val_mask, val_y]]
        # Deactivate autograd
        with torch.no_grad():
            # Model predictions
            preds = model(sent_id, mask)
            # Compute the validation loss between actual and predicted values
            loss =  CEloss(preds, labels) 
            #loss = cross_entropy(preds,labels)
            total_loss = total_loss + loss.item()

    # Compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader) 

    return avg_loss

In [None]:
batch = next(iter(train_dataloader))

train_seq, train_mask, train_y = tokenize_batch(batch)
# Push to device
sent_id, mask, labels = [t.to(device) for t in [train_seq, train_mask, train_y]]

# Get predictions
preds = model(sent_id, mask)

In [None]:
preds

In [None]:
# Epochs
epochs = 1

# Init loss
best_valid_loss = float('inf')

# data lists
train_losses=[]
valid_losses=[]

# Loop over epochs
for epoch in range(epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    
    # Train model
    train_loss = train() 
    # Evaluate model
    valid_loss  = evaluate()
        
    # Append training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print(f'\nTraining Loss: {train_loss:.6f}')
    print(f'Validation Loss: {valid_loss:.6f}')

In [24]:
strings = """
 The copper beech is a large tree that often grows on chalky soil. The inconspicuous flowers are pollinated by the wind and the fruits are three-sided nuts (called beech mast) inside spiky cases.

The size and shape of Fagus sylvatica depends on its environment. Given space, the beech will spread its branches widely and can grow up to a massive 42 m (140 ft) high. In tightly packed woods, the tree will grow straight, with few side branches, to reach the light. This becomes an even greater priority with time.

The arrangement of leaves is such that they overlap, which, while efficient for the tree, shades the ground beneath and can also prevent rain from reaching it. So if the floor of the English wood you're walking through comprises little more than fungi and rotting leaves, it's probably a beech wood.

Copper beech trees are quite variable in leaf colour as they are normally propagated by seed. Particularly dark clones have been selected over the years and those are grafted to maintain the true colour. The best-known are 'Cuprea', 'Nigra', 'Riversii' and 'Spaethiana'. 
"""
#string = '. '.join(strings)

In [25]:
tokens = tokenizer(strings, return_tensors="pt", truncation=True)

In [26]:
outputs = model(**tokens)

In [27]:
exps = torch.exp(outputs)
# Get class
span_class = exps.argmax(1).item()

In [28]:
span_class

7785

In [29]:
exps = exps.detach().numpy()

In [30]:
n=10

numbers = np.array([1, 3, 2, 4])
numbers = np.squeeze(exps)
idx = np.argpartition(numbers, -n)[-n:]
indices = idx[np.argsort((-numbers)[idx])]

In [35]:
indices

array([ 7785, 24482,   328,   322,  3027,  4493, 27624,   330, 23253,
       29063])

In [32]:
# Encoder 
label_encoder = LabelEncoder()
# Init dict
datadict = {}
# Load the pickle list
data_files = glob.glob(root + 'TEST*PLANTS.pkl')
for data_file in data_files:
    # Open the dict and update
    datadict.update(pickle.load(open(data_file, 'rb')))

# Get keys and encode them
keys = np.array([key for key in datadict.keys()])
print(len(keys))
keys_encoded = label_encoder.fit_transform(keys)

35746


In [33]:
label_encoder.inverse_transform(indices)

array(['Chorisia speciosa', 'Sandoricum koetjape', 'Adansonia situla',
       'Adansonia bahobab', 'Arecaceae', 'Baobabus digitata',
       'Terebinthus microphylla', 'Adansonia sphaerocarpa',
       'Pircunia dioica', 'Urostigma petiolaris'], dtype='<U56')