<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/WordsEmbeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!yes | pip install transformers wget tensorflow_datasets wikipedia deprecated vecto unzip gluonnlp mxnet --quiet
!mkdir -p temp

In [18]:
import tensorflow_datasets as tfds
import tensorflow as tf

import gc
import torch
import wikipedia
import numpy as np
import pickle

import mxnet as mx
import gluonnlp as nlp

from transformers import BertTokenizer, BertModel

import logging
logging.basicConfig(level = logging.INFO, filename = "BERT.log" )
# logging.basicConfig(level = logging.INFO)

from tqdm import tqdm
from deprecated import deprecated
from typing import List, Tuple
from os.path import exists
from enum import Enum

In [3]:
class customWikiArticle:

    def __init__(self, index : int, title : str, realtitle : str, summary : str, ambiguous : bool):
        self.index = index
        self.title = title
        self.realtitle = realtitle
        self.summary = summary
        self.ambiguous = ambiguous

class ArticleRetriever:

    article_dir = "./article"

    def __init__(self, name : str, list_title : List[str]):

        self.name = name
        self.list_title = list_title

        if not exists(self.get_filename()):
            self.articles_map = []
        else:
            with open(self.get_filename(), "r") as mapfile:
                self.articles_map = pickle.load(mapfile)

    def get_filename(self):
        return ArticleRetriever.article_dir + self.name

    def load_article(self, title) -> customWikiArticle:
        if title not in self.article_map:
            realtitle, summary, ambiguous = self._retrieve_article(title)
            self.article_map[title] = customWikiArticle(len(self.article_map), title, realtitle, summary, ambiguous)

        return self.article_map[title]


    def _retrieve_article(self, title, closed_list : List = []) -> Tuple(str, str, bool): 
        try:
            article = wikipedia.page(title, auto_suggest=False, redirect=True)
            return (article.title, article.summary, False)

        except wikipedia.PageError or wikipedia.DisambiguationError as e:
            search_result = wikipedia.search(title, suggestion = False)

            if isinstance(e, wikipedia.PageError):
                logging.warning(f"{title} misspelled or article missing.")
                if search_result[0] is not None and search_result[0] not in closed_list:
                    logging.warning(f"\tBest find is {search_result[0]}")
                    return self._retrieve_article(search_result[0], closed_list)  
                else: return (title, None, False)

            if isinstance(e, wikipedia.DisambiguationError):
                logging.warning(f"{title} is ambiguous, trying first")
                res = self._retrieve_summary(search_result[0], closed_list)
                return (res[0], res[1], True)


    def load_all_articles(self) -> None:
        logging.info(f"Starting loading articles...")
        nb_success = 0

        nb_article = len(self.list_title)
        for i, title in tqdm(enumerate(self.list_title), total=nb_article):
            self.load_article(title)

            if self.list_title[title].summary is not None: 
                nb_success += 1
        logging.info(f"Finished loading {nb_success} article(s) !")


    def get_article(self, title):
        return self.articles_map[title] if title in self.articles_map else None

    def save(self):
        with open(self.get_filename(), "w") as mapfile:
            pickle.dump(self.articles_map, mapfile)

class WordToVector:

    def __init__(self, list_tags : List[str] = []):
        self.list_tags = list_tags
        self.embeddings = []

    def export(self, filename):
        return NotImplementedError

    def convert(self):
        return NotImplementedError


