In [None]:
import torch
import torch.nn as nn

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=2048, num_layers=2, dropout=0.1, offset = 256):
        super(Discriminator, self).__init__()

        layers = []
        temp = hidden_dim
        for i in range(num_layers):
            if(i==0):
                layers.append(nn.Linear(input_dim, hidden_dim))
                temp -= offset
            elif(i<hidden_dim/offset):
                layers.append(nn.Linear(temp, temp-offset*i))
                temp -= offset
            else:   
                layers.append(nn.Linear(temp, temp))

            layers.append(nn.LeakyReLU(0.2))
            layers.append(nn.Dropout(dropout))

        # Final layer to output probability (sigmoid)
        layers.append(nn.Linear(temp, 1))
        layers.append(nn.Sigmoid())

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        # x: [batch_size, input_dim]
        assert x.dim() == 2, "Input must be of shape [batch_size, input_dim]"
        return self.model(x).view(-1)  # Output: [batch_size] (probability for each sample)

In [None]:
# This is what typically the mapper is according to the MUSE paper
class Generator(nn.Module):
    def __init__(self, embedding_dimension):
        self.dimension = embedding_dimension
        self.model = nn.Sequential(nn.Linear(self.dimension, self.dimension), nn.ReLU())

    def forward(self, x):
        return self.model(x)

In [None]:
## Assume that we have a src embedding and target embeddings

def train(src_emb, tar_emb, generator, discriminator, num_epochs, config):
    # src_emb and tar_emb are the torch.Embeddings

    optimizer_for_discriminator = torch.optim.Adam(discriminator.parameters(), lr=config["discriminator_lr"]) 
    optimizer_for_generator = torch.optim.Adam([generator.weight], lr=config["generator_lr"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for i in range(num_epochs):

        # Lets update the discriminator multiple times rather than once in a epoch

        for discr_steps in range(config["discriminator_train_steps"]):

            discriminator.train()

            # Sampling with replacement

            src_idx = torch.randint(0, src_emb.num_embeddings, (config["batch_size"],), device=device)
            tar_idx = torch.randint(0, tar_emb.num_embeddings, (config["batch_size"],), device=device)

            src_word_batch_embeddings = src_emb(src_idx)
            tar_word_batch_embeddings = tar_emb(tar_idx)
            
            linear_transformation_of_src_embeddings = generator(src_word_batch_embeddings).detach() # no grad calculation is needed
            
            # Let the descriminator classify the transformed embeddings
            # If the descriminator is trained only with the fake data, then it always predicts the data as fake.
            # @ discr_predictions = discriminator(linear_transformation_of_src_embeddings) 
            # So, the descriminator has to be aware of both the fake and real data

            discriminator_input = torch.cat([linear_transformation_of_src_embeddings, tar_word_batch_embeddings])
            actual_output = torch.cat([torch.zeros(config["batch_size"])], dim=0) # 0 for fake/generated 1 for real
            pred_output = discriminator(discriminator_input)

            optimizer_for_discriminator.zero_grad() # Avoiding gradient accumulations
            discriminator_loss = nn.BCELoss()(pred_output, actual_output)
            discriminator_loss.backward()
            optimizer_for_discriminator.step()

        for gen_steps in range(config["generator_train_steps"]):
            src_idx = torch.randint(0, src_emb.num_embeddings, (config["batch_size"],))
            src_words = src_emb(src_idx)
            linear_transformation = generator(src_words)

            discri_prediction = discriminator(linear_transformation)
            fake_preds = torch.ones(config["batch_size"]) # the generator should learn to fool the discriminator

            optimizer_for_generator.zero_grad()
            generator_loss = nn.BCELoss()(discri_prediction, fake_preds)

            generator_loss.backward()
            optimizer_for_generator.step()

    return generator, discriminator

In [None]:
# Read the fast text model file was trained on the custom dataset 
import fasttext

model_hi_path = r"../custom_models/model_hi.bin"
model_en_path = r"../custom_models/model_en.bin"

model_hi = fasttext.load_model(model_hi_path)
model_en = fasttext.load_model(model_en_path)

In [None]:
import numpy as np

In [None]:
src_vocab = model_en.get_words()
src_embedding_matrix = np.array([model_en.get_word_vector(word) for word in src_vocab])
tar_vocab = model_hi.get_words()
tar_embedding_matrix = np.array([model_hi.get_word_vector(word) for word in tar_vocab])

src_embedding_dim = model_en.get_dimension()
tar_embedding_dim = model_hi.get_dimension()

src_embed = torch.tensor(src_embedding_matrix, dtype=torch.float32)
tar_embed = torch.tensor(tar_embedding_matrix, dtype=torch.float32)

# initialize the generator and discriminator
generator = Generator(src_embedding_dim)
discriminator = Discriminator(src_embedding_dim, hidden_dim=1024, num_layers=3, dropout=0.1, offset=256)

config = {
    "discriminator_lr":0.00001,
    "generator_lr": 0.00001,
    "discriminator_train_steps": 2,
    "generator_train_steps": 3,
    "batch_size": 10
}

mapper, trained_discriminator = train(src_embed, tar_embed, generator, discriminator, num_epochs=4, config=config)

In [None]:
import cosine_similarity

def get_similarities(word, word_embedding, fasttext_model):
    similarities = []

    for target_word in fasttext_model.get_words():
        target_embedding = fasttext_model.get_word_vector(target_word)
        similarity = cosine_similarity([word_embedding], [target_embedding])[0][0]
        similarities.append((target_word, similarity))

    return similarities

def get_top_k_pairs(similarities, k):
    top_k = sorted(similarities, key=lambda x: x[1], reverse=True)[:k]
    return [word for word, _ in top_k]


def generate_pseudo_translation_pairs(mapper, word, model_en, model_hi, k):
    pseudo_bilingual_pairs = {}
    tar_word_embeddings = [model_hi.get_word_vector(word) for word in model_hi.get_words()]
    word_embedding = model_en.get_word_vector(word)
    similarities = get_similarities(word, word_embedding, model_hi)
    
    return get_top_k_pairs(similarities, k)

In [None]:
def precision(muse_dict, mapper, model_en, model_hi, k):
    correct = 0
    total = len(muse_dict)
    for (en,hi) in zip(muse_dict.keys(), muse_dict.values()):
        generated_pairs = generate_pseudo_translation_pairs(mapper, en, model_en, model_hi, k)
        if hi in generated_pairs:
            correct += 1
            
    return correct/total

In [None]:
# Loading the muse en_hi parallel corpus dictionary
def create_dict(file_path, size=5000):
    en_hi = {}
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            en, hi = line.strip().split()
            en_hi[en] = hi
            if(len(en_hi)==size):
                break
                
    return en_hi

In [None]:
muse_pairs = create_dict(r"MUSE//data//crosslingual//dictionaries//en-hi.txt")

precision(muse_pairs, mapper, model_en, model_hi, k=10)

In [None]:
if __name__ == "__main__":
    with open("config.yaml") as f:
        config = yaml.safe_load(f)

    src_emb, tar_emb, src_dim, tar_dim = load_embeddings()

    generator = Generator(src_dim)
    discriminator = Discriminator(src_dim, config["hidden_dim"], config["num_layers"], config["dropout"], config["offset"])

    generator, discriminator = train(src_emb, tar_emb, generator, discriminator, config["num_epochs"], config)

    # Evaluation
    muse_dict = create_dict("data/en_hi.txt")
    acc = precision(muse_dict, generator, ...)
    print(f"Precision@k: {acc}")
