In [1]:
from pytorch_metric_learning import losses, miners, samplers, trainers, testers
from pytorch_metric_learning.utils import common_functions
import pytorch_metric_learning.utils.logging_presets as logging_presets
import numpy as np
import torchvision
from torchvision import datasets, transforms
import torch
import glob
import pickle
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
import pytorch_metric_learning
import transformers
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, Dataset, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder

The pytorch-metric-learning testing module requires faiss. You can install the GPU version with the command 'conda install faiss-gpu -c pytorch'
                        or the CPU version with 'conda install faiss-cpu -c pytorch'. Learn more at https://github.com/facebookresearch/faiss/blob/master/INSTALL.md


In [2]:
# specify device
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'
# Load the BERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# Bert mode
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class SpeciesDescriptions(Dataset):
    
    """Description dataset without species names."""
    
    def __init__(self, root_dir):
        
        self.root_dir = root_dir
        self.samples = []
        self._init_dataset()
        #self.label_encoder = LabelEncoder()
        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        return self.samples[idx]
     
    def _init_dataset(self):
        
        # Encoder 
        label_encoder = LabelEncoder()
        # Init dict
        datadict = {}
        # Load the pickle list
        data_files = glob.glob(self.root_dir + 'subset*.pkl')
        for data_file in data_files:
            # Open the dict and update
            datadict.update(pickle.load(open(data_file, 'rb')))
            
        # Get keys and encode them
        keys = np.array([key for key in datadict.keys()])
        print(len(keys))
        keys_encoded = label_encoder.fit_transform(keys)
        # Extract the values with the encoded keys
        self.samples += [(key_label, value[0]) for key_label, (key, value_list) in zip(keys_encoded, datadict.items()) for value in value_list]
        

In [4]:
try:
    # Colab
    from google.colab import drive
    root = '/content/gdrive/My Drive/'
    drive.mount('/content/gdrive')
    print('Mounted @Google')
except:
    # Local
    root = "../data/processed/"
    print('Mounted @Local')

# Load data
data = SpeciesDescriptions(root)

Mounted @Local
170


In [5]:
total_count = len(data)
train_count = int(1.0 * total_count)
valid_count = int(0.0 * total_count)
test_count = total_count - train_count - valid_count
train_dataset, valid_dataset, test_dataset = random_split(data, (train_count, valid_count, test_count), 
                                                       generator=torch.Generator().manual_seed(33))

In [6]:
batch_size = 2

# Random sample (skewed set)
train_sampler = RandomSampler(train_dataset)
# DataLoader for train set
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

# Random sample
val_sampler = SequentialSampler(valid_dataset)
# DataLoader for validation set
val_dataloader = DataLoader(valid_dataset, sampler=val_sampler, batch_size=batch_size)

In [7]:
# Freeze all the parameters
for param in bert.parameters():
    param.requires_grad = True

In [14]:
class BERT(nn.Module):
    def __init__(self, bert):
        
        super(BERT, self).__init__()
        
        # Distil Bert model
        self.bert = bert

    # Forward pass
    def forward(self, **kwargs):

        # Pass data trough bert and extract 
        cls_hs = self.bert(**kwargs)
        # Extract hidden state
        hidden_state = cls_hs.last_hidden_state
        # Only first is needed for classification
        x = hidden_state[:, 0]

        return x

In [15]:
# Load the entire model
model = BERT(bert)

# Load trained model (colab)
try:
    try:
        model_save_name = 'saved_weights_NLP_test.pt'
        path = F"/content/gdrive/My Drive/{model_save_name}"
        model.load_state_dict(torch.load(path))
        print('Google Success')

    except:
        model_save_name = 'saved_weights_NLP_subset.pt'
        path = "../models/" + model_save_name
        model.load_state_dict(torch.load(path, 
                                         map_location=torch.device('cpu')))
        print('Local Success')
except:
    print('No pretrained model found.')

# Push the model to GPU
model = model.to(device)

No pretrained model found.


In [16]:
# Load optimizer (Adam best for bert)
optimizer = torch.optim.Adam(params = model.parameters(), lr=3e-5)
# Define loss function
softmax = nn.Softmax(1)
CEloss = nn.CrossEntropyLoss()

TripletLoss = losses.TripletMarginLoss(margin=0.1)

