<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/WordsEmbeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

# %cd /content/drive/MyDrive/Kingston/ZSL-v2/

In [None]:
!pip install transformers wget tensorflow_datasets wikipedia

In [3]:
import tensorflow_datasets as tfds
import tensorflow as tf

import wikipedia
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

import logging
logging.basicConfig(level = logging.INFO)

from typing import List, Tuple
from os.path import exists


In [26]:
class WordToVecteur:

    def __init__(self, list_tags : List[str] = []):
        self.list_tags = list_tags

    def export(self, filename):
        return NotImplementedError

    def importTagList(self, filename):
        return NotImplementedError

    def convert(self):
        return NotImplementedError


class BERTModel(WordToVecteur):

    temp_dir = "./temp"

    def __init__(self, list_tag : List[str] = [], big: bool = False):
        super(BERTModel, self).__init__(list_tag)
        self.list_articles = {}

        self.model_size = "bert-large-uncased" if big else "bert-base-uncased"
        self.embeddings = []
        self.cosine_sim_matrix = None

        self.tokenizer = BertTokenizer.from_pretrained(self.model_size, padding=True, truncation=True,)
        self.model = BertModel.from_pretrained(self.model_size, output_hidden_states = True)

        self.merging_strategy = Sum4LastLayers()

        self.model.eval()


    def export(self, filename):
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            for embedding in self.embeddings:
                line = ",".join(list(map(lambda x: str(float(x)), embedding[1])))
                print(embedding[0], ",", line, sep="", file=f)

    def import_tag_list(self, filename):
        self.embeddings.clear()
        self.cosine_sim_matrix = None

        try:
            f = open(filename, "r")
        except OSError:
            return OSError("Could not open file")

        with f:
            data = f.read().split("\n")
            for item in data:
                if item not in self.list_tags and str.strip(item) != "":
                    self.list_tags.append(item)
            print(self.list_tags)
            logging.info(f"Import finished : {len(self.list_tags)} elements imported.")

    def _retrieve_summary(self, title, depth : int, include_ambiguous : bool = False) -> Tuple[str, str]:
            if depth == 0:
                logging.warning(f"Maximum recursion depth for {title}")
                return None
            
            filename = BERTModel.temp_dir + "/" + title.replace(" ", "_")
            page = None

            if exists(filename):
                with open(filename, "r") as f:
                    return (title, f.read())

            search_result, suggest = wikipedia.search(title, suggestion = True)
            try:
                summary = wikipedia.page(search_result[0]).summary
                with open(filename, "w") as f:
                    logging.info(f"Saving {title}")
                    print(summary, file=f)
                return (search_result[0], summary)
            except wikipedia.PageError as e:
                return self._retrieve_summary(suggest, depth - 1, include_ambiguous) if suggest is not None else None
            except  wikipedia.DisambiguationError as e:
                return self._retrieve_summary(search_result[1], depth - 1, include_ambiguous) if include_ambiguous else None
            except IndexError as e:
                logging.warning(f"No result for {title}, skipping")
                return None
    

    def retrieve_all_articles(self, include_ambiguous : bool = False):
        logging.info("Starting retrieveing articles...")
        current_percent = 0.00

        nb_token = len(self.list_tags)
        for i, tag in enumerate(self.list_tags):
            
            percent_completion = round((i / nb_token) * 100, 2)
            if current_percent != percent_completion:
                logging.info(f"{percent_completion}% completed")
                current_percent = percent_completion

            self.list_articles[tag] = self._retrieve_summary(tag, 5, include_ambiguous)

        logging.info("Finished retrieving article !")


    def convert(self):
        """ convert all word in their embeddings"""
        if len(self.list_tags) == 0:
            raise Exception("not tags yet !")

        if len(self.list_tags) != len(self.list_articles):
            self.retrieve_all_articles()

        logging.info("Starting converting tokens...")
        nb_token = len(self.list_tags)
        current_percent = 0.00

        for i, tag in enumerate(self.list_tags):
            
            percent_completion = round((i / nb_token) * 100, 2)
            if current_percent != percent_completion:
                logging.info(f"{percent_completion}% completed")
                current_percent = percent_completion
            
            if self.list_articles[tag] is None: continue
            tag_plus_context = tag + ". " + self.list_articles[tag][1]
            inputs = self.tokenizer(tag_plus_context, return_tensors = "pt")

            with torch.no_grad():
                outputs = self.model(**inputs)

            hidden_states = outputs[2]

            # [# layers, # batches, # tokens, # features] ==> [# tokens, # layers, # features]
            token_embeddings = torch.stack(hidden_states, dim=0)
            token_embeddings = torch.squeeze(token_embeddings, dim=1)
            token_embeddings = token_embeddings.permute(1,0,2)

            # apply different strategy to summarize word embeddings
            # tokenized_text = self.tokenizer.tokenize(tag)
            # acc = []
            # for i, token in reversed(list(enumerate(tokenized_text))):

            #     embed = self.merging_strategy.merge(token_embeddings[i+1])
            #     if i == 0:
            #         if len(acc) != 0:
            #             embed = torch.mean(torch.stack([x[1] for x in acc]), dim=0)
            #             token = tag
            #             acc = []

            #         self.embeddings.append((token, embed))
            #     else:
            #         acc.append((token, embed))
            
            # here we are just taking the [CLS] (for classification) as an embedding for the tag
            self.embeddings.append((tag, self.merging_strategy.merge(token_embeddings[0])))


    def compute_sim(self):
        """ compute cosine similarity between all vectors """
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")

        logging.info("Computing cosine similarity, this could take some time...")

        n_tokens = len(self.embeddings)
        self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        for j, vector in enumerate(self.embeddings):

            for i, other_vector in enumerate(self.embeddings):

                if i == j:
                    continue

                
                cos = torch.nn.CosineSimilarity(dim=0)
                similarity = cos(vector[1], other_vector[1])

                self.cosine_sim_matrix[i][j] = similarity
                self.cosine_sim_matrix[j][i] = similarity

    def export_sim_matrix(self, filename):
        if self.cosine_sim_matrix == None:
            self.compute_sim()
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            print("/", ",".join([tag[0] for tag in self.embeddings]), sep = ",", file = f)

            for j, tag_y in enumerate(self.embeddings):
                print(tag_y[0], ",".join( [str(round(float(self.cosine_sim_matrix[j][i]), 3)) for i in range(len(self.embeddings))]), sep = ",", file = f)

    def sim_between(self, token1, token2):
        if self.cosine_sim_matrix is None:
            self.compute_co_sim()

        index1 = [i for i, v in enumerate(self.embeddings) if v[0] == token1][0]
        index2 = [i for i, v in enumerate(self.embeddings) if v[0] == token2][0]

        return self.cosine_sim_matrix[index1][index2]

    def get_embedding_of(self, token):
        res = [v for v in self.embeddings if v[0] == token]
        if len(res) == 0:
            raise Exception("no such token")
        
        return res[0]

    def get_nearest_embedding_of(self, embedding, nb = 10):

        if nb > len(self.embeddings):
            raise Exception("nb too high, not enough token")

        nearest = []
        for e in self.embeddings:

            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(embedding, e[1])

            nearest.append((e[0], similarity))
        
        nearest.sort(key = lambda tup : tup[1])
        return nearest[-nb:]

    def get_class_list(self):
        return [x[0] for x in self.embeddings]

