<a href="https://colab.research.google.com/github/DataJenius/NLPEncodingExperiment/blob/main/python/NLPEncodingExperiment_get_BERT_pt_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
################################################################
# encode all of our comments as 768-dimension BERT embeddings
# https://stackoverflow.com/questions/65627663/bert-extracting-cls-embedding-from-multiple-outputs-vs-single?rq=1
# https://stackoverflow.com/questions/59330597/bert-sentence-embeddings-how-to-obtain-sentence-embeddings-vector


In [None]:
###########################################
# install all dependencies
!pip install transformers



In [None]:
###########################################
# load dependencies
from transformers import AutoTokenizer, TFBertModel

#, AutoModel, AutoModelForSequenceClassification, Trainer, TrainingArguments
import pandas as pd
import numpy as np
import os
from google.colab import files

In [None]:
###############################################
# use BERT tokenizer & classifier
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased', trainable=False, num_labels=2)

Downloading:   0%|          | 0.00/511M [00:00<?, ?B/s]

Some layers from the model checkpoint at bert-base-uncased were not used when initializing TFBertModel: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at bert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [None]:
#################################################################################
# our labelled, raw comment data is on github
all_files = ['https://raw.githubusercontent.com/DataJenius/NLPEncodingExperiment/main/data/comments/selected/selected_reddit_comments_group1.csv',
             'https://raw.githubusercontent.com/DataJenius/NLPEncodingExperiment/main/data/comments/selected/selected_reddit_comments_group2.csv',
             'https://raw.githubusercontent.com/DataJenius/NLPEncodingExperiment/main/data/comments/selected/selected_reddit_comments_group3.csv',
             'https://raw.githubusercontent.com/DataJenius/NLPEncodingExperiment/main/data/comments/selected/selected_reddit_comments_group4.csv',
             'https://raw.githubusercontent.com/DataJenius/NLPEncodingExperiment/main/data/comments/selected/selected_reddit_comments_group5.csv']

# concat into a single dataframe and shuffle the contents
df_from_each_file = (pd.read_csv(f) for f in all_files)
df_all_data = pd.concat(df_from_each_file, ignore_index=True)
df_all_data = df_all_data.sample(frac=1).reset_index(drop=True)
print(df_all_data.shape)
print(df_all_data.head())

(10000, 8)
   msg_id  token_count  my_group   my_role  label       source  \
0    9923           22         5      test      0  /r/StarWars   
1   38879           58         4  validate      1      /r/lotr   
2   29415           29         4  validate      1      /r/lotr   
3    4039           18         2     train      0  /r/StarWars   
4    1432           17         2     train      0  /r/StarWars   

                                            raw_text  \
0  Yep, RR directing was very bad, but there were...   
1  Way to ignore all context and arguments in goo...   
2  Again, you're basing this entire opinion on a ...   
3  And the spice syndicate, don't they have like ...   
4  I think a limited run of 5-8 episodes could be...   

                                          clean_text  
0  yep rr directing was very bad but there were a...  
1  way to ignore all context and arguments in goo...  
2  again youre basing this entire opinion on a te...  
3  and the spice syndicate dont the

In [None]:
#################################################################################
# tokenize all 10,000 comments
tokenizer_output = bert_tokenizer(df_all_data["raw_text"].tolist(), 
                                  truncation=True,                                  
                                  max_length=500, 
                                  padding='max_length',
                                  return_tensors='np')

# put tokenized comments into np array
tokenizer_output_np = np.array([
    tokenizer_output['input_ids'], 
    tokenizer_output['attention_mask']
])
print(tokenizer_output_np.shape)

(2, 10000, 500)


In [None]:
#################################################################################
# forgive this goofy loop, I wanted to be able to sanity-check at each step
# loop through all 10,000 comments
# takes amount 2 hours to get BERT embeddings from them all
df_embeddings = pd.DataFrame()
j = 0
for i in range(0,10000,1):
  j += 1
  print(i, j)

  # get [CLS] embedding for this comment
  my_input = [tokenizer_output_np[0][i:i+1], tokenizer_output_np[1][i:i+1]] # Shape : (1,500) (1,500)
  my_output = bert_model.predict(my_input)
  my_cls = my_output[0][:,0,:] # Shape : (1, 768)

  # put it into a dataframe for use elsewhere
  values = my_cls.squeeze().tolist()
  df_my_embeddings = pd.DataFrame({'value': values}).transpose()
  df_my_embeddings["msg_id"] = df_all_data.iloc[i]["msg_id"]
  df_my_embeddings.index=[i]
  df_embeddings = df_embeddings.append(df_my_embeddings)   

  # save every 1000 records
  if j == 1000:

      # save results to local CSV
      file_name = 'BERT_embeddings_i'+str(i)+'.csv'
      df_embeddings.to_csv(file_name, index=False) 
      files.download(file_name)
      print("\n\nsave progress...\n\n",df_embeddings.head(10))      

      # reset df and j counter
      df_embeddings = pd.DataFrame()
      j=0

  #if i == 267:
  #  break

