<a href="https://colab.research.google.com/github/Taedriel/ZSL-v2/blob/wordEmbedding/WordsEmbeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Kingston/ZSL-v2/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Kingston/ZSL-v2


In [2]:
!pip install transformers wget tensorflow_datasets



In [4]:
import tensorflow_datasets as tfds
import tensorflow as tf

import torch
from transformers import BertTokenizer, BertModel

import logging
logging.basicConfig(level = logging.INFO)

from typing import List

In [3]:
class WordToVecteur:

    def __init__(self, list_tags : List[str] = []):
        self.list_tags = list_tags

    def export(self, filename):
        return NotImplementedError

    def importTagList(self, filename):
        return NotImplementedError

    def convert(self):
        return NotImplementedError


class BERTModel(WordToVecteur):

    def __init__(self, list_tag : List[str] = [], big: bool = False):
        super(BERTModel, self).__init__(list_tag)
        self.model_size = "bert-large-uncased" if big else "bert-base-uncased"
        self.embeddings = []
        self.cosine_sim_matrix = None

        self.tokenizer = BertTokenizer.from_pretrained(self.model_size, padding=True, truncation=True,)
        self.model = BertModel.from_pretrained(self.model_size, output_hidden_states = True)

        self.merging_strategy = Sum4LastLayers()

        self.model.eval()


    def export(self, filename):
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            for embedding in self.embeddings:
                line = ",".join(list(map(str, map(float, embedding[1]))))
                print(embedding[0], ",", line, sep="", file=f)

    def import_tag_list(self, filename):
        self.embeddings.clear()
        self.cosine_sim_matrix = None

        try:
            f = open(filename, "r")
        except OSError:
            return OSError("Could not open file")

        with f:
            data = f.read().split("\n")
            for item in data:
                if item not in self.list_tags and str.strip(item) != "":
                    self.list_tags.append(item)
            print(self.list_tags)
            logging.info(f"Import finished : {len(self.list_tags)} elements imported.")
            

    def convert(self):
        """ convert all word in their embeddings"""

        logging.info("Starting converting tokens...")
        nb_token = len(self.list_tags)
        current_percent = 0

        for i, tag in enumerate(self.list_tags):
            
            percent_completion = (i / nb_token) * 100
            if percent_completion >= current_percent + 10:
                nearest_percent = (percent_completion // 10) * 10
                logging.info(f"{nearest_percent}% completed")
                current_percent = nearest_percent
            
            inputs = self.tokenizer(tag, return_tensors = "pt")

            with torch.no_grad():
                outputs = self.model(**inputs)

            hidden_states = outputs[2]

            # log.info(f"[{i}]","Number of layers:", len(hidden_states), "  (initial embeddings + 12 BERT layers)")
            # log.info(f"[{i}]","Number of batches:", len(hidden_states[0]))
            # logging.info(f"[{i}] Number of tokens: {len(hidden_states[0][0]) - 2}")
            # log.info(f"[{i}]","Number of hidden units:", len(hidden_states[0][0][0]))

            # [# layers, # batches, # tokens, # features] ==> [# tokens, # layers, # features]
            token_embeddings = torch.stack(hidden_states, dim=0)
            token_embeddings = torch.squeeze(token_embeddings, dim=1)
            token_embeddings = token_embeddings.permute(1,0,2)

            # apply different strategy to summarize word embeddings
            tokenized_text = self.tokenizer.tokenize(tag)
            print(tokenized_text)
            acc = []
            for i, token in reversed(list(enumerate(tokenized_text))):

                embed = self.merging_strategy.merge(token_embeddings[i+1])
                if not token.startswith("##"):
                    if len(acc) != 0:
                        embed = torch.mean(torch.stack([x[1] for x in acc]), dim=0)
                        token += "".join([x[0] for x in reversed(acc)])
                        acc = []

                    self.embeddings.append((token, embed))
                else:
                    acc.append((token, embed))

    def compute_sim(self):
        """ compute cosine similarity between all vectors """
        if len(self.embeddings) == 0:
            raise Exception("Tags not converted yet !")

        logging.info("Computing cosine similarity, this could take some time...")

        n_tokens = len(self.embeddings)
        self.cosine_sim_matrix = [[1 for j in range(n_tokens)] for i in range(n_tokens)]

        for j, vector in enumerate(self.embeddings):

            for i, other_vector in enumerate(self.embeddings):

                if i == j:
                    continue

                
                cos = torch.nn.CosineSimilarity(dim=0)
                similarity = cos(vector[1], other_vector[1])

                self.cosine_sim_matrix[i][j] = similarity
                self.cosine_sim_matrix[j][i] = similarity

    def export_sim_matrix(self, filename):
        if self.cosine_sim_matrix == None:
            self.compute_sim()
        
        try:
            f = open(filename, "w")
        except OSError:
            raise OSError("Could not open file")

        with f:
            print("/", ",".join([tag[0] for tag in self.embeddings]), sep = ",", file = f)

            for j, tag_y in enumerate(self.embeddings):
                print(tag_y[0], ",".join( [str(round(float(self.cosine_sim_matrix[j][i]), 3)) for i in range(len(self.embeddings))]), sep = ",", file = f)

    def sim_between(self, token1, token2):
        if self.cosine_sim_matrix is None:
            self.compute_co_sim()

        index1 = [i for i, v in enumerate(self.embeddings) if v[0] == token1][0]
        index2 = [i for i, v in enumerate(self.embeddings) if v[0] == token2][0]

        return self.cosine_sim_matrix[index1][index2]

    def get_embedding_of(self, token):
        res = [v for v in self.embeddings if v[0] == token]
        if len(res) == 0:
            raise Exception("no such token")
        
        return res[0]

    def get_nearest_embedding_of(self, embedding, nb = 10):

        if nb > len(self.embeddings):
            raise Exception("nb too high, not enough token")

        nearest = []
        for e in self.embeddings:

            cos = torch.nn.CosineSimilarity(dim=0)
            similarity = cos(embedding, e[1])

            nearest.append((e[0], similarity))
        
        nearest.sort(key = lambda tup : tup[1])
        return nearest[-nb:]

class Sum4LastLayers:

    def merge(self, vector):
        return torch.sum(vector[-4:], dim = 0)

Already up to date.


In [5]:
ds_name = 'oxford_flowers102'
splits = ['test', 'validation', 'train']
ds, info = tfds.load(ds_name, split = splits, with_info=True)
(train_examples, validation_examples, test_examples) = ds
print(f"Number of flower types {info.features['label'].num_classes}")
print(f"Number of training examples: {tf.data.experimental.cardinality(train_examples)}")
print(f"Number of validation examples: {tf.data.experimental.cardinality(validation_examples)}")
print(f"Number of test examples: {tf.data.experimental.cardinality(test_examples)}\n")

print('Flower types full list:')
print(info.features['label'].names)

class_list = ["pink primrose", 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']
class_list_m = ["_".join(i.split(" ")) for i in class_list]

print(class_list_m)

INFO:absl:Load dataset info from /root/tensorflow_datasets/oxford_flowers102/2.1.1
INFO:absl:Reusing dataset oxford_flowers102 (/root/tensorflow_datasets/oxford_flowers102/2.1.1)
INFO:absl:Constructing tf.data.Dataset for split ['test', 'validation', 'train'], from /root/tensorflow_datasets/oxford_flowers102/2.1.1


Number of flower types 102
Number of training examples: 6149
Number of validation examples: 1020
Number of test examples: 1020

Flower types full list:
['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion'

In [6]:
# ["king", "queen", "man", " woman", "splurgle", "pladonf"]
model = BERTModel(class_list_m)

# model.import_tag_list("en-basic")
model.convert()
model.export("flower-embeddeding.csv")

# model.computeCoSim()
# model.simBetween("cat", "dog")

# man = model.get_embedding_of("man")[1]
# woman = model.get_embedding_of("woman")[1]

# king = model.get_embedding_of("king")[1]

# totest = king.sub(man).add(woman)
# print(model.get_nearest_embedding_of(totest, 3))

# model.export_sim_matrix("sim_matrix.csv")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:root:Starting converting tokens...


['pink', '_', 'pri', '##m', '##rose']


TypeError: ignored