In [1]:
import pandas as pd
from scipy.spatial.distance import cosine
from scipy import stats
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# MatBERT
tokenizer = BertTokenizerFast.from_pretrained('MatBERT/matbert-base-cased', do_lower_case=False)
model = BertForMaskedLM.from_pretrained('MatBERT/matbert-base-cased', output_hidden_states=True).eval()

In [24]:
def get_embedding(text, tokenizer, model):
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    with torch.no_grad():
        outputs = model(tokens_tensor)
        hidden_state = outputs[0]
    embedding = hidden_state[0].mean(dim=0)
    return embedding.numpy()

In [49]:
file_path = 'dataset/thermoelectric_dft/'
file_name = 'dft.txt'

In [50]:
df = pd.read_csv(file_path + file_name, sep='\t', header=None, names=['name', 'value'])

In [51]:
thermoelectric_embedding = get_embedding("thermoelectric", tokenizer, model)

def compute_similarities_and_ranks(df):
    df['similarity'] = df['name'].apply(lambda x: 1 - cosine(get_embedding(x, tokenizer, model), thermoelectric_embedding))
    df['value_rank'] = df['value'].rank(ascending=False)
    df['similarity_rank'] = df['similarity'].rank(ascending=False)
    return df

In [52]:
df = compute_similarities_and_ranks(df)

In [53]:
df.to_csv(file_path + 'matbert_' + file_name, encoding='utf-8', index=False)

In [54]:
print(df[['name', 'value_rank', 'similarity_rank']])

             name  value_rank  similarity_rank
0        Lu2Sn2O7         1.0           7047.0
1            B4O2         2.0           1450.0
2            GeTe         3.0           3634.0
3       Pd(NN)2Pd         4.0            320.0
4           Yb2S3         5.0           3914.0
...           ...         ...              ...
9478    Ce(ClO4)3      9479.0            541.0
9479        HfNi5      9480.0           2198.0
9480       H2S2O7      9481.0           5579.0
9481      K5V3F14      9482.0           3004.0
9482  CeCl3(H2O)7      9483.0            843.0

[9483 rows x 3 columns]


In [55]:
res = stats.spearmanr(df['value_rank'], df['similarity_rank'])
res.statistic

-0.02982334608655128

In [None]:
df.plot(x='name', y='similarity_rank', kind='bar', figsize=(10, 6))

experiment_pf: -0.10094247288823806

In [None]:
experiment_zt: -0.034989962282081795

dft: -0.02982334608655128