In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import cosine_similarity

from collections import Counter

import random

from dataSet import SGNS_store_DataSet

from typing import Sequence, Optional, Callable, List, Dict, Set

from copy import deepcopy

import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

import seaborn as sns
import matplotlib.pyplot as plt

import unicodedata
import string

from visuEmbedding import components_to_fig_3D, components_to_fig_3D_animation
import tool

import numpy as np
import pandas as pd

import re

[nltk_data] Downloading package punkt_tab to /home/pe/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/pe/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
def remove_accents(text: str) -> str:
    """Normalizes text to remove accents (e.g., 'café' -> 'cafe')."""
    nk = unicodedata.normalize("NFKD", text)
    return "".join(ch for ch in nk if not unicodedata.combining(ch))

def prepare_data(
    file_path: str,
    language: str,
    remove_accent: bool = True,
    remove_punct: bool = True,
    keep_apostrophes: bool = True,
    contraction_map: Optional[Dict[str, str]] = None,
    stop_words: Optional[List[str]] = None,
    break_line: bool = True,
    expand_is_contraction: bool = True
    ) -> List[List[str]]:

    sentence_split_re = re.compile(r'[\.!\?]+')
    
    contraction_re = None
    if contraction_map:
        pattern = "|".join(re.escape(k) for k in sorted(contraction_map.keys(), reverse=True))
        contraction_re = re.compile(f"({pattern})")

    punctuation_chars = set(string.punctuation)
    if keep_apostrophes or expand_is_contraction:
        punctuation_chars -= {"'", "’"}
    
    punct_trans_table = str.maketrans({c: " " for c in punctuation_chars})
    stop_words_set: Set[str] = set(stop_words) if stop_words else set()
    tokens_by_sentence: List[List[str]] = []
    
    with open(file_path, encoding="utf-8") as f:
        for line in f:
            sub_lines = sentence_split_re.split(line.strip().lower()) if break_line else [line.strip().lower()]
            
            for s in sub_lines:
                if not s: continue
                
                if contraction_re:
                    s = contraction_re.sub(lambda m: contraction_map[m.group(0)], s)
                    
                s = s.replace("-", "")
                s = s.replace("—", " ")
                
                if remove_accent:
                    s = remove_accents(s) 

                if remove_punct:
                    s = s.translate(punct_trans_table)

                toks = word_tokenize(s, language=language)

                if expand_is_contraction and language == 'english':
                    tagged = nltk.pos_tag(toks)
                    new_toks = []
                    for word, tag in tagged:
                        if tag == 'POS': continue # Remove possession
                        elif word in ["'s", "’s"] and tag == 'VBZ':
                            new_toks.append("is")
                        else:
                            new_toks.append(word)
                    toks = new_toks

                clean_toks = []
                for t in toks:
                    t_stripped = t.strip("'’")
                    if t_stripped and t_stripped not in stop_words_set:
                        clean_toks.append(t_stripped)
                
                if clean_toks:
                    tokens_by_sentence.append(clean_toks)

    return tokens_by_sentence

def separate_text_intonation(data:List[List[str]]):
    texts = []
    intonations = []
    for sentence in data:
        intonation = sentence[1::2]
        text = sentence[::2]
        if all(t.isalpha() for t in text) and all(t.isdigit() for t in intonation):
            texts.append(text)
            intonations.append(list(map(int, intonation)))
        else:
            print("Warning: Mismatched text and intonation in sentence:", sentence)
            print("Extracted text:", text)
            print("Extracted intonation:", intonation)
            for t in text:
                if not t.isalpha():
                    print(" Non-alpha text token:", t)
            for i in intonation:
                if not i.isdigit():
                    print(" Non-digit intonation token:", i)
            
    return texts, intonations


In [3]:
data = prepare_data(
    file_path="./data/GoodNightGorilla_Intonation.txt",
    language='english',
    remove_accent=True,
    remove_punct=True,
    keep_apostrophes=False,
    contraction_map={
        "that's" : "thatis",
        "it's" : "itis",
        "don't": "donot",
        "doesn't": "doesnot",},
    stop_words=["s", "n't"],
    break_line=False
)
for s in data:
    print(s)
texts, intonations = separate_text_intonation(data)
for t, i in zip(texts, intonations):
    print("Text:", t)
    print("Intonation:", i)

