<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/WordsEmbeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
!yes | pip install transformers wget tensorflow_datasets wikipedia deprecated unzip gluonnlp mxnet --quiet
!mkdir -p temp article

In [41]:
import tensorflow_datasets as tfds
import tensorflow as tf

import gc
import torch
import wikipedia
import numpy as np
import pickle

import mxnet as mx
import gluonnlp as nlp
import traceback

from transformers import BertTokenizer, BertModel, RobertaModel, RobertaTokenizer

import logging
logging.basicConfig(level = logging.INFO, filename = "BERT.log" )
# logging.basicConfig(level = logging.INFO)

from tqdm import tqdm
from deprecated import deprecated
from typing import List, Tuple
from os.path import exists, join
from enum import Enum

wikipedia.set_rate_limiting(True)

### Utils classes

In [66]:
class customWikiArticle:

    def __init__(self, index : int, title : str, realtitle : str, summary : str, ambiguous : bool):
        self.index = index
        self.title = title
        self.realtitle = realtitle
        self.summary = summary
        self.ambiguous = ambiguous

class ArticleRetriever:

    article_dir = "./article"

    def __init__(self, name : str, list_title : List[str]):

        self.name = name
        self.list_title = list_title
        self.modified = False

        if not exists(self.get_filename()):
            self.articles_map = {}
        else:
            with open(self.get_filename(), "rb") as mapfile:
                self.articles_map = pickle.load(mapfile)

    def get_filename(self):
        return join(ArticleRetriever.article_dir,self.name)

    def load_article(self, title, force_reload : bool = False) -> customWikiArticle:
        if title not in self.articles_map:
            self.modified = True
            realtitle, summary, ambiguous = self._retrieve_article(title)
            self.articles_map[title] = customWikiArticle(len(self.articles_map), title, realtitle, summary, ambiguous)

        if title in self.articles_map and self.articles_map[title].summary == None and force_reload:
            self.modified = True
            realtitle, summary, ambiguous = self._retrieve_article(title)
            self.articles_map[title].summary = summary

        return self.articles_map[title]

    def _retrieve_article(self, title, closed_list : List = []) -> Tuple[str, str, bool]: 
        closed_list.append(title)
        try:
            article = wikipedia.page(title, auto_suggest=False, redirect=True)
            return (article.title, article.summary, False)

        except wikipedia.PageError as e:
            search_result = wikipedia.search(title, suggestion = False)

            logging.warning(f"{title} misspelled or article missing. Best find is {search_result[0]}")
            if search_result[0] is not None and search_result[0] not in closed_list:            
                return self._retrieve_article(search_result[0], closed_list)  
            else: return (title, None, False)

        except wikipedia.DisambiguationError as e:
            logging.warning(f"{title} is ambiguous, trying first {e.options[0]}")
            if e.options[0] is not None and e.options[0] not in closed_list:
                res = self._retrieve_article(e.options[0], closed_list)
                return (res[0], res[1], True)
        
            return (None, None, None)


    def load_all_articles(self, force_reload : bool = False) -> None:
        logging.info(f"Starting loading articles...")
        nb_success = 0

        nb_article = len(self.list_title)
        for i, title in tqdm(enumerate(self.list_title), total=nb_article):
            self.load_article(title, force_reload)

            if self.articles_map[title].summary is not None: 
                nb_success += 1
        logging.info(f"Finished loading {nb_success} article(s) !")
        return self.modified

    def __call__(self, force_reload : bool = True) -> None:
        return self.load_all_articles(force_reload)

    def get_article(self, title):
        if title not in self.articles_map:
            self.load_article(title)

        return self.articles_map[title]
        
    def save(self):
        with open(self.get_filename(), "wb") as mapfile:
            pickle.dump(self.articles_map, mapfile)

class WordToVector:

    def __init__(self, list_tags : List[str] = []):
        self.list_tags = list_tags
        self.embeddings = {}

    def export(self, filename):
        return NotImplementedError

    def convert(self):
        return NotImplementedError

class Sum4LastLayers:

    def merge(self, vector):
        return torch.sum(vector[-4:], dim = 0)


In [54]:
class EmbeddingsLoader:

    def __init__(self, filename : str):

        self.file = filename
        self.embeddings = {}

        self._load_file()

    def _load_file(self):
        try:
            with open(self.file, "r") as f:
                lines = f.readlines()
                
            for line in lines[1:]:
                data = line.split(",")
                self.embeddings[data[0]] = torch.FloatTensor(list(map(float, data[1:])))

        except IOError as e:
            raise IOError(f"No file {self.file}")

