In [1]:
from gensim.models import Word2Vec
import logging
import os
from tqdm import tqdm
import sys
import json

from gensim import downloader
text8 = downloader.load('text8')

In [2]:
from flashtext import KeywordProcessor
keyword_processor = KeywordProcessor()

def load_json_pairs(input_file):
    with open(input_file, "r") as fp:
        pairs = json.load(fp)
    return pairs
PATH = '/home/manni/bias/counterfactual-data-substitution/cds/data/'
base_pairs = load_json_pairs(PATH+'cda_default_pairs_new.json')
#base_pairs = load_json_pairs('/home/manni/bias/counterfactual-data-substitution/cds/data/cda_default_filtered_pairs.json')
name_pairs = load_json_pairs(PATH+'names_pairs_1000_scaled.json')

exclusions = set()

for (male, female) in base_pairs:
    keyword_processor.add_keyword(male.lower(),male.lower()+'_'+female.lower())
    keyword_processor.add_keyword(female.lower(),male.lower()+'_'+female.lower())
    exclusions.add(male.lower())
    exclusions.add(female.lower())

for (male, female) in name_pairs:
    keyword_processor.add_keyword(male.lower(),male.lower()+'_'+female.lower())
    keyword_processor.add_keyword(female.lower(),male.lower()+'_'+female.lower())
    exclusions.add(male.lower())
    exclusions.add(female.lower())

In [3]:
# Set up logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
print('Cpu count:',os.cpu_count())

Cpu count: 48


In [4]:
def preprocess_text(text):
    text = ' '.join(text)
    text = text.lower()
    text_ = keyword_processor.replace_keywords(text) 
    return text_.split()


class SentenceIterator:
    def __iter__(self):
        for document in text8:
            _document = preprocess_text(document)
            yield _document

sentences = SentenceIterator()

In [5]:
# Initialize and train the Word2Vec model
model = Word2Vec(sentences=sentences,epochs=10,shrink_windows=True,sg=1,vector_size=300,window=5,min_count=3,workers=os.cpu_count())