['look', '2', 'there', '0', 'is', '0', 'the', '0', 'zookeeper', '5', 'he', '0', 'has', '0', 'a', '0', 'big', '3', 'flashlight', '5', 'to', '0', 'see', '2', 'in', '0', 'the', '0', 'dark', '4', 'click', '5', 'what', '2', 'is', '0', 'he', '0', 'saying', '2', 'to', '0', 'the', '0', 'animal', '3', 'he', '0', 'says', '2', 'good', '4', 'night', '4', 'gorilla', '5', 'can', '2', 'you', '0', 'say', '2', 'good', '4', 'night', '4', 'oh', '3', 'my', '0', 'goodness', '3', 'look', '2', 'closer', '3', 'is', '0', 'the', '0', 'gorilla', '4', 'going', '0', 'to', '0', 'sleep', '4', 'no', '5', 'he', '0', 'is', '0', 'reaching', '3', 'out', '2', 'and', '0', 'taking', '3', 'the', '0', 'keys', '5', 'that', '0', 'sneaky', '4', 'gorilla', '4', 'is', '0', 'stealing', '3', 'the', '0', 'keys', '5', 'right', '0', 'off', '0', 'the', '0', 'zookeeper', '3', 'belt', '3', 'jingle', '5', 'jangle', '5', 'who', '2', 'sees', '2', 'him', '0', 'doing', '0', 'it', '0', 'it', '0', 'the', '0', 'little', '3', 'mouse', '5', 'squeak

In [4]:
# First analyse of frequency of words
word_counter = Counter()
for sentence in texts:
    word_counter.update(sentence)
most_common_words = word_counter.most_common()
print("Most common words:", most_common_words)

occ_m, word_list, word_to_index = tool.compute_co_occurrence_matrix(texts, window_size=2)
ooc_df = pd.DataFrame(
    data=occ_m,
    index=word_list,
    columns=word_list
)

bad_word, score_series = tool.get_parasite_word(ooc_df, percentile_threshold=95)
print(f"Identified parasite word: {bad_word}")


Most common words: [('the', 192), ('is', 117), ('he', 69), ('look', 48), ('to', 46), ('and', 42), ('a', 33), ('his', 29), ('at', 28), ('gorilla', 27), ('zookeeper', 25), ('in', 24), ('it', 24), ('good', 21), ('night', 21), ('are', 20), ('little', 19), ('big', 18), ('mouse', 18), ('she', 18), ('up', 17), ('who', 16), ('all', 16), ('right', 15), ('but', 15), ('they', 15), ('back', 15), ('there', 13), ('lion', 13), ('bed', 13), ('her', 13), ('on', 12), ('banana', 12), ('giraffe', 12), ('you', 11), ('that', 11), ('walking', 11), ('see', 10), ('oh', 10), ('out', 10), ('him', 10), ('has', 9), ('says', 9), ('elephant', 9), ('of', 9), ('wife', 9), ('keys', 8), ('doesnot', 8), ('now', 8), ('everyone', 8), ('click', 7), ('what', 7), ('sleep', 7), ('so', 7), ('still', 7), ('hyena', 7), ('them', 7), ('armadillo', 7), ('saying', 6), ('room', 6), ('where', 6), ('with', 6), ('behind', 6), ('looks', 6), ('for', 6), ('way', 6), ('next', 6), ('asleep', 6), ('can', 5), ('going', 5), ('inside', 5), ('foll

In [9]:
class W2V_weighted_DataSet(Dataset):
    def compute_importance(self, words, intonations):
        dict_list_importance = {}
        for sentence, intonation in zip(words, intonations) :
            for index, inton in enumerate(intonation):
                if sentence[index] not in dict_list_importance :
                    dict_list_importance[sentence[index]] = [float(inton)]
                else :
                    dict_list_importance[sentence[index]].append(float(inton))

        dict_importance = {}
        for word in dict_list_importance :
            dict_importance[word] = sum(dict_list_importance[word]) / len(dict_list_importance[word])

        return dict_importance
    
    def _get_unigram_dist(self):
        """Compute unigram distribution depending on word importance"""
        weight_list = [self.word_importance[token] for token in range(len(self.encoder))]
        unigram = torch.tensor([weight for weight in weight_list], dtype=torch.float)
        return unigram / unigram.sum()

    
    def _make_pairs_positif(self):
        print("It's not me ! I'm the parent")
        pairs = []
        for sent, intonation in zip(self.sentences, self.intonations):
            ids = self.encode(sent)
            L = len(ids)
            for i, center in enumerate(ids):
                cur_window = self.context_size
                start = max(0, i - cur_window)
                end = min(L, i + cur_window + 1)
                for j in range(start, end):
                    if j == i:
                        continue
                    context = ids[j]
                    pairs.append((center, context, intonation[i]))
        return pairs
    
    def __init__(self, sentences:list[list[str]], intonations:List[List[float]] , window_size:int=2, nb_neg:int=5):
        super().__init__()
        
        assert len(sentences) == len(intonations), f"Error: Sentences and intonations must have the same length."

        all_tokens = [t for sentence in sentences for t in sentence if t.isalpha()]
        self.vocab = list(set(all_tokens))
        self.encoder:dict = {w:i for i,w in enumerate(self.vocab)}
        self.decoder:dict = {i:w for i,w in enumerate(self.vocab)}
        self.context_size:int = window_size
        self.sentences = sentences
        self.intonations = intonations
        self.K = nb_neg
        
        self.tokens = []
        for s in sentences :
            self.tokens.append([])
            for w in s :
                self.tokens[-1] .append(self.encoder[w])

        self.word_importance:dict = self.compute_importance(self.tokens, intonations)
        self.unigram_dist = self._get_unigram_dist()
        self.pairs:List = self._make_pairs_positif()

    def encode(self, words:list|str) -> list|int:
        if isinstance(words, str) : return self.encoder[words]
        ids = []
        for w in words :
            ids.append(self.encoder[w])
        return ids
    
    def decode(self, ids:list|int) -> list|int:
        if isinstance(ids, int) : return self.decoder[ids]
        words = []
        for i in ids :
            words.append(self.decoder[i])
        return words

    def __getitem__(self, idx:int):
        print("It's not me ! I'm the parent (getitem)")
        center, pos, intonation = self.pairs[idx]
        neg = torch.multinomial(self.unigram_dist, self.K, replacement=True)
        return center, pos, neg, intonation
    
    def __len__(self):
        return len(self.pairs)


In [None]:
class W2V_weighted_DataSet_v2(W2V_weighted_DataSet):
    def _make_pairs_positif(self):
        pairs = []
        for sent, intonation in zip(self.sentences, self.intonations):
            ids = self.encode(sent)
            L = len(ids)
            for i, center in enumerate(ids):
                cur_window = self.context_size
                start = max(0, i - cur_window)
                end = min(L, i + cur_window + 1)
                for j in range(start, end):
                    if j == i:
                        continue
                    context = ids[j]
                    pairs.append((center, context, intonation[i], intonation[j]))
        return pairs
    
    def __init__(self, sentences:list[list[str]], intonations:List[List[float]] , window_size:int=2, nb_neg:int=5):
        super().__init__(sentences, intonations, window_size, nb_neg)
        
    def __getitem__(self, idx:int):
        center, pos, into_center, into_pos = self.pairs[idx]
        neg = torch.multinomial(self.unigram_dist, self.K, replacement=True)
        return center, pos, neg, into_center, into_pos
    

In [None]:
def normalize_intonation(intonations:List[List[int]], range_normalize:float=1.0, center_intonation:float=1.0) -> List[List[float]]:
    all_intonations = [inton for sublist in intonations for inton in sublist]
    min_inton = min(all_intonations) # Find the minimum intonation value
    max_inton = max(all_intonations) # Find the maximum intonation value
    assert max_inton > min_inton, "Error: All intonation values are the same."

    normalized_intonations = []
    for sentence in intonations:
        normalized_sentence = [
            (inton - min_inton) / (max_inton - min_inton) for inton in sentence
        ] # Normalize to [0, 1]
        normalized_sentence = [
            range_normalize * intonation + (center_intonation - range_normalize / 2)
            for intonation in normalized_sentence
        ]
        normalized_intonations.append(normalized_sentence)

    return normalized_intonations

In [None]:
normalize_intonation([[1, 2, 3], [4, 5]], range_normalize=2.0, center_intonation=10.0)

In [None]:
intonations = normalize_intonation(intonations, range_normalize=0.8, center_intonation=1)

print(intonations)

In [12]:
test = W2V_weighted_DataSet_v2(sentences=texts, intonations=intonations)
freq_weighted = test.word_importance

rows = []
for token, imp in freq_weighted.items():
    word = test.decode(token)
    freq = float(word_counter.get(word, 0))
    parasite = float(score_series.get(word, 0.0))
    dist = test.unigram_dist[token].item()
    rows.append((int(token), word, float(imp), freq, parasite, dist))

df = pd.DataFrame(rows, columns=['token', 'word', 'importance_score', 'frequency', 'parasite_score', 'unigram_dist']).set_index('token')

It's me ! I'm the child


In [None]:
for w1, w2, intonation in test.pairs:
    center = test.decode(w1)
    context = test.decode(w2)
    print(f"{center:<20}{context:<20}{str(intonation):<6}")
    
print(test.unigram_dist)


In [None]:
text_without_0intonation = []
intonation_without_0intonation = []

for sentence_t, sentence_i in zip(texts, intonations):
    text_without_0intonation.append([])
    intonation_without_0intonation.append([])
    for t, i in zip(sentence_t, sentence_i):
        if int(i) != 0:
            text_without_0intonation[-1].append(t)
            intonation_without_0intonation[-1].append(i)

In [None]:
occ_m, word_list, word_to_index = tool.compute_co_occurrence_matrix(text_without_0intonation, window_size=2)
ooc_df = pd.DataFrame(
    data=occ_m,
    index=word_list,
    columns=word_list
)

test2 = W2V_weighted_DataSet(sentences=text_without_0intonation, intonations=intonation_without_0intonation)
freq_weighted2 = test2.word_importance

rows = []
for token, imp in freq_weighted2.items():
    word = test.decode(token)
    parasite = float(score_series.get(word, 0.0))
    rows.append((int(token), word, float(imp), parasite))

df2 = pd.DataFrame(rows, columns=['token', 'word', 'importance_score', 'parasite_score']).set_index('token')

for w1, w2, intonation in test2.pairs:
    center = test2.decode(w1)
    context = test2.decode(w2)
    print(f"{center:<20}{context:<20}{str(intonation):<6}")
    

In [None]:
import random
from collections import Counter

data = test.word_importance
words = list(data.keys())
weights = list(data.values())

def softmax(x):
    e_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
    return e_x / e_x.sum()

words = list(data.keys())
scores = np.array(list(data.values()))

probabilities = softmax(scores)

# Simulate 1,000 selections
trials = 10000
results = random.choices(words, weights=weights, k=trials)
counts = Counter(results)

print(f"{'Word':<12} | {'Score':<8} | {'Frequency (out of 1000)':<25}")
print("-" * 50)
for word in words:
    print(f"{word:<12} | {data[word]:<8} | {counts[word]}")
    

In [None]:
results = random.choices(words, weights=probabilities, k=trials)
counts = Counter(results)

print(f"{'Word':<12} | {'Score':<8} | {'Frequency (out of 1000)':<25}")
print("-" * 50)
for word in words:
    print(f"{word:<12} | {data[word]:<8} | {counts[word]}")

In [18]:
data_set = W2V_weighted_DataSet_v2(sentences=texts, intonations=intonations,window_size=4 ,nb_neg=5)
loader = DataLoader(data_set, batch_size=1, shuffle=False)
for center, pos, neg, intonation, intonationv2 in loader:
    center = data_set.decode(center.tolist())
    pos = data_set.decode(pos.tolist())
    neg = data_set.decode(neg[0].tolist())
    print(f"Center: {center}, Positive: {pos}, Negatives: {neg}, Intonation: {intonation}")

It's me ! I'm the child
It's me ! I'm the child (getitem)
Center: ['look'], Positive: ['there'], Negatives: ['green', 'barely', 'yaaaawn', 'party', 'lights'], Intonation: tensor([2])
It's me ! I'm the child (getitem)
Center: ['look'], Positive: ['is'], Negatives: ['mouth', 'pulling', 'click', 'squeaky', 'herself'], Intonation: tensor([2])
It's me ! I'm the child (getitem)
Center: ['look'], Positive: ['the'], Negatives: ['saying', 'finally', 'shell', 'walking', 'smiling'], Intonation: tensor([2])
It's me ! I'm the child (getitem)
Center: ['look'], Positive: ['zookeeper'], Negatives: ['moon', 'last', 'made', 'used', 'pull'], Intonation: tensor([2])
It's me ! I'm the child (getitem)
Center: ['there'], Positive: ['look'], Negatives: ['full', 'leaning', 'hyena', 'busy', 'tap'], Intonation: tensor([0])
It's me ! I'm the child (getitem)
Center: ['there'], Positive: ['is'], Negatives: ['pulls', 'shell', 'using', 'black', 'comfy'], Intonation: tensor([0])
It's me ! I'm the child (getitem)
Cente

In [None]:
class OnlyOneEmb(nn.Module):
    def __init__(self, emb_size:int, embedding_dimension:int=15, init_range:float|None=None, sparse:bool=True, device="cpu"):
        super().__init__()
        self.emb_size:int = emb_size
        self.emb_dim:int = embedding_dimension
        self.word_emb:nn.Embedding = nn.Embedding(num_embeddings=self.emb_size, embedding_dim=self.emb_dim, device=device, sparse=sparse)

        if init_range is None:
            init_range = 0.5 / self.emb_dim
        self.word_emb.weight.data.uniform_(-init_range, init_range)

    def forward(self, centrals_words:list|torch.Tensor, pos_context:list|torch.Tensor, neg_context:list|torch.Tensor, weights:List|torch.Tensor):
        words_emb:torch.Tensor = self.word_emb(centrals_words)
        context_emb:torch.Tensor = self.word_emb(pos_context) # [B, D]
        neg_emb:torch.Tensor = self.word_emb(neg_context) # [B, K, D]

        pos_score = torch.sum(words_emb * context_emb, dim=1)
        pos_loss = F.logsigmoid(pos_score)

        neg_score = torch.bmm(neg_emb, words_emb.unsqueeze(-1)).squeeze(2)
        neg_loss = F.logsigmoid(-neg_score).sum(1)
        loss = -((pos_loss + neg_loss) * weights).mean()
        
        return loss
    

In [None]:
data_set = W2V_weighted_DataSet(sentences=text_without_0intonation, intonations=intonation_without_0intonation, window_size=4, nb_neg=5)
loader = DataLoader(data_set, batch_size=1, shuffle=False)
# for center, pos, neg, intonation in loader:
#     center = data_set.decode(center.tolist())
#     pos = data_set.decode(pos.tolist())
#     neg = data_set.decode(neg[0].tolist())
#     print(f"Center: {center}, Positive: {pos}, Negatives: {neg}, Intonation: {intonation}")

In [None]:
modelW2V:OnlyOneEmb = OnlyOneEmb(len(data_set.encoder.values()), embedding_dimension=3)
optimizer = torch.optim.SparseAdam(modelW2V.parameters(), lr=0.005)

nb_epoch = 5

for _ in range(nb_epoch):
	for sentence_nb, (centers, pos, negs, intonation) in enumerate(loader):
		optimizer.zero_grad()
		loss = modelW2V(centers, pos, negs, intonation)
		loss.backward()
		optimizer.step()
 
print(loss)

In [None]:
def find_nearest_neighbors(vector_word:torch.Tensor, tensor:torch.Tensor, top_n:int=5):
    all_scores = cosine_similarity(tensor, vector_word.reshape(1, -1))
    score_series = pd.Series(all_scores.flatten())
    top_words = score_series.sort_values(ascending=False).head(top_n)
    return top_words

def cosine_similarity_matrix(embeddings:nn.Embedding) -> torch.Tensor:
    emb = embeddings.weight.detach()
    emb_norm = F.normalize(emb, p=2, dim=1)
    similarity_matrix = emb_norm @ emb_norm.t()
    return similarity_matrix

In [None]:
word_a = "banana"
matrix_of_similarity = cosine_similarity_matrix(modelW2V.word_emb)
nearest_neighbors = find_nearest_neighbors(matrix_of_similarity[data_set.encode(word_a)], matrix_of_similarity,
                                            top_n=20)
nearest_neighbors = nearest_neighbors.rename(index=lambda x: data_set.decoder[x])
print(f"Nearest Neighbors to '{word_a}':")
print(nearest_neighbors)