def tokenize_batch(batch_set):
    
    """
    Tokenize a pytorch dataset using the hugging face tokenizer.
    """
    
    # Extract the labels and text
    y = batch_set[0]
    text = batch_set[1]
    
    
    # Tokenize the text
    tokens = tokenizer.batch_encode_plus(text,
                max_length = 512,
                padding=True,
                truncation=True)
    
    # Convert to tensors
    seq = torch.tensor(tokens['input_ids'])
    mask = torch.tensor(tokens['attention_mask'])
    
    return seq, mask, y

def train():
  
    """
    Function to train classification Bert model.
    """
    
    model.train()
    total_loss = 0
    
    # Iterate over batches
    for batch in tqdm(train_dataloader):
        
        # Tokenize batch
        train_seq, train_mask, train_y = tokenize_batch(batch)
        # Push to device
        sent_id, mask, labels = [t.to(device) for t in [train_seq, train_mask, train_y]]
        # Clear gradients 
        model.zero_grad()        
        # Get predictions
        preds = model(sent_id, mask)
        # Compute loss
        loss =  TripletLoss(preds, labels) 
        #loss = cross_entropy(preds, labels)
        # Update total loss
        total_loss = total_loss + loss.item()
        # Backward pass to calculate the gradients
        loss.backward()
        # Clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # Update parameters
        optimizer.step()

    # Compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)

    return avg_loss


def evaluate():
    
    """
    Function to test classification Bert model.
    """
  
    # Deactivate dropout layers
    model.eval()
    total_loss = 0

    # Iterate over batches
    for batch in tqdm(val_dataloader):   
        # Tokenize batch
        val_seq, val_mask, val_y = tokenize_batch(batch)
        # Push to device
        sent_id, mask, labels = [t.to(device) for t in [val_seq, val_mask, val_y]]
        # Deactivate autograd
        with torch.no_grad():
            # Model predictions
            preds = model(sent_id, mask)
            # Compute the validation loss between actual and predicted values
            loss =  TripletLoss(preds, labels) 
            #loss = cross_entropy(preds,labels)
            total_loss = total_loss + loss.item()

    # Compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader) 

    return avg_loss

In [61]:
# Epochs
epochs = 1

# Init loss
best_valid_loss = float('inf')

# data lists
train_losses=[]

# Loop over epochs
for epoch in range(epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    
    # Train model
    train_loss = train() 
        
    # Append training and validation loss
    train_losses.append(train_loss)
    
    print(f'\nTraining Loss: {train_loss:.6f}')



 Epoch 1 / 1


100%|█████████████████████████████████| 4190/4190 [55:10<00:00,  1.27it/s]


Training Loss: 0.000000





In [17]:
string = 'This is a test.'

In [18]:
test = tokenizer(string, return_tensors="pt", truncation=True)

In [19]:
output = model(**test)

In [20]:
output

tensor([[-0.5201,  0.5049, -0.8210,  0.1483,  0.4277, -0.8137,  1.4250,  0.3003,
         -0.5727,  0.5645, -0.0306, -0.4534,  0.3267,  0.8140, -0.7719,  0.8386,
          0.2793, -0.7581,  0.7456,  1.1674,  0.1890, -0.2233, -0.0487, -0.1938,
          0.1842, -0.0789, -0.1713, -0.5290, -0.2594,  0.0734,  0.3469,  0.1272,
          0.4685, -0.1654,  1.4168, -0.0442, -0.7272,  0.4253, -0.8182, -0.2440,
          0.1256, -0.3787, -0.1420, -0.4979, -0.3011, -0.0638, -0.2782, -0.1950,
         -0.7990,  0.5789, -0.8186,  0.1311,  0.3382, -0.3325, -0.5793,  0.0871,
         -0.7164, -0.8506, -0.7472, -0.6172,  0.6725, -0.1354,  1.4744, -0.6154,
         -0.4263, -0.8802, -0.8408, -0.1092, -0.6150, -0.5463, -0.0990,  0.6968,
          0.5898, -0.0431, -0.0584,  0.3107,  0.3548, -0.6508,  0.8356, -1.0803,
          1.0688, -0.0106, -0.3511,  0.1426,  1.4892, -1.3131, -0.3424,  0.1309,
         -0.9179, -0.0051,  0.1006, -0.5325, -0.8342, -0.2906, -0.4913, -0.1670,
         -0.1618,  0.2526,  