class Sum4LastLayers:

    def merge(self, vector):
        return torch.sum(vector[-4:], dim = 0)

In [5]:
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
imagenet_labels = list(imagenet_labels)
len(imagenet_labels)

1001

In [22]:
wikipedia.set_rate_limiting(True)

In [27]:
model = BERTModel(imagenet_labels, big = True)

# model.import_tag_list("en-basic")
model.convert()
model.export("imageNet-embeddeding.csv")
print(model.get_class_list())

# model.computeCoSim()
# model.simBetween("cat", "dog")

# man = model.get_embedding_of("man")[1]
# woman = model.get_embedding_of("woman")[1]

# king = model.get_embedding_of("king")[1]

# totest = king.sub(man).add(woman)
# print(model.get_nearest_embedding_of(totest, 3))

# model.export_sim_matrix("sim_matrix.csv")

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:root:Starting retrieveing articles...


  lis = BeautifulSoup(html).find_all('li')
INFO:root:0.1% completed
INFO:root:0.2% completed
INFO:root:0.3% c

RuntimeError: ignored

In [10]:
print(wikipedia.search("dotswitcher", suggestion = True))

print(wikipedia.search("dotswitcher", suggestion = False))

([], None)
[]


In [21]:
!ls "/content/temp"

In [20]:
!rm -f /content/temp/*