class BERTModel(WordToVector):

    temp_dir = "./temp"

    def __init__(self, list_tag : List[str] = [], big: bool = False):
        super(BERTModel, self).__init__(list_tag)
        self.list_articles = {}

        self.model_size = "bert-large-uncased" if big else "bert-base-uncased"
        self.cosine_sim_matrix = None

        self.tokenizer = BertTokenizer.from_pretrained(self.model_size, padding=True, truncation=True,)
        self.model = BertModel.from_pretrained(self.model_size, output_hidden_states = True)

        self.merging_strategy = Sum4LastLayers()

        self.model.eval()

    def export(self, filename):
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            for embedding in self.embeddings:
                line = ",".join(list(map(lambda x: str(float(x)), embedding[1])))
                print(embedding[0], ",", line, sep="", file=f)

    @deprecated(reason = "use constructor after manual import of the file is preffered")
    def import_tag_list(self, filename):
        self.embeddings.clear()
        self.cosine_sim_matrix = None

        try:
            f = open(filename, "r")
        except OSError:
            return OSError("Could not open file")

        with f:
            data = f.read().split("\n")
            for item in data:
                if item not in self.list_tags and str.strip(item) != "":
                    self.list_tags.append(item)
            print(self.list_tags)
            logging.info(f"Import finished : {len(self.list_tags)} elements imported.")

    def _retrieve_summary(self, title, depth : int, include_ambiguous : bool = True) -> Tuple[str, str]:
            if depth == 0:
                logging.warning(f"Maximum recursion depth for {title}")
                return None
            
            filename = BERTModel.temp_dir + "/" + title.replace(" ", "_")
            page = None

            if exists(filename):
                with open(filename, "r") as f:
                    return (title, f.read())

            search_result = wikipedia.search(title, suggestion = False)
            try:
                summary = wikipedia.page(search_result[0], auto_suggest=False, redirect=True).summary
                with open(filename, "w") as f:
                    logging.info(f"Saving {search_result[0]} as {title}")
                    print(summary, file=f)
                return (search_result[0], summary)

            except wikipedia.PageError as e:
                logging.warning(f"{title} misspelled or article missing.")
                if search_result[0] is not None:
                    logging.warning(f"\tBest find is {search_result[0]}")
                    return self._retrieve_summary(search_result[0], depth - 1, include_ambiguous)  
                else: return None

            except  wikipedia.DisambiguationError as e:
                logging.warning(f"{title} is ambiguous")
                if include_ambiguous:
                    return self._retrieve_summary(search_result[1], depth - 1, include_ambiguous)
                else: return None
                
            except IndexError as e:
                logging.warning(f"No result for {title}, skipping")
                return None
    

    def _retrieve_all_articles(self, include_ambiguous : bool = True):
        logging.info(f"Starting retrieveing articles {'(including ambiguous)' if include_ambiguous else ''}...")
        nb_success = 0

        nb_token = len(self.list_tags)
        for i, tag in tqdm(enumerate(self.list_tags), total=nb_token):
            res = self._retrieve_summary(tag, 5, include_ambiguous)
            self.list_articles[tag] = res

            if res is not None: 
                nb_success += 1

        logging.info(f"Finished retrieving {nb_success} article(s) !")


    def convert(self, force_retrieve : bool = False, include_ambiguous : bool = True):
        """ convert all word in their embeddings"""
        if len(self.list_tags) == 0:
            raise Exception("not tags yet !")

        if len(self.list_tags) != len(self.list_articles) or force_retrieve:
            self._retrieve_all_articles(include_ambiguous)

        logging.info("Starting converting tokens...")
        nb_token = len(self.list_tags)
        current_percent = 0.00

        for i, tag in tqdm(enumerate(self.list_tags), total = nb_token):
            
            if self.list_articles[tag] is None: continue

            article_size = len(self.list_articles[tag][1])
            max_size_article = min(article_size, 100)

            tag_plus_context = tag + ". " + self.list_articles[tag][1][:max_size_article]
            inputs = self.tokenizer(tag_plus_context, return_tensors = "pt")

            with torch.no_grad():
                outputs = self.model(**inputs)

            hidden_states = outputs[2]

            # [# layers, # batches, # tokens, # features] ==> [# tokens, # layers, # features]
            token_embeddings = torch.stack(hidden_states, dim=0)
            token_embeddings = torch.squeeze(token_embeddings, dim=1)
            token_embeddings = token_embeddings.permute(1,0,2)

            # apply different strategy to summarize word embeddings
            # tokenized_text = self.tokenizer.tokenize(tag)
            # acc = []
            # for i, token in reversed(list(enumerate(tokenized_text))):

            #     embed = self.merging_strategy.merge(token_embeddings[i+1])
            #     if i == 0:
            #         if len(acc) != 0:
            #             embed = torch.mean(torch.stack([x[1] for x in acc]), dim=0)
            #             token = tag
            #             acc = []

            #         self.embeddings.append((token, embed))
            #     else:
            #         acc.append((token, embed))
            
            # here we are just taking the [CLS] (for classification) as an embedding for the tag
            self.embeddings.append((tag, self.merging_strategy.merge(token_embeddings[0])))


    def compute_sim(self):
        """ compute cosine similarity between all vectors """
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")

        logging.info("Computing cosine similarity, this could take some time...")

        if self.cosine_sim_matrix is None:
            n_tokens = len(self.embeddings)
            self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        for j, vector in tqdm(enumerate(self.embeddings), total = len(self.embeddings)):

            for i, other_vector in enumerate(self.embeddings):

                if i == j:
                    continue

                
                cos = torch.nn.CosineSimilarity(dim=0)
                similarity = cos(vector[1], other_vector[1])

                self.cosine_sim_matrix[i][j] = similarity
                self.cosine_sim_matrix[j][i] = similarity

    def export_sim_matrix(self, filename):
        if self.cosine_sim_matrix == None:
            self.compute_sim()
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            print("/", ",".join([tag[0] for tag in self.embeddings]), sep = ",", file = f)

            for j, tag_y in enumerate(self.embeddings):
                print(tag_y[0], ",".join( [str(round(float(self.cosine_sim_matrix[j][i]), 3)) for i in range(len(self.embeddings))]), sep = ",", file = f)

    def sim_between(self, token1, token2):
        index1, v1 = [(i, v[1]) for i, v in enumerate(self.embeddings) if v[0] == token1][0]
        index2, v2 = [(i, v[1]) for i, v in enumerate(self.embeddings) if v[0] == token2][0]

        if self.cosine_sim_matrix is None:
            n_tokens = len(self.embeddings)
            self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        if self.cosine_sim_matrix[index1][index2] == 0 or self.cosine_sim_matrix[index2][index1]:
            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(v1, v2)

            self.cosine_sim_matrix[index1][index2] = similarity
            self.cosine_sim_matrix[index2][index1] = similarity

        return self.cosine_sim_matrix[index1][index2]

    def get_embedding_of(self, token):
        res = [v for v in self.embeddings if v[0] == token]
        if len(res) == 0:
            raise Exception(f"no such token {token}")
        
        return res[0]

    def get_nearest_embedding_of(self, embedding, nb = 10):

        if nb > len(self.embeddings):
            raise Exception("nb too high, not enough token")

        nearest = []
        for e in self.embeddings:

            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(embedding, e[1])

            nearest.append((e[0], similarity))
        
        nearest.sort(key = lambda tup : tup[1])
        return nearest[-nb:]

    def get_class_list(self):
        return [x[0] for x in self.embeddings]

