In [46]:
from torch import cuda
import torch.nn as nn
import transformers
from transformers import DistilBertTokenizer, DistilBertModel
import warnings
import torch
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm


In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [14]:
class BERT(nn.Module):
    def __init__(self, bert):
        
        super(BERT, self).__init__()
        
        self.bert = bert
        self.dropout = nn.Dropout(0.1)
        self.relu =  nn.ReLU()
        self.fc1 = nn.Linear(768, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, **kwargs):

        cls_hs = self.bert(**kwargs)
        hidden_state = cls_hs[0]
        pooler = hidden_state[:, 0]
        
        x = self.fc1(pooler)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return self.fc1(pooler)

In [15]:
model = BERT(bert)

try:
    model_save_name = 'saved_weights_BERT_description_classifier.pt'
    path = "../../../models/saved_weights/" + model_save_name
    model.load_state_dict(torch.load(path, 
                                        map_location=torch.device('cpu')))
    print('Local Success')
except:
    raise ValueError

model = model.to(device)

Local Success


In [39]:
def embeddings(span, model, pred_values=False, truncation=True, threshold=False):
    with torch.no_grad():
        inputs = tokenizer(span, return_tensors="pt", truncation=truncation)
        return model(**inputs)

def similarity(GroundTruth, Prediction):
    GroundTruth = embeddings(GroundTruth, model=model)
    Prediction = embeddings(Prediction, model=model)
    return cosine_similarity(GroundTruth, Prediction)[0][0]

In [57]:
root = "../../../data/interim/"

df_andrei = pd.read_csv(root + "DF_Andrei.csv", header=[0, 1], index_col=0) 
df_pierre = pd.read_csv(root + "DF_Pierre.csv", header=[0, 1], index_col=0) 
df_daniel = pd.read_csv(root + "DF_Daniel.csv", header=[0, 1], index_col=0) 

### Load Sentences

In [31]:
root = "../../../data/processed/"
sentences_all = {}

f = open(F"{root}Sentences_Pierre.pkl", 'rb')
sentences_Pierre = pickle.load(f)
sentences_all |= sentences_Pierre 

f = open(F"{root}Sentences_Andrei.pkl", 'rb')
sentences_Andrei = pickle.load(f)
sentences_all |= sentences_Andrei 

f = open(F"{root}Sentences_Kissling.pkl", 'rb')
sentences_Kissling = pickle.load(f)
sentences_all |= sentences_Kissling 

# Drop duplicates
for species, sentences in sentences_all.items():
    sentences_all[species] = list(set(sentences))

### Match Data

In [59]:
k= 5
google_form_lst = []

for idx, (species, sentences) in enumerate(tqdm(sentences_all.items())):

    print(species)

    if idx > 1:
        continue
    # if species != 'Dypsis thiryana':
    #     continue
    
    df_select = [df_andrei, df_pierre, df_daniel]
    if species in df_andrei.index:
        df_select = df_select[0]
    elif species in df_pierre.index:
        df_select = df_select[1]
    elif species in df_daniel.index:
        df_select = df_select[2]
    else:
        raise ValueError

    for gt_main_trait in df_select.columns.get_level_values(0).unique():
        if gt_main_trait == 'Measurement':
            df_subset = df_select[df_select.index == species][gt_main_trait]
            # Not really efficient, use PD?
    
            subtraits = list(df_subset.columns)
            values = df_subset.values[0]
            for sub_trait, value in zip(subtraits, values):

                df_sent  = F"{sub_trait}: {value}"

                top_list = []

                for sentence in tqdm(sentences, leave=False, desc="Sentences"):
                    gt_sim = similarity(df_sent, sentence)
                    if not gt_sim:
                        sentence = np.NaN
                    top_list.append((gt_sim, sentence))

                top_list.sort(reverse=True)
                # top_k_list = [sentence for (_, sentence) in top_list[0:k]]
                gt_sim_sum = 0
                top_k_list = []
                for (gt_sim, sentence) in top_list[0:k]:
                    gt_sim_sum += gt_sim
                    top_k_list.append(sentence)

                # Non Nan for normalization
                None_NaNs = k - top_k_list.count(np.NaN) + 2e-26 # Float division               

                google_form_lst.append((species, gt_main_trait, sub_trait, [df_sent], *top_k_list, df_sent.capitalize(), gt_sim_sum/None_NaNs))
        else:
            df_subset = df_select[df_select.index == species][gt_main_trait]
            present_traits = df_subset.loc[:, df_subset.any()].columns.values
            
            # df_sent  = ' '.join(gt_main_trait + ' ' + present_traits)
            
            # NEW VERSION SINGLE TRAIT
            size = present_traits.shape
            if not size[0]:
                continue
            df_sent = F"{gt_main_trait} {present_traits[0]}"
            # print(df_sent)

            top_list = []

            for sentence in tqdm(sentences, leave=False, desc="Sentences"):
                gt_sim = similarity(df_sent, sentence)

                if not gt_sim:
                    sentence = np.NaN

                # print(gt_sim, sentence, df_sent)
                top_list.append((gt_sim, sentence))

            top_list.sort(reverse=True)

            # print(top_list)

            # top_k_list = [sentence for (_, sentence) in top_list[0:k]]
            gt_sim_sum = 0
            top_k_list = []
            for (gt_sim, sentence) in top_list[0:k]:
                gt_sim_sum += gt_sim
                top_k_list.append(sentence)

            # Non Nan for normalization
            None_NaNs = k - top_k_list.count(np.NaN) + 2e-26 # Float division       

            # GoogleSent = gt_main_trait + ': ' + ', '.join(list(present_traits))
            GoogleSent = F"{gt_main_trait}: {present_traits[0]}"
            google_form_lst.append((species, gt_main_trait, gt_main_trait, list(present_traits), *top_k_list, GoogleSent.capitalize(), gt_sim_sum/None_NaNs))

  0%|          | 0/647 [00:00<?, ?it/s]

Acacia amythethophylla


  0%|          | 1/647 [00:26<4:47:51, 26.74s/it]

Acacia ataxacantha


  0%|          | 1/647 [01:06<11:51:50, 66.12s/it]


KeyboardInterrupt: 