2023-11-13 07:36:55,714 : INFO : collecting all words and their counts
2023-11-13 07:36:55,716 : DEBUG : {'uri': '/home/manni/gensim-data/text8/text8.gz', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:36:55,768 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2023-11-13 07:37:29,358 : INFO : collected 252872 word types from a corpus of 17005207 raw words and 1701 sentences
2023-11-13 07:37:29,360 : INFO : Creating a fresh vocabulary
2023-11-13 07:37:29,741 : DEBUG : starting a new internal lifecycle event log for Word2Vec
2023-11-13 07:37:29,742 : INFO : Word2Vec lifecycle event {'msg': 'effective_min_count=3 retains 99230 unique words (39.24% of original 252872, drops 153642)', 'datetime': '2023-11-13T07:37:29.741582', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platfor

In [6]:
# Save the trained model
vocab = model.wv.key_to_index
emb_file = '/home/manni/embs/bw2v_merge_en_text8_mc3_epoch5_300.txt'
logging.info('Save trained word vectors')
with open(emb_file, 'w', encoding='utf-8') as f:
    f.write('%d %d\n' % (len(vocab), 300)) 
    for word in tqdm(vocab, position=0):
        f.write('%s %s\n' % (word, ' '.join([str(v) for v in model.wv[word]])))
logging.info('Done')

2023-11-13 07:44:50,255 : INFO : Save trained word vectors
100%|██████████| 99230/99230 [00:17<00:00, 5651.00it/s]
2023-11-13 07:45:12,321 : INFO : Done


In [7]:
model.save('text8_bw2v.model')

2023-11-13 07:45:12,334 : INFO : Word2Vec lifecycle event {'fname_or_handle': 'text8_bw2v.model', 'separately': 'None', 'sep_limit': 10485760, 'ignore': frozenset(), 'datetime': '2023-11-13T07:45:12.334143', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with-glibc2.35', 'event': 'saving'}
2023-11-13 07:45:12,336 : INFO : storing np array 'vectors' to text8_bw2v.model.wv.vectors.npy
2023-11-13 07:45:14,589 : INFO : storing np array 'syn1neg' to text8_bw2v.model.syn1neg.npy
2023-11-13 07:45:16,871 : INFO : not storing attribute cum_table
2023-11-13 07:45:16,872 : DEBUG : {'uri': 'text8_bw2v.model', 'mode': 'wb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:45:17,597 : INFO : saved text8_bw2v.model


In [8]:
from gensim.models import Word2Vec
model = Word2Vec.load("text8_bw2v.model")

2023-11-13 07:45:17,611 : INFO : loading Word2Vec object from text8_bw2v.model
2023-11-13 07:45:17,613 : DEBUG : {'uri': 'text8_bw2v.model', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:45:17,673 : INFO : loading wv recursively from text8_bw2v.model.wv.* with mmap=None
2023-11-13 07:45:17,674 : INFO : loading vectors from text8_bw2v.model.wv.vectors.npy with mmap=None
2023-11-13 07:45:17,747 : INFO : loading syn1neg from text8_bw2v.model.syn1neg.npy with mmap=None
2023-11-13 07:45:17,804 : INFO : setting ignored attribute cum_table to None
2023-11-13 07:45:18,521 : INFO : Word2Vec lifecycle event {'fname': 'text8_bw2v.model', 'datetime': '2023-11-13T07:45:18.521756', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with-glibc2.35', 'event': 'loaded'}


In [9]:
print('vocab before:{} pair vocab:{}'.format(len(model.wv.key_to_index),len(exclusions)))

vocab before:99230 pair vocab:2451


In [10]:
model.build_vocab(text8,update=True)

2023-11-13 07:45:18,618 : INFO : collecting all words and their counts
2023-11-13 07:45:18,619 : DEBUG : {'uri': '/home/manni/gensim-data/text8/text8.gz', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:45:18,649 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2023-11-13 07:45:23,552 : INFO : collected 253854 word types from a corpus of 17005207 raw words and 1701 sentences
2023-11-13 07:45:23,553 : INFO : Updating model with new vocabulary
2023-11-13 07:45:24,036 : INFO : Word2Vec lifecycle event {'msg': 'added 1995 new unique words (0.79% of original 253854) and increased the count of 98043 pre-existing words (38.62% of original 253854)', 'datetime': '2023-11-13T07:45:24.036755', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with

In [11]:
print('vocab after:{} pair vocab:{}'.format(len(model.wv.key_to_index),len(exclusions)))

vocab after:101225 pair vocab:2451


In [12]:
oov = exclusions-set(model.wv.key_to_index)
print('OOV:{}'.format(len(oov)))

OOV:456


In [13]:
model.train(text8,total_examples=model.corpus_count,epochs=model.epochs)

2023-11-13 07:45:25,800 : INFO : Word2Vec lifecycle event {'msg': 'training model with 48 workers on 101225 vocabulary and 300 features, using sg=1 hs=0 sample=0.001 negative=5 window=5 shrink_windows=True', 'datetime': '2023-11-13T07:45:25.800810', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with-glibc2.35', 'event': 'train'}
2023-11-13 07:45:25,823 : DEBUG : {'uri': '/home/manni/gensim-data/text8/text8.gz', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:45:26,831 : INFO : EPOCH 0 - PROGRESS: at 5.29% examples, 661504 words/s, in_qsize 0, out_qsize 0
2023-11-13 07:45:27,837 : INFO : EPOCH 0 - PROGRESS: at 12.05% examples, 750365 words/s, in_qsize 0, out_qsize 0
2023-11-13 07:45:28,846 : INFO : EPOCH 0 - PROGRESS: at 18.75% examples, 779070 words/s, in_qs

(126133109, 170052070)

In [14]:
old_model = Word2Vec.load("text8_bw2v.model")

2023-11-13 07:48:03,642 : INFO : loading Word2Vec object from text8_bw2v.model
2023-11-13 07:48:03,643 : DEBUG : {'uri': 'text8_bw2v.model', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:48:03,696 : INFO : loading wv recursively from text8_bw2v.model.wv.* with mmap=None
2023-11-13 07:48:03,697 : INFO : loading vectors from text8_bw2v.model.wv.vectors.npy with mmap=None
2023-11-13 07:48:03,762 : INFO : loading syn1neg from text8_bw2v.model.syn1neg.npy with mmap=None
2023-11-13 07:48:03,818 : INFO : setting ignored attribute cum_table to None
2023-11-13 07:48:04,528 : INFO : Word2Vec lifecycle event {'fname': 'text8_bw2v.model', 'datetime': '2023-11-13T07:48:04.528174', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with-glibc2.35', 'event': 'loaded'}


In [15]:
old_model.build_vocab(text8,update=True)

2023-11-13 07:48:04,536 : INFO : collecting all words and their counts
2023-11-13 07:48:04,537 : DEBUG : {'uri': '/home/manni/gensim-data/text8/text8.gz', 'mode': 'rb', 'buffering': -1, 'encoding': None, 'errors': None, 'newline': None, 'closefd': True, 'opener': None, 'compression': 'infer_from_extension', 'transport_params': None}
2023-11-13 07:48:04,541 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2023-11-13 07:48:09,542 : INFO : collected 253854 word types from a corpus of 17005207 raw words and 1701 sentences
2023-11-13 07:48:09,545 : INFO : Updating model with new vocabulary
2023-11-13 07:48:10,032 : INFO : Word2Vec lifecycle event {'msg': 'added 1995 new unique words (0.79% of original 253854) and increased the count of 98043 pre-existing words (38.62% of original 253854)', 'datetime': '2023-11-13T07:48:10.032413', 'gensim': '4.3.2.dev0', 'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]', 'platform': 'Linux-5.15.0-76-generic-x86_64-with

In [16]:
new_vocab = old_model.wv.key_to_index 
for word in tqdm(exclusions):
    if word in new_vocab:
        old_model.wv[word]=model.wv[word]

100%|██████████| 2451/2451 [01:28<00:00, 27.70it/s]


In [17]:
emb_file = '/home/manni/embs/bw2v_MIX_en_text8_mc3_epoch5_300.txt'

In [18]:
# Save the trained model
logging.info('Save trained word vectors')
with open(emb_file, 'w', encoding='utf-8') as f:
    f.write('%d %d\n' % (len(new_vocab), 300)) 
    for word in tqdm(new_vocab, position=0):
        f.write('%s %s\n' % (word, ' '.join([str(v) for v in old_model.wv[word]])))
logging.info('Done')

2023-11-13 07:49:40,205 : INFO : Save trained word vectors
100%|██████████| 101225/101225 [00:17<00:00, 5689.52it/s]
2023-11-13 07:50:02,642 : INFO : Done


# de conflate copy

In [None]:
# old model is now the mixed version
import numpy as np
DIM = old_model.vector_size
vectors = dict()
bvocab = old_model.wv.key_to_index
for word in tqdm(bvocab,leave=0):
    if word in exclusions:
        vec = old_model.wv.get_vector(word)
        vec = vec/np.linalg.norm(vec)
        vectors['@'+word]=vec
        continue
    words = word.split('_')
    if len(words)>1:
        for _word in words:
            if _word == 'he' or _word == 'she':
                continue
            vec = old_model.wv.get_vector(word)
            vec = vec/np.linalg.norm(vec)
            vectors[_word]=vec
    else:
        vec = old_model.wv.get_vector(word)
        vec = vec/np.linalg.norm(vec)
        vectors[word]=vec

In [None]:
import gzip
emb_copy_file = emb_file+'.copy.gz'
with gzip.open(emb_copy_file, 'wt', encoding='utf-8') as f:
    f.write('%d %d\n' % (len(vectors), DIM))
    for word,vector in tqdm(vectors.items(), position=0):
        vector = vector/np.linalg.norm(vector)
        f.write('%s %s\n' % (word, ' '.join([str(v) for v in vector])))

In [None]:
emb_copy_file

# deconf avg 

In [None]:
# old model is now the mixed version
import numpy as np
DIM = old_model.vector_size
vectors = dict()
bvocab = old_model.wv.key_to_index
for word in tqdm(bvocab,leave=0):
    words = word.split('_')
    if len(words)>1:
        for _word in words:
            if _word not in bvocab:
                continue
            vec_1 = old_model.wv.get_vector(word)
            vec_1 = vec/np.linalg.norm(vec_1)
            vec_2 = old_model.wv.get_vector(_word)
            vec_2 = vec/np.linalg.norm(vec_2)
            vec = vec_1 + vec_2
            assert vec.shape[0]==DIM
            vec = vec/np.linalg.norm(vec)
            vectors[_word]=vec
    else:
        vec = old_model.wv.get_vector(word)
        vec = vec/np.linalg.norm(vec)
        vectors[word]=vec

In [None]:
import gzip
emb_avg_file = emb_file+'.avg.gz'
with gzip.open(emb_avg_file, 'wt', encoding='utf-8') as f:
    f.write('%d %d\n' % (len(vectors), DIM))
    for word,vector in tqdm(vectors.items(), position=0):
        vector = vector/np.linalg.norm(vector)
        f.write('%s %s\n' % (word, ' '.join([str(v) for v in vector])))

In [None]:
emb_avg_file

# validation

In [None]:
old_model.wv.most_similar('he')

In [None]:
old_model.wv.most_similar('she')

In [None]:
old_model.wv.most_similar('he_she')

In [None]:
old_model.wv.most_similar('man_woman')

In [None]:
old_model.wv.most_similar('bank')

In [None]:
old_model.wv.most_similar('cat')

In [None]:
for sent in sentences:
    print(len(sent))
    for i,word in enumerate(sent):
        if '_' in word:
            print(sent[i],sent[(len(sent)//2)+i])
            break
    input()

In [None]:
model.wv.index2word(1)

In [None]:
model.wv.similarity('he','she')

In [None]:
model.wv.similarity('he','he_she')

In [None]:
model.wv.similarity('she','he_she')