0 1
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
12 13
13 14
14 15
15 16
16 17
17 18
18 19
19 20
20 21
21 22
22 23
23 24
24 25
25 26
26 27
27 28
28 29
29 30
30 31
31 32
32 33
33 34
34 35
35 36
36 37
37 38
38 39
39 40
40 41
41 42
42 43
43 44
44 45
45 46
46 47
47 48
48 49
49 50
50 51
51 52
52 53
53 54
54 55
55 56
56 57
57 58
58 59
59 60
60 61
61 62
62 63
63 64
64 65
65 66
66 67
67 68
68 69
69 70
70 71
71 72
72 73
73 74
74 75
75 76
76 77
77 78
78 79
79 80
80 81
81 82
82 83
83 84
84 85
85 86
86 87
87 88
88 89
89 90
90 91
91 92
92 93
93 94
94 95
95 96
96 97
97 98
98 99
99 100
100 101
101 102
102 103
103 104
104 105
105 106
106 107
107 108
108 109
109 110
110 111
111 112
112 113
113 114
114 115
115 116
116 117
117 118
118 119
119 120
120 121
121 122
122 123
123 124
124 125
125 126
126 127
127 128
128 129
129 130
130 131
131 132
132 133
133 134
134 135
135 136
136 137
137 138
138 139
139 140
140 141
141 142
142 143
143 144
144 145
145 146
146 147
147 148
148 149
149 150
150 151
151 152
15

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

           0         1         2         3         4         5         6  \
0  0.007723 -0.025574  0.175493 -0.038083 -0.329894 -0.110158  0.157542   
1  0.341545  0.185957 -0.088983 -0.016812 -0.402677 -0.264315  0.137367   
2  0.051385 -0.205393 -0.067455  0.084979 -0.575922 -0.201654  0.131130   
3  0.244150 -0.063882  0.190421  0.056093 -0.722250 -0.145684  0.059774   
4  0.007838 -0.398030  0.048651 -0.297497 -0.310664 -0.298068  0.001150   
5 -0.205928 -0.681554 -0.079084  0.202485 -0.698614 -0.418531 -0.434251   
6  0.436328 -0.608490  0.402088  0.153449 -0.612226 -0.604861  0.129110   
7 -0.389868 -0.494005 -0.253824  0.305682 -0.621787 -0.361605 -0.021366   
8  0.184861 -0.369838 -0.019996  0.031214 -0.322090 -0.491898 -0.028071   
9  0.248473 -0.073738  0.218109 -0.225075 -0.579103 -0.375420  0.429654   

          7         8         9  ...       759       760       761       762  \
0  0.624229  0.160256 -0.574947  ... -0.594470  0.141395 -0.278089  0.137

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
1000  0.044482 -0.310627  0.011957 -0.386451 -0.501136 -0.623518  0.189006   
1001  0.353751 -0.486148  0.664450  0.015563  0.049035 -0.551894  0.265181   
1002  0.327814 -0.214555  0.074856  0.144886 -0.870469 -0.134359 -0.333768   
1003  0.216073 -0.156897 -0.067295  0.440750 -0.431627 -0.141125  0.242515   
1004 -0.114648 -0.144961  0.190749  0.173238 -0.415455 -0.810847  0.077555   
1005  0.576664 -0.008391  0.246919 -0.029492  0.237073 -0.014608  0.260022   
1006 -0.079233 -0.362919  0.154612  0.101771 -0.448867 -0.453476 -0.118294   
1007  0.183946  0.419669  0.010117 -0.350259 -0.469092 -0.474982  0.428322   
1008 -0.401581 -0.172269 -0.452259  0.002972 -0.856626  0.002627  0.477895   
1009  0.157893 -0.007522  0.461957 -0.177991 -0.341115 -0.399954  0.292756   

             7         8         9  ...       759       760       761  \