class Sum4LastLayers:

    def merge(self, vector):
        return torch.sum(vector[-4:], dim = 0)





In [4]:
class EmbeddingTest:

    def __init__(self, filename : str):

        self.file = filename
        self.embeddings = {}

        self._load_file()

    def _load_file(self):
        try:
            with open(self.filename, "r") as f:
                lines = f.readlines()
                
            for line in lines:
                self.embeddings[line[0]] = line[1:]

        except IOError as e:
            raise IOError(f"No file {self.filename}")

    def evaluate(self):
        raise NotImplementedError

In [None]:
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
imagenet_labels = list(imagenet_labels)
len(imagenet_labels)

1001

In [None]:
sub_imagenet_labels = imagenet_labels[:200]
print(len(sub_imagenet_labels))

200


In [4]:
wikipedia.set_rate_limiting(True)

In [8]:
model = BERTModel(["dog", "cat", "horse", "spider", "butterfly", "chicken", "sheep", "cow", "squirrel", "elephant"], big = True)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
# model.import_tag_list("en-basic")
model.convert()
model.export("animal10-embeddeding.csv")
print(len(model.get_class_list()))

# model.computeCoSim()
# model.export_sim_matrix("sim_matrix.csv")

100%|██████████| 10/10 [00:00<00:00, 850.89it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


10


In [12]:
#@title Article to search for { run: "auto", vertical-output: true, display-mode: "both" }
totest = "Queen" #@param {type:"string"}

result = wikipedia.search(totest, suggestion = False)
print(result)
print(f"first result with suggestion is: {result[0]}")

try:
    print(wikipedia.page(result[0], auto_suggest=False, redirect=True))
    print(wikipedia.page(totest, auto_suggest=False, redirect=True))
except Exception as e:
    print(e)

['Queen', 'Queen (band)', 'Queen Elizabeth The Queen Mother', 'Queen Victoria', 'Queen (Queen album)', 'The Queen', 'Mary, Queen of Scots', 'Elizabeth II', 'Elizabeth I', 'Queen Mary']
first result with suggestion is: Queen
"Queen" may refer to: 
Queen regnant
List of queens regnant
Queen consort
Queen dowager
Queen mother
Queen (Marvel Comics)
Evil Queen
Red Queen (Through the Looking-Glass)
Queen of Hearts (Alice's Adventures in Wonderland)
Queen (chess)
Queen (playing card)
Carrom
Queen (band)
Queen (Queen album)
Queen (Kaya album)
Queen (Nicki Minaj album)
Ten Walls
Lovers Rock
G Flip
R.O.S.E.
Stoner Witch
Too Bright
Shawn Mendes
Deltarune Chapter 2 OST
Record
Q.U.E.E.N.
Queen Records
Queen (magazine)
Queen: The Story of an American Family
Alex Haley's Queen
Queen (2014 film)
Queen (2018 film)
Queen (web series)
Queen, New Mexico
Queen, Pennsylvania
May Queen
Queen of Heaven
Queen of heaven (antiquity)
eusociality
Queen ant
Queen bee
Queen (butterfly)
cat
Queen (Canadian automobile



  lis = BeautifulSoup(html).find_all('li')


In [6]:
test_model = BERTModel(["King", "Queen", "men", "woman"], big = True)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
gc.collect()
test_model.convert(force_retrieve = True)

men = test_model.get_embedding_of("men")[1]
woman = test_model.get_embedding_of("woman")[1]
king = test_model.get_embedding_of("King")[1]

totest = king.sub(men).add(woman)
print(test_model.get_nearest_embedding_of(totest, nb=4))



  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 4/4 [00:04<00:00,  1.17s/it]
100%|██████████| 4/4 [00:02<00:00,  1.95it/s]

[('Queen', tensor(0.5186)), ('men', tensor(0.6197)), ('woman', tensor(0.7733)), ('King', tensor(0.9269))]





In [14]:
wordsim353 = nlp.data.WordSim353('all')

In [15]:
gc.collect()

vocab = []

for w1, w2, i in wordsim353:
    if w1 not in vocab:
        vocab.append(w1)
    if w2 not in vocab:
        vocab.append(w2)

print(len(vocab))
wordsim353_model = BERTModel(vocab, big = False)
wordsim353_model.convert()

437


Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 437/437 [04:38<00:00,  1.57it/s]
100%|██████████| 437/437 [01:04<00:00,  6.80it/s]


In [16]:
from scipy.stats import spearmanr

total_comparison = 0
sim_list = []
i_list = []

for w1, w2, i in wordsim353:
    try:
        # print(f"how much {w1} is similar to {w2}:")
        sim = wordsim353_model.sim_between(w1, w2)

        sim_list.append(sim)
        i_list.append(i
                      )
        # print(f"\t{sim} & {i}")
        total_comparison += 1
    except Exception as e:
        print(w1, w2)
        continue

print(spearmanr(sim_list, i_list))

exhibit memorabilia
focus life
lover quarrel
noon string
precedent collection
rock jazz
smart student
smart stupid
start match
start year
type kind
SpearmanrResult(correlation=0.24093068514956245, pvalue=6.821500240639248e-06)


### Different result for King - men + women equation with different context window size

* 300 [('men', tensor(0.6871)), ('Queen regnant', tensor(0.8214)), ('woman', tensor(0.8371)), ('King', tensor(0.9153))]

* 400 [('men', tensor(0.6763)), ('Queen regnant', tensor(0.7927)), ('woman', tensor(0.8225)), ('King', tensor(0.9382))]

* 100 [('men', tensor(0.6197)), ('woman', tensor(0.7733)), ('Queen regnant', tensor(0.7994)), ('King', tensor(0.9269))]

* 10 [('men', tensor(0.6339)), ('Queen regnant', tensor(0.7719)), ('woman', tensor(0.8425)), ('King', tensor(0.8910))]

* 50 [('men', tensor(0.5651)), ('woman', tensor(0.7553)), ('Queen regnant', tensor(0.7684)), ('King', tensor(0.9356))]

* 200 [('men', tensor(0.4828)), ('Queen regnant', tensor(0.6896)), ('woman', tensor(0.7900)), ('King', tensor(0.8819))]

### Pearson correlation rank with different context window size
 
| model | corpus | test set |window size | pearson rank correlation |
|-------|--------|----------|------------|--------------------------|
| bert large | article from wordsim vocab | wordsim353 | 100 | 0.2158 (5.8353)|
| bert base  | article from wordsim vocab | wordsim353 | 100 | 0.2409 (6.8215)|


In [7]:
!rm -f /content/temp/*