class SimilarityCompute(EmbeddingsLoader):

    def __init__(self, embeddings):
        super(SimilarityCompute, self).__init__(embeddings)


    def compute_sim(self):
        """ compute cosine similarity between all vectors """
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")

        logging.info("Computing cosine similarity, this could take some time...")

        if self.cosine_sim_matrix is None:
            n_tokens = len(self.embeddings)
            self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        for j, vector in tqdm(enumerate(self.embeddings), total = len(self.embeddings)):

            for i, other_vector in enumerate(self.embeddings):

                if i == j:
                    continue

                cos = torch.nn.CosineSimilarity(dim=0)
                similarity = cos(vector[1], other_vector[1])

                self.cosine_sim_matrix[i][j] = similarity
                self.cosine_sim_matrix[j][i] = similarity

    def export_sim_matrix(self, filename):
        if self.cosine_sim_matrix == None:
            self.compute_sim()
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            print("/", ",".join([tag[0] for tag in self.embeddings]), sep = ",", file = f)

            for j, tag_y in enumerate(self.embeddings):
                print(tag_y[0], ",".join( [str(round(float(self.cosine_sim_matrix[j][i]), 3)) for i in range(len(self.embeddings))]), sep = ",", file = f)

    def sim_between(self, token1, token2):
        index1, v1 = [(i, v[1]) for i, v in enumerate(self.embeddings) if v[0] == token1][0]
        index2, v2 = [(i, v[1]) for i, v in enumerate(self.embeddings) if v[0] == token2][0]

        if self.cosine_sim_matrix is None:
            n_tokens = len(self.embeddings)
            self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        if self.cosine_sim_matrix[index1][index2] == 0 or self.cosine_sim_matrix[index2][index1]:
            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(v1, v2)

            self.cosine_sim_matrix[index1][index2] = similarity
            self.cosine_sim_matrix[index2][index1] = similarity

        return self.cosine_sim_matrix[index1][index2]

class Solver(EmbeddingsLoader):

    DEFAULT_MIN_LIST_RESULT = 10

    def __init__(self, embeddings):
        super(Solver, self).__init__(embeddings)

    def get_nearest_embedding_of(self, embedding, nb = 10):

        if nb > len(self.embeddings):
            raise Exception("nb too high, not enough token")

        nearest = []
        for tag, e in self.embeddings.items():

            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(embedding, e)

            nearest.append((tag, similarity))
        
        nearest.sort(key = lambda tup : tup[1])
        return nearest[-1:-nb-1:-1]

    def __call__(self, embeddeding):
        return self.get_nearest_embedding_of(embeddeding, min(Solver.DEFAULT_MIN_LIST_RESULT, len(self.embeddings)))

### BERT model

In [44]:
class BERTModel(WordToVector):

    temp_dir = "./temp"

    def __init__(self, list_tag : List[str], big: bool = False, window : int = 100):
        super(BERTModel, self).__init__(list_tag)
        self.list_articles = {}
        self.window_size = window

        self.model_size = "bert-large-uncased" if big else "bert-base-uncased"
        self.cosine_sim_matrix = None

        self.tokenizer = BertTokenizer.from_pretrained(self.model_size, padding=True, truncation=True,)
        self.model = BertModel.from_pretrained(self.model_size, output_hidden_states = True)

        self.merging_strategy = Sum4LastLayers()

        self.model.eval()

    def export(self, filename):
        """export all the embeddings in filename under a .csv format.
           Raise exception if embeddings hasn't been calculed yet."""

        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        dimension_number = len(next(iter(self.embeddings.values())))
        with f:
            print("embeddings", *[str(i) for i in range(dimension_number)], sep=",", file=f)
            for tag, embedding in self.embeddings.items():
                print(tag, *list(map(lambda x: str(float(x)), embedding)), sep=",", file=f)

    def reset_embeddings(self):
        self.embeddings.clear()

    def convert(self, article_ret : ArticleRetriever):
        """ convert all word in their embeddings"""

        if len(self.list_tags) == 0:
            raise Exception("not tags yet !")

        logging.info("Starting converting tokens...")
        nb_token = len(self.list_tags)
        for i, tag in tqdm(enumerate(self.list_tags), total = nb_token):
            
            article = article_ret.get_article(tag)
            tag_plus_context = tag
            if article.summary is not None:
                max_size_article = min(len(article.summary), self.window_size)
                tag_plus_context = tag + ". " + article.summary[:max_size_article]

            inputs = self.tokenizer(tag_plus_context, return_tensors = "pt")

            with torch.no_grad():
                outputs = self.model(**inputs)

            hidden_states = outputs[2]

            # [# layers, # batches, # tokens, # features] ==> [# tokens, # layers, # features]
            token_embeddings = torch.stack(hidden_states, dim=0)
            token_embeddings = torch.squeeze(token_embeddings, dim=1)
            token_embeddings = token_embeddings.permute(1,0,2)

            # apply different strategy to summarize word embeddings
            # tokenized_text = self.tokenizer.tokenize(tag)
            # acc = []
            # for i, token in reversed(list(enumerate(tokenized_text))):

            #     embed = self.merging_strategy.merge(token_embeddings[i+1])
            #     if i == 0:
            #         if len(acc) != 0:
            #             embed = torch.mean(torch.stack([x[1] for x in acc]), dim=0)
            #             token = tag
            #             acc = []

            #         self.embeddings.append((token, embed))
            #     else:
            #         acc.append((token, embed))
            
            # here we are just taking the [CLS] (for classification) as an embedding for the tag
            self.embeddings[tag] = self.merging_strategy.merge(token_embeddings[0])

    def get_embedding_of(self, token):
        if token not in self.embeddings:
            raise Exception(f"no such token {token}")
        
        return self.embeddings[token]

    def get_class_list(self):
        return self.embeddings.keys()