1000  0.736995  0.447655 -0.178106  ... -0.5152

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
2000 -0.169047  0.051896 -0.437290 -0.140519 -0.325450 -0.118745 -0.176385   
2001  0.047366 -0.294134 -0.209982 -0.191281 -0.365260 -0.330248  0.122568   
2002 -0.195334 -0.218001 -0.033639  0.027422 -0.224368 -0.554424 -0.231550   
2003 -0.093748 -0.031999 -0.192881 -0.484286 -0.255244 -0.484653  0.214267   
2004  0.498683  0.313168  0.416882 -0.599127 -0.357528 -0.344527  0.677865   
2005  0.131450 -0.266419  0.409662  0.059815 -0.216402 -0.391881  0.025815   
2006 -0.303346 -0.979906 -0.006425 -0.511717 -0.418003 -0.347349 -0.142098   
2007 -0.018670  0.095594 -0.522275 -0.117483 -0.702087 -0.550364  0.245768   
2008 -0.066318 -0.491389 -0.444886 -0.083209 -0.787048 -0.179607 -0.445208   
2009  0.064420 -0.256464  0.192897 -0.273768 -0.242173 -0.588877  0.071819   

             7         8         9  ...       759       760       761  \
2000  0.238382  0.153004 -0.285203  ... -0.4311

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
3000  0.223520  0.078309 -0.065063  0.339993 -0.894714 -0.139897  0.282089   
3001 -0.222325  0.129116 -0.024783  0.130754 -0.617399  0.027512  0.180820   
3002  0.216769 -0.088915 -0.337220 -0.184133 -0.248928 -0.388060  0.179488   
3003  0.047779  0.009021  0.216328  0.050217 -0.583657 -0.162699  0.370643   
3004  0.202744 -0.346893 -0.400062 -0.061221 -0.681065 -0.060310  0.137107   
3005 -0.067730 -0.446351  0.238151 -0.305641 -0.521404 -0.439547  0.257423   
3006 -0.011840 -0.084725  0.023957  0.257405 -0.459756 -0.515967  0.346956   
3007 -0.089580 -0.398766 -0.439232  0.468536 -0.375929 -0.215812 -0.287776   
3008 -0.072175 -0.045297  0.125705 -0.080221 -0.865534 -0.020435  0.349743   
3009  0.364442 -0.036703  0.253396 -0.268762 -0.168965 -0.120088  0.445753   

             7         8         9  ...       759       760       761  \
3000  0.513092  0.086421 -0.408794  ... -0.3803

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
4000 -0.197073 -0.487693  0.225170  0.011303 -0.801810 -0.034710  0.302273   
4001 -0.115582 -0.426119 -0.017874 -0.164226 -0.358396 -0.551242 -0.178921   
4002  0.157013  0.074554  0.069331  0.010963 -0.206704 -0.433839  0.272444   
4003 -0.362344 -0.667772  0.141796  0.159639 -0.502621 -0.359730 -0.155155   
4004  0.184357  0.149457 -0.085492 -0.292684 -0.405697 -0.440149  0.048892   
4005  0.105583  0.169103 -0.396065  0.048229 -0.601005 -0.268363 -0.111210   
4006  0.184320 -0.093221  0.027526 -0.075397 -0.428329 -0.576483 -0.093503   
4007  0.153408 -0.030137 -0.022466 -0.061353 -0.313072  0.106574  0.272106   
4008 -0.181815  0.198009 -0.148819  0.110888 -0.388441 -0.177471 -0.384168   
4009  0.334540 -0.636705 -0.226619  0.352922 -0.336663 -0.606143 -0.648719   

             7         8         9  ...       759       760       761  \
4000  0.332527 -0.234938 -0.367399  ... -0.6754

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
5000 -0.223043 -0.220668 -0.008518  0.319638 -0.591191 -0.360585 -0.059769   
5001 -0.265886  0.054243 -0.291478 -0.440960 -0.108685 -0.134640  0.120237   
5002  0.259322 -0.315193  0.048649  0.336300 -0.420177 -0.025799 -0.060935   
5003 -0.460033 -0.048545 -0.187315 -0.020369 -0.581050 -0.194240  0.261992   
5004  0.227871 -0.197131 -0.597947  0.094650 -0.573975 -0.622309  0.334084   
5005 -0.113949 -0.178736  0.361607  0.324255 -0.545913 -0.424971  0.387817   
5006 -0.141983 -0.027733  0.746989 -0.258104 -0.544662 -0.424887  0.124975   
5007  0.005824 -0.023382  0.053872 -0.297711 -0.389156 -0.404627  0.121578   
5008  0.162099  0.086180  0.317405  0.142770 -0.527357 -0.511904  0.307947   
5009  0.236441  0.051686  0.377807 -0.112795 -0.710133 -0.316042  0.314812   

             7         8         9  ...       759       760       761  \
