In [None]:
import torch
import torch.nn as nn
import numpy as np
import random
from math import sqrt
from itertools import cycle
from copy import deepcopy
from time import time

class ProjectionModel(nn.Module):
    def __init__(self, embedding_layer):
        super(ProjectionModel, self).__init__()
        self.embedding_layer = embedding_layer
        self.dropout = nn.Dropout(0.5)
        variance = 2 / (300 + 300)
        self.projection_matrices = nn.Parameter(
            torch.eye(300, device=embedding_layer.weight.device).unsqueeze(0).repeat(24, 1, 1)
        )
        self.projection_matrices.data.normal_(0, variance)
        self.projection_matrices.data += torch.eye(300, device=embedding_layer.weight.device).unsqueeze(0).repeat(24, 1, 1)

    def get_projections(self, embeddings):
        if self.training:
            embeddings = self.dropout(embeddings)
        projections = torch.matmul(self.projection_matrices, embeddings.transpose(0, 1))
        if self.training:
            projections = self.dropout(projections)
        projections = projections.transpose(0, 1).transpose(0, 2)
        return projections

    def forward(self, query_embeddings, candidate_indices):
        projected_queries = self.get_projections(query_embeddings)
        candidate_embeddings = self.embedding_layer(candidate_indices)
        if self.training:
            candidate_embeddings = self.dropout(candidate_embeddings)
        candidate_embeddings = candidate_embeddings.transpose(1, 2)
        features = torch.bmm(projected_queries, candidate_embeddings)
        return features