### RoBERTa model

In [45]:
class ROBERTAModel(BERTModel):

    def __init__(self, list_tag : List[str], big: bool = False, window : int = 100):
        WordToVector.__init__(self, list_tag)
        self.window_size = window

        self.model_size = "roberta-large" if big else "roberta-base"
        self.cosine_sim_matrix = None

        self.tokenizer = RobertaTokenizer.from_pretrained(self.model_size, padding=True, truncation=True,)
        self.model = RobertaModel.from_pretrained(self.model_size, output_hidden_states = True)

        self.merging_strategy = Sum4LastLayers()

        self.model.eval()

## Embeddings to word proba

In [None]:
solver = Solver("animal10-embeddeding.csv")

totest = solver.embeddings["cat"]

print(solver.get_nearest_embedding_of(totest))

[('cat', tensor(1.)), ('horse', tensor(0.9067)), ('squirrel', tensor(0.8875)), ('chicken', tensor(0.8678)), ('butterfly', tensor(0.8576)), ('cow', tensor(0.8493)), ('spider', tensor(0.8481)), ('dog', tensor(0.8410)), ('elephant', tensor(0.7702)), ('sheep', tensor(0.7604))]


In [None]:
# labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
# imagenet_labels = np.array(open(labels_path).read().splitlines())
# imagenet_labels = list(imagenet_labels)
# sub_imagenet_labels = imagenet_labels[:200]

##Word to embeddings