5000  0.232537  0.267376 -0.183015  ... -0.2436

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
6000  0.057713 -0.296822 -0.278459  0.562855 -0.166555 -0.532253 -0.204827   
6001  0.319087 -0.295068  0.254527 -0.034498 -0.755354 -0.387920 -0.032774   
6002  0.370987 -0.266930  0.160087 -0.189839 -0.437752 -0.467280  0.066221   
6003  0.341367  0.018423  0.257837 -0.350166 -0.456960 -0.249546  0.357720   
6004 -0.125385  0.092718 -0.006860 -0.053224 -0.284245 -0.419431 -0.015725   
6005 -0.587162  0.234852 -0.003351 -0.111512 -0.136963 -0.132590  0.083665   
6006  0.132693  0.004501  0.263799 -0.427792 -0.210223 -0.265674  0.210355   
6007  0.583655 -0.226516  0.093201  0.053656 -0.239784 -0.589727  0.167788   
6008 -0.158507  0.218863  0.124950 -0.183512 -0.590424  0.026242  0.336536   
6009 -0.329731 -0.211487  0.605021  0.105369 -0.421001 -0.287727  0.121154   

             7         8         9  ...       759       760       761  \
6000  0.479311 -0.235834 -0.007217  ... -0.5082

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
7000  0.387666  0.020113  0.391210 -0.181992 -0.395911 -0.462945  0.454598   
7001  0.072350 -0.556157 -0.199769  0.169695 -0.983217 -0.251525  0.031349   
7002  0.345303  0.460067  0.024411 -0.188794 -0.549817 -0.492763  0.719411   
7003  0.348857  0.039958  0.427950  0.128136 -0.083189 -0.374783  0.343466   
7004  0.143201 -0.274372  0.402892  0.104848 -0.690015 -0.269729 -0.163506   
7005  0.230012  0.160443  0.015181 -0.151319 -0.276459 -0.473565  0.717835   
7006 -0.125230 -0.548312  0.333534  0.023651 -0.851254 -0.356904  0.262676   
7007  0.153443 -0.677614  0.324012 -0.280727 -0.455693 -0.398441 -0.269778   
7008  0.051075 -0.004651  0.165945 -0.167595 -0.660474 -0.269096  0.134382   
7009  0.339639 -0.609704  0.232759  0.035154 -0.520723 -0.416084 -0.326989   

             7         8         9  ...       759       760       761  \
7000  0.604792  0.185253 -0.618157  ... -0.5361

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
8000  0.210978 -0.365924  0.241585 -0.013909 -0.389829 -0.073530  0.229285   
8001  0.275279  0.134905  0.284627 -0.192636 -0.186372 -0.379877  0.330736   
8002  0.639710  0.224408 -0.293656  0.104530 -0.289828 -0.494360  0.575686   
8003  0.004174 -0.142704  0.006451 -0.317812 -0.272723 -0.395661  0.105163   
8004  0.600863 -0.205149  0.044715  0.101431 -0.497202 -0.501527  0.100017   
8005  0.183072 -0.352979 -0.378455 -0.317445 -0.426657 -0.219283  0.224555   
8006  0.018982  0.113867  0.047510 -0.163166 -0.363451 -0.137090  0.363874   
8007  0.288142  0.022036 -0.203982  0.200026 -0.326876 -0.709024  0.112925   
8008  0.491562  0.314121 -0.227868 -0.163597 -0.224419 -0.350725  0.176580   
8009 -0.505444 -0.419083 -0.140344  0.452833 -0.418154 -0.410523  0.018005   

             7         8         9  ...       759       760       761  \
8000  0.518172  0.063919 -0.084734  ... -0.2240

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>



save progress...

              0         1         2         3         4         5         6  \
9000 -0.048557 -0.168804 -0.014584 -0.418477 -0.284797 -0.151885  0.489796   
9001  0.165299  0.011111  0.028682 -0.177584 -0.507271 -0.711152  0.365144   
9002  0.252007 -0.272850  0.332500 -0.119530  0.038965 -0.337528 -0.030102   
9003 -0.707969 -0.373064 -0.894306  0.394536 -0.836453 -0.096555  0.037092   
9004  0.033425 -0.310231 -0.155892  0.161684 -1.023448 -0.266278  0.013619   
9005 -0.349715 -0.277529 -0.001990  0.324522 -0.658866 -0.010215  0.277428   
9006 -0.155868 -0.029656  0.444205  0.145694 -0.370254 -0.117515  0.282467   
9007  0.342806  0.172664 -0.310461 -0.343477 -0.573348 -0.418449  0.365239   
9008  0.290053 -0.205997 -0.313159 -0.044952 -0.130626 -0.357653  0.333982   
9009  0.056056 -0.120343  0.520524 -0.373062 -0.084793 -0.396445  0.140631   

             7         8         9  ...       759       760       761  \
9000  0.546657  0.134113 -0.123901  ... -0.3701