class ClassifierModel(nn.Module:
    def __init__(self, projection_model):
        super(ClassifierModel, self).__init__()
        self.projection_model = projection_model
        self.output_layer = nn.Linear(24, 1).to(projection_model.projection_matrices.device)
        self.loss_function = nn.BCEWithLogitsLoss(reduction='sum')
        self.sigmoid = nn.Sigmoid()

    def forward_to_logits(self, query_embeddings, candidate_indices):
        features = self.projection_model(query_embeddings, candidate_indices)
        logits = self.output_layer(features.transpose(1, 2)).squeeze(2)
        return logits

    def calculate_loss(self, query_embeddings, candidate_indices, targets):
        logits = self.forward_to_logits(query_embeddings, candidate_indices)
        loss = self.loss_function(logits, targets)
        return loss

    def forward(self, query_embeddings, candidate_indices):
        logits = self.forward_to_logits(query_embeddings, candidate_indices)
        return self.sigmoid(logits.clamp(-10, 10))

class Evaluator:
    def __init__(self, model, query_embeddings, candidate_ids, embedding_layer):
        self.model = model
        self.query_embeddings = query_embeddings
        self.candidate_ids = candidate_ids
        self.embedding_layer = embedding_layer

    def get_map(self, gold_ids):
        # This function should compute the Mean Average Precision given gold standard ids
        # Placeholder for actual MAP calculation
        return np.mean([random.random() for _ in gold_ids])

def make_sampler(data):
    """Create a generator that shuffles data each cycle."""
    num_items = len(data)
    shuffled_items = deepcopy(data)
    while True:
        random.shuffle(shuffled_items)
        for item in shuffled_items:
            yield item

def gold_ids(query_embeddings, pairs):
    """Extract sets of gold hypernym IDs for each query based on training pairs."""
    num_queries = query_embeddings.weight.shape[0]
    query_gold_ids = [set() for _ in range(num_queries)]
    for query_id, hyper_id in pairs:
        query_gold_ids[query_id].add(hyper_id)
    return query_gold_ids
def gen_hyponyms(path):
  with open(path,"r",encoding="utf-8") as fp:
    lines = fp.readlines()

  ho = []
  for l in lines:
    parts = l.strip().split("\t")
    if len(parts) != 2:
        continue
    ws,ty = l.strip().split("\t")
    # print(ws)
    w1 = "_".join(ws.split(" ")).lower()
    ho.append(w1)

  return ho


def gen_hypernyms(path):
  with open(path,"r",encoding="utf-8") as fp:
    lines = fp.readlines()

  hy = []
  for l in lines:
#     parts = l.strip().split("\t")
#     if len(parts) != 2:
#         continue
    ws = l.strip().split("\t")
    m= []
    for w in ws:
      w1 = "_".join(w.split(" ")).lower()
      m.append(w1)
    hy.append(m)

  return hy

def embeddings_dict(embeddings):
  word2vec_d = dict()
  word_vocab = []
  for word_embed in embeddings[1:]:
    word_vocab.append(word_embed.split()[0])
    temp_embedding = word_embed.strip().split(' ')[1:]
    temp_array = np.zeros(shape=(1, 300),dtype=np.float32)

    for i in range(len(temp_embedding)):
      print(temp_embedding[i])
      temp_array[0, i] = float(temp_embedding[i])

    word2vec_d[word_embed.split()[0]] = temp_array[0]

  return word2vec_d,word_vocab
def train_model(model, query_train_emb, query_trial_emb, train_pairs, trial_pairs, query_trial_ids):
    max_epochs = 1000
    patience = 200
    batch_size = 32
    clip_threshold = 1e-4

    candidate_ids = list(range(model.projector.embedding_layer.weight.shape[0]))
    candidate_sampler = make_sampler(candidate_ids)

    hypernym_freq = {hid: 0 for _, hid in train_pairs}
    for _, hid in train_pairs:
        hypernym_freq[hid] += 1
    min_freq = min(hypernym_freq.values())

    pos_sample_prob = {hid: sqrt(min_freq / freq) for hid, freq in hypernym_freq.items()}
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    best_model, best_score = None, float('-inf')
    start_time = time()

    for epoch in range(1, max_epochs + 1):
        model.train()
        random.shuffle(train_pairs)
        total_pos_loss, total_neg_loss, total_loss, updates = 0, 0, 0, 0

        batch_queries, batch_pos, batch_neg = [], [], []
        for query_id, hyper_id in train_pairs:
            if random.random() < pos_sample_prob[hyper_id]:
                batch_queries.append(query_id)
                batch_pos.append(hyper_id)
                neg_samples = [next(candidate_sampler) for _ in range(10) if next(candidate_sampler) not in train_gold_ids[query_id]]
                batch_neg.append(neg_samples)

                if len(batch_queries) == batch_size:
                    losses = process_batch(model, optimizer, batch_queries, batch_pos, batch_neg, clip_threshold, device)
                    total_pos_loss += losses[0]
                    total_neg_loss += losses[1]
                    total_loss += losses[2]
                    updates += 1
                    batch_queries, batch_pos, batch_neg = [], [], []  # Reset for next batch

        avg_pos_loss = total_pos_loss / updates
        avg_neg_loss = total_neg_loss / updates
        avg_total_loss = total_loss / updates

        trial_loss, MAP = evaluate_model(model, query_trial_emb, trial_pairs, query_trial_ids, trial_gold_ids)
        print_epoch_results(epoch, avg_pos_loss, avg_neg_loss, avg_total_loss, trial_loss, MAP, start_time)

        if MAP > best_score:
            best_score = MAP
            best_model = deepcopy(model)
            no_gain = 0
        else:
            no_gain += 1

        if no_gain >= patience:
            print("EARLY STOP!")
            break

    return best_model

# Assuming that some functions like process_batch, evaluate_model, and print_epoch_results are defined elsewhere.
# This is a basic setup and does not include the full implementation of all functions used in the training process.
if __name__ == '__main__':


    # p = argparse.ArgumentParser()
    # p.add_argument('domain',type=str)
    # domain = p.parse_args().domain
    domain = "italian"
    if(domain == "english"):
        # path = "/content/english_merged_vocab.txt"
        path = "/kaggle/input/vocabulary/1A.english.vocabulary.txt"
        path1 = "/kaggle/input/training/1A.english.training.data.txt"
        path2 = "/kaggle/input/trailll/1A.english.trial.data.txt"
        path3 = "/kaggle/input/testdata/1A.english.test.data.txt"
        pathh1 = "/kaggle/input/training/1A.english.training.gold.txt"
        pathh2 = "/kaggle/input/trailll/1A.english.trial.gold.txt"
        e_path = "/kaggle/input/engdata/english_sg_embed.txt"
        # log_path = "logs/english_logs.txt"
        p_file = "english_results.txt"
        model = "english_final.pt"
    elif(domain == "italian"):
        path = "/kaggle/input/vocabulary/1B.italian.vocabulary.txt"
        path1 = "/kaggle/input/training/1B.italian.training.data.txt"
        path2 = "/kaggle/input/trailll/1B.italian.trial.data.txt"
        path3 = "/kaggle/input/testdata/1B.italian.test.data.txt"
        pathh1 = "/kaggle/input/training/1B.italian.training.gold.txt"
        pathh2 = "/kaggle/input/trailll/1B.italian.trial.gold.txt"
        e_path = "/kaggle/input/vocabulary/1B.italian.vocabulary.txt"
#         log_path = "logs/italian_logs.txt"
        p_file = "italian_results.txt"
        model = "italian_final.pt"
    elif(domain == "spanish"):
        path = "/content/drive/MyDrive/SemEval2018-Task9/vocabulary/1C.spanish.vocabulary.txt"
        path1 = "/content/drive/MyDrive/SemEval2018-Task9/training/data/1C.spanish.training.data.txt"
        path2 = "/content/drive/MyDrive/SemEval2018-Task9/trial/data/1C.spanish.trial.data.txt"
        path3 = "/content/drive/MyDrive/SemEval2018-Task9/test/data/1C.spanish.test.data.txt"
        pathh1 = "/content/drive/MyDrive/SemEval2018-Task9/training/gold/1C.spanish.training.gold.txt"
        pathh2 = "/content/drive/MyDrive/SemEval2018-Task9/trial/gold/1C.spanish.trial.gold.txt"
        e_path = "embeddings/spanish_embeddings.txt"
        log_path = "logs/spanish_logs.txt"
        p_file = "results/spanish_results.txt"
        model = "models/spanish_final.pt"
    elif(domain == "medical"):
        path = "/kaggle/input/semevaldata/vocabulary/2A.medical.vocabulary.txt"
        path1 = "/kaggle/input/semevaldata/training/data/2A.medical.training.data.txt"
        path2 = "/kaggle/input/semevaldata/trial/data/2A.medical.trial.data.txt"
        path3 = "/kaggle/input/semevaldata/test/data/2A.medical.test.data.txt"
        pathh1 = "/kaggle/input/semevaldata/training/gold/2A.medical.training.gold.txt"
        pathh2 = "/kaggle/input/semevaldata/trial/gold/2A.medical.trial.gold.txt"
        e_path = "/kaggle/input/medmusic/medical_sg_embed.txt"
        log_path = "logs/medical_logs.txt"
        p_file = "medical_results.txt"
        model = "medical_final.pt"
    elif(domain == "music"):
        path = "/kaggle/input/semevaldata/vocabulary/2B.music.vocabulary.txt"
        path1 = "/kaggle/input/semevaldata/training/data/2B.music.training.data.txt"
        path2 = "/kaggle/input/semevaldata/trial/data/2B.music.trial.data.txt"
        path3 = "/kaggle/input/semevaldata/test/data/2B.music.test.data.txt"
        pathh1 = "/kaggle/input/semevaldata/training/gold/2B.music.training.gold.txt"
        pathh2 = "/kaggle/input/semevaldata/trial/gold/2B.music.trial.gold.txt"
        e_path = "/kaggle/input/medmusic/music_sg_embed.txt"
        log_path = "logs/music_logs.txt"
        p_file = "music_results.txt"
        model = "music_final.pt"
    else:
      print("wrong argument")
      exit()
    candidates = gen_vocab(path)
    # print(candidates)



    q_train= gen_hyponyms(path1)
    q_trial = gen_hyponyms(path2)
    q_test = gen_hyponyms(path3)


    h_train = gen_hypernyms(pathh1)
    h_trial= gen_hypernyms(pathh2)



    embeddings_file = open(e_path,"r",encoding="utf-8")
    embeddings = embeddings_file.read().splitlines()

    word2vec_d,word_vocab = embeddings_dict(embeddings)
    word_vocabs = set(word_vocab)

    candidates_embeds = em_matrix(word2vec_d,candidates)
    qtrain_embeds = em_matrix(word2vec_d,q_train)
    qtrail_embeds = em_matrix(word2vec_d,q_trial)
    qtest_embeds = em_matrix(word2vec_d,q_test)

    candidate_to_id = inv_dict(candidates)
    qtrain_to_id = inv_dict(q_train)
    qtrial_to_id = inv_dict(q_trial)

    # print(candidate_to_id)
    # print(qtrial_to_id)

    train_pairs = hypo_hyper(q_train,h_train,qtrain_to_id,candidate_to_id)
    trial_pairs = hypo_hyper(q_trial,h_trial,qtrial_to_id,candidate_to_id)

    qtrain_candids = gen_cand_ids(q_train,candidate_to_id)
    qtrial_candids = gen_cand_ids(q_trial,candidate_to_id)
    qtest_candids = gen_cand_ids(q_test,candidate_to_id)

    candidate_temb = embedder(candidates_embeds)
    qtrain_temb = embedder(qtrain_embeds)
    qtrial_temb = embedder(qtrail_embeds)

    projector1 = Projector1(candidate_temb)
    print(projector1)

    classifier1 = Classifier1(projector1)
    print(classifier1)

    trainables = list(filter(lambda x:x.requires_grad,classifier1.parameters()))
    trainables.append(qtrain_temb.weight)

    optimizer = torch.optim.Adam(trainables,lr=2e-4,betas=(0.9,0.9),eps=1e-8,weight_decay=0)