In [71]:
animal10 = ["dog", "cat", "horse", "spider", "butterfly", "chicken", "sheep", "cow", "squirrel", "elephant"]
cifar10  = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
cifar100 = ["apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", \
            "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "computer_keyboard", "couch", "crab", "crocodile", "cup", \
            "dinosaur", 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'lamp', 'lawn_mower', 'leopard', 'lion', \
            'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', \
            'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark',\
            'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', \
            'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']

In [72]:
save_name = "cifar100"
vocab = cifar100

# model = BERTModel(model, big = True, window = 100)
model = ROBERTAModel(vocab, big = True, window = 100)

articlesRetriever = ArticleRetriever(save_name + ".art", vocab)

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [74]:
if articlesRetriever():
    articlesRetriever.save()

model.reset_embeddings()
model.convert(articlesRetriever)
model.export(f"{save_name}-{model.model_size}-{model.window_size}.csv")
print("\n", len(model.get_class_list()))

100%|██████████| 100/100 [00:00<00:00, 97338.22it/s]
100%|██████████| 100/100 [00:52<00:00,  1.90it/s]



 100


## wikipedia debug

In [76]:
#@title Article to search for { run: "auto", vertical-output: true, display-mode: "both" }
totest = "king" #@param {type:"string"}

result = wikipedia.search(totest, suggestion = False)
print(result)
print(f"first result is: {result[0]}")

try:
    print(wikipedia.page(totest, auto_suggest=False, redirect=True))
    print(wikipedia.page(result[0], auto_suggest=False, redirect=True))
except Exception as e:
    print(f"best option envisaged: {e.options[0]}")
    print(e)

['King', 'King (disambiguation)', 'Martin Luther King Jr.', 'Stephen King', 'King & King', 'George VI', 'The Lion King', 'Burger King', 'King King', 'List of King of the Hill episodes']
first result is: King
<WikipediaPage 'King'>
<WikipediaPage 'King'>


## Test part

In [65]:
gc.collect()

save_name = "king-test"
vocab = ["King", "Queen", "men", "woman"]

king_test_model = ROBERTAModel(vocab, big = False, window = 10)

king_test_articlesRetriever = ArticleRetriever(save_name + ".art", vocab)
if king_test_articlesRetriever():
    king_test_articlesRetriever.save()


king_test_model.convert(articlesRetriever)
king_test_model.export(save_name + ".csv")
king_test_solver = Solver(save_name + ".csv")


men = king_test_model.get_embedding_of("men")
woman = king_test_model.get_embedding_of("woman")
king = king_test_model.get_embedding_of("King")

totest = king.sub(men).add(woman)
print("\n", king_test_solver(totest))

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 4/4 [00:00<00:00, 23730.15it/s]
100%|██████████| 4/4 [00:00<00:00,  8.22it/s]


 [('woman', tensor(0.9997)), ('King', tensor(0.9996)), ('Queen', tensor(0.9996)), ('men', tensor(0.9986))]





In [None]:
wordsim353 = nlp.data.WordSim353('all')

### Different result for King - men + women equation with different context window size

| model | window size | rank of Queen | distance with first |
|-------|-------------|---------------|-----------|
| bert-large | 0  | 2 |  .0693 |
| bert-large | 10 | 3 |  .1191 | 
| bert-large | 50 | 2 |  .1672 |
| bert-large | 100 | 2 | .1275 |
| bert-large | 150 | 3 | .1134 |
| bert-large | 200 | 3 | .1923 |
| bert-large | 300 | 3 | .0939 |
| bert-large | 400 | 3 | .1455 |
| roberta-large | 0 | 2 | .0001 |
| roberta-large | 10 | 3 | .0011 |
| roberta-large | 50 | 2 | .0029 |
| roberta-large | 100 | 4 | .0061 |
| roberta-large | 150 | 3 | .0023 |
| roberta-large | 200 | 4 | .0055 |
| roberta-large | 300 | 4 | .0127 |
| roberta-large | 400 | 4 | .0045 |








In [None]:
gc.collect()

vocab = []

for w1, w2, i in wordsim353:
    if w1 not in vocab:
        vocab.append(w1)
    if w2 not in vocab:
        vocab.append(w2)

print(len(vocab))
wordsim353_model = BERTModel(vocab, big = False, window = 0)
articlesRetriever = ArticleRetriever("wordsim353", vocab)

try:
    articlesRetriever.load_all_articles(force_reload = True)
    articlesRetriever.save()
except Exception:
    traceback.print_exc()

wordsim353_model.convert(articlesRetriever)

In [None]:
from scipy.stats import spearmanr

total_comparison = 0
sim_list = []
i_list = []

for w1, w2, i in wordsim353:
    try:
        # print(f"how much {w1} is similar to {w2}:")
        sim = wordsim353_model.sim_between(w1, w2)

        sim_list.append(sim)
        i_list.append(i
                      )
        # print(f"\t{sim} & {i}")
        total_comparison += 1
    except Exception as e:
        print(w1, w2)
        continue

print(spearmanr(sim_list, i_list))

### Pearson correlation rank with different context window size
 
| model | corpus | test set |window size | pearson rank correlation |
|-------|--------|----------|------------|--------------------------|
| bert large | article from wordsim vocab | wordsim353 | 100 | 0.2158 (5.8353e-05)|
| bert base  | article from wordsim vocab | wordsim353 | 100 | 0.2284 (6.8215e-05)|
| bert-base  | article from wordsim vocav | wordsim353 | 300 | 0.1326 (1.27e-02) |
| bert-large  | article from wordsim vocav | wordsim353 | 300 | 0.2117 (6.2153e-05) |
| bert-large  | article from wordsim vocav | wordsim353 | 0 | 0.2638 (5.1232-07) |
| bert-base  | article from wordsim vocav | wordsim353 | 0 | 0.3721 (5.2306-13) |


