In [13]:
import os
!pip install transformers
if not os.path.exists("scibert_uncased.tar"):
    !wget -O scibert_uncased.tar https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar
if not os.path.exists("scibert_scivocab_uncased"):
    !tar -xvf scibert_uncased.tar
!pip install umap-learn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import torch
from transformers import BertTokenizer, BertModel

model_version = 'scibert_scivocab_uncased'
do_lower_case = True
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = BertModel.from_pretrained(model_version).to(DEVICE)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

Some weights of the model checkpoint at scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokeni

In [4]:
from sklearn.metrics.pairwise import cosine_similarity


def embed_text(text, model):
    with torch.no_grad():
        input_ids = torch.tensor(tokenizer.encode(text, truncation=True, max_length=512)).unsqueeze(0).to(DEVICE)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
        return last_hidden_states

def get_similarity(em, em2):
    return cosine_similarity(em.detach().numpy(), em2.detach().numpy())


In [5]:
import umap.umap_ as umap
reducer = umap.UMAP()

In [6]:
import pandas as pd

title_description_metadata = pd.read_csv("/content/drive/MyDrive/title_description_zenodo.csv").head(1000).values

In [7]:

def make_data_embedding(title_description_metadata, method="mean", dim=1):
    embedding_list = []
    for i in range(len(title_description_metadata)):
        embedding = embed_text(title_description_metadata[i, 1], model)
        if method == "mean":
            embedding_list.append(embedding.mean(dim))


    return embedding_list

description_embedding = make_data_embedding(title_description_metadata)

In [8]:
title_list = title_description_metadata[:, 0]


embed_list = torch.cat(description_embedding, dim=0)
red = reducer.fit_transform(embed_list.cpu().detach().numpy())

In [9]:
#plot embeding with reduced dimmensionality)

In [10]:

def make_plot(red, title_list, number=200, color = True, color_mapping_cat=None, color_cats = None, bg_color="white"):
    digits_df = pd.DataFrame(red, columns=('x', 'y'))
    if color_mapping_cat:
        digits_df['colors'] = color_mapping_cat
    digits_df['digit'] = title_list
    datasource = ColumnDataSource(digits_df)
    plot_figure = figure(
    title='UMAP projection of the article title embeddings',
    width=890,
    height=600,
    tools=('pan, wheel_zoom, reset'),
    background_fill_color = bg_color
    )
    plot_figure.legend.location = "top_left",
    plot_figure.add_tools(HoverTool(tooltips="""
    <div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    <div>
        <span style='font-size: 10px; color: #224499'></span>
        <span style='font-size: 10px'>@digit</span>
    </div>
    </div>
    """))
    if color:
        color_mapping = CategoricalColorMapper(factors=title_list, palette=magma(number))
        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='digit', transform=color_mapping),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=7
        )
        show(plot_figure)
    elif color_mapping_cat:
        color_mapping = CategoricalColorMapper(factors=color_cats, palette=magma(len(color_cats)+2)[2:])
        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='colors', transform=color_mapping),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=8,
            legend_field='colors'
        )
        show(plot_figure)
    else:

        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='digit'),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=7
        )
        show(plot_figure)

make_plot(red, title_list, number=200)

You are attempting to set `plot.legend.location` on a plot that has zero legends added, this will have no effect.

Before legend properties can be set, you must add a Legend explicitly, or call a glyph method with a legend parameter set.



 'Data from: Evolution of mir-92a underlies natural morphological variation in Drosophila melanogaster'
 'Data from: High flight costs, but low dive costs, in auks support the biomechanical hypothesis for flightlessness in penguins'
 'Data from: Bioclimatic, ecological, and phenotypic intermediacy and high genetic admixture in a natural hybrid of octoploid strawberries'
 'Data from: Evidence for a host role in thermotolerance divergence between populations of the mustard hill coral (Porites astreoides) from different reef environments'
 'Data from: Patterns of host-parasite adaptation in three populations of monarch butterflies infected with a naturally occurring protozoan disease: virulence, resistance, and tolerance'
 'Data from: Correlated responses to clonal selection in populations of Daphnia pulicaria: mechanisms of genetic correlation and the creative power of sex'
 'Data from: The complete sequence of the mitochondrial genome of Butomus umbellatus - a member of an early branchi