In [1]:
import numpy as np
import pandas as pd
import torch
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BertTokenizer, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('./data.csv')

In [3]:
df

Unnamed: 0,book,words
0,./data/cook_book_one.txt,project gutenberg's the whitehouse cookbook by...
1,./data/cook_book_three.txt,the project gutenberg ebook of new royal cook ...
2,./data/gothic_novel_four.txt,the project gutenberg ebook of the works of ed...
3,./data/gothic_novel_six.txt,the project gutenberg ebook of northanger abbe...
4,./data/gothic_novel_two.txt,project gutenberg’s the complete works of will...
5,./data/gothic_novel_three.txt,the project gutenberg ebook of dracula by bram...
6,./data/cook_book_four.txt,the project gutenberg ebook of the italian coo...
7,./data/gothic_novel_ten.txt,the project gutenberg ebook of the castle of o...
8,./data/gothic_novel_eight.txt,the project gutenberg ebook of the vampyre a t...
9,./data/gothic_novel_nine.txt,the project gutenberg ebook of the masque of t...


In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
text = ['love', 'hate']

In [7]:
encoded_input = tokenizer(text, return_tensors='pt')

In [8]:
encoded_input

{'input_ids': tensor([[ 101, 2293,  102],
        [ 101, 5223,  102]]), 'token_type_ids': tensor([[0, 0, 0],
        [0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1],
        [1, 1, 1]])}

In [9]:
output = model(**encoded_input)

In [10]:
output.last_hidden_state.shape

torch.Size([2, 3, 768])

In [11]:
output.last_hidden_state

tensor([[[-0.1931,  0.2282, -0.1564,  ..., -0.6209, -0.0099,  0.2842],
         [ 0.3865,  0.3619,  0.2342,  ..., -0.9576, -0.0633, -0.4696],
         [ 0.8001,  0.1751, -0.2877,  ..., -0.0113, -0.7751, -0.2699]],

        [[-0.2322,  0.4478, -0.2562,  ..., -0.1797,  0.0831,  0.0075],
         [ 0.3562,  0.3625, -0.2206,  ...,  0.2490,  0.4487,  0.0989],
         [ 0.7516,  0.0380, -0.2849,  ..., -0.0262, -0.8469, -0.4179]]],
       grad_fn=<NativeLayerNormBackward0>)

In [12]:
love_embedding = output.last_hidden_state[0][1]
hate_embedding = output.last_hidden_state[1][1]

In [13]:
love_embedding.detach().numpy()

array([ 3.86491150e-01,  3.61881346e-01,  2.34232858e-01, -3.95797491e-01,
        9.35692608e-01, -3.20417464e-01,  2.04268217e-01,  3.38452309e-01,
       -5.20041510e-02, -8.10698509e-01, -5.65118551e-01, -1.61363661e-01,
        3.03391844e-01,  3.40914965e-01, -4.51137006e-01, -3.86415571e-01,
        5.06600022e-01,  3.11436683e-01,  8.61643851e-01,  5.42768419e-01,
       -2.06672251e-01,  1.60407662e-01, -5.38816869e-01,  3.24378252e-01,
       -1.79528072e-02,  5.04499853e-01,  7.72020742e-02, -2.57169545e-01,
        5.15542328e-01,  3.37318391e-01,  7.22072482e-01,  7.50575811e-02,
       -1.59209967e-01,  3.52997780e-02, -8.67908359e-01, -3.30941498e-01,
        3.19348037e-01,  4.92764235e-01, -1.04944324e+00,  7.51347363e-01,
        8.89055729e-01, -4.18411404e-01,  1.28195584e-01, -9.13363278e-01,
        4.92032826e-01, -4.75546420e-02,  1.97652772e-01,  5.56535959e-01,
        1.00315318e-01, -2.43042260e-01, -1.55048698e-01,  3.88772845e-01,
        5.86102366e-01,  

In [14]:
# love_embedding

In [15]:
# x = [love_embedding.detach().numpy(), hate_embedding.detach().numpy()]
# from sklearn.decomposition import PCA
# pca = PCA(n_components=2)
# principalComponents = pca.fit_transform(x)
# principalDf = pd.DataFrame(data = principalComponents
#              , columns = ['principal component 1', 'principal component 2'])

In [16]:
# principalDf

In [17]:
# TODO: check for correct embedding
# TODO: make sure pca is working the way its supposed to be
# TODO: make good visuals (categories)

In [18]:
cosine_similarity(love_embedding.detach().numpy().reshape(1,-1), hate_embedding.detach().numpy().reshape(1, -1))

array([[0.7019417]], dtype=float32)

# Smaller Example

In [19]:
data = pd.read_csv('./words.csv')

In [20]:
data.head()

Unnamed: 0,word,label
0,loved,positive
1,terrific,positive
2,admired,positive
3,jolly,positive
4,brave,positive


In [21]:
encoded_data = [tokenizer(i, return_tensors='pt') for i in data.word]

In [22]:
word_embeddings = [model(**i) for i in encoded_data]

In [23]:
embeddings = [i.last_hidden_state for i in word_embeddings]

In [24]:
embeddings = [i[0][1]for i in embeddings]

In [25]:
embeddings = [i.detach().numpy() for i in embeddings]

In [26]:
data['word_embeddings'] = embeddings

In [27]:
data.head()

Unnamed: 0,word,label,word_embeddings
0,loved,positive,"[0.5129337, -0.04044947, 0.42860067, -0.150437..."
1,terrific,positive,"[0.33330074, -0.20869645, -0.1652221, -0.06318..."
2,admired,positive,"[-0.25650117, 0.022439487, 0.25480837, -0.0880..."
3,jolly,positive,"[-0.22022378, -0.20243269, 0.10135959, -0.3264..."
4,brave,positive,"[-0.08306722, -0.11899652, -0.46022692, -0.438..."


In [31]:
data.word_embeddings[1]
# p_c_a = PCA(n_components=2)
# principal_components = p_c_a.fit_transform(data.word_embeddings[0], data.word_embeddings[1])
# principal_df = pd.DataFrame(data = principal_components, columns=['pc_one', 'pc_two'])

array([ 3.33300740e-01, -2.08696455e-01, -1.65222093e-01, -6.31842613e-02,
        2.73006201e-01, -4.36092794e-01,  1.81016490e-01,  1.88062936e-01,
       -5.29205978e-01, -4.24398065e-01, -4.46680546e-01,  9.83447209e-02,
        6.48081958e-01,  3.54262978e-01,  1.12375468e-01,  1.41649470e-02,
        4.19651389e-01, -1.11519605e-01,  2.08385997e-02,  5.61181456e-02,
       -2.70916730e-01, -2.93104291e-01, -5.54386735e-01,  1.36999398e-01,
       -4.42754105e-02,  4.39313829e-01,  5.30328989e-01, -1.63119674e-01,
       -7.85207748e-01,  1.97202086e-01,  6.51294649e-01,  1.18266881e-01,
        1.23914115e-01, -4.48257998e-02, -9.81006145e-01, -4.78658974e-01,
        2.32817363e-02,  4.74838138e-01, -4.54063743e-01, -5.28650582e-01,
        6.65784478e-01, -3.98789108e-01,  6.25975490e-01, -3.47875059e-01,
       -2.09135264e-01,  4.14062738e-01,  4.31793928e-02,  3.95659089e-01,
        1.89450130e-01, -9.04028639e-02, -2.44704902e-01,  3.69696736e-01,
       -4.63985279e-03,  

In [34]:
love_embedding.detach().numpy()

array([ 3.86491150e-01,  3.61881346e-01,  2.34232858e-01, -3.95797491e-01,
        9.35692608e-01, -3.20417464e-01,  2.04268217e-01,  3.38452309e-01,
       -5.20041510e-02, -8.10698509e-01, -5.65118551e-01, -1.61363661e-01,
        3.03391844e-01,  3.40914965e-01, -4.51137006e-01, -3.86415571e-01,
        5.06600022e-01,  3.11436683e-01,  8.61643851e-01,  5.42768419e-01,
       -2.06672251e-01,  1.60407662e-01, -5.38816869e-01,  3.24378252e-01,
       -1.79528072e-02,  5.04499853e-01,  7.72020742e-02, -2.57169545e-01,
        5.15542328e-01,  3.37318391e-01,  7.22072482e-01,  7.50575811e-02,
       -1.59209967e-01,  3.52997780e-02, -8.67908359e-01, -3.30941498e-01,
        3.19348037e-01,  4.92764235e-01, -1.04944324e+00,  7.51347363e-01,
        8.89055729e-01, -4.18411404e-01,  1.28195584e-01, -9.13363278e-01,
        4.92032826e-01, -4.75546420e-02,  1.97652772e-01,  5.56535959e-01,
        1.00315318e-01, -2.43042260e-01, -1.55048698e-01,  3.88772845e-01,
        5.86102366e-01,  