
<div style="display: flex; justify-content: space-between; align-items: center;">
    <h1>Hello there! 👋🏻 This is an at-home task for Eedi</h1>
    <img src="https://drive.google.com/uc?export=view&id=1qtHBk_bf_f9ikBcg5QQ1i7xlM0UZv4D9" width="150" style="margin-right: 40px;" />
</div>


Thank you for participating in this task! It is designed to assess your knowledge of knowledge graphs and negative sampling techniques. Please review the following details carefully.

### Task Overview 📝
This task consists of two parts:

* Part 1: Build a ComplEx knowledge graph embedding model.
* Part 2: Implement a Negative Sampler

### Deadline ⏰: 27th October

### Submission 📤

You are required to submit two items:

1. The Code
    * Use the provided Colab notebook, to either make a copy of the notebook or download it as a .py file. Submit the `.ipynb` or `.py`file via email.

2. A Video Explanation

    * Record a **5-minute** video explaining your solution. Imagine you are presenting your newly developed code to a colleague. In the video, highlight the key aspects of your code, any assumptions you made, and any optimizations that are important for a teammate to understand.

### Notes:
1. We recognize that AI tools may assist you in completing this task. However, if your solution is entirely generated by an AI, it may closely resemble other candidates' submissions, making it harder for you to stand out in the process.

2. The code should be able to run without requiring any highly specialized hardware. Assume that you have access to one GPU and a standard CPU.


## Download Data: [here](https://drive.google.com/file/d/1fzwUXMnDm_JbGYvAvgReipbVasTct8XQ/view?usp=sharing)

## Background KG Data 🌐

Graph data are represented by graph triples. A triple `(h,r,t)` consists of three parts: the `head (h)`, `relation(r)`, and `tail(t)`.

Here’s a simple example:

This graph:

<div style="text-align:center;">
<img src="https://drive.google.com/uc?export=view&id=1mKxZX0sTk584MUdyXsPT1zFKx1tLcUl9"   width="300"/>
</div>

will be represented as the following triplet

```
Head: "France"
Relation: "hasCapital"
Tail: "Paris"
```
i.e. `(France, hasCapital, Paris)`


## Dataset Overview 📊
The given dataset represents a heterogeneous graph in which the nodes represent: _authors, books, genres, publishers, awards, and readers_. The edges represent various types of relationships between these entities, such as:

* (Author, _wrote_, Book)
* (Book, _published_by_, Publisher)
* (Book, _belongs_to_genre_, Genre)
* (Book, _won_award_, Award)
* (Reader, _read_, Book)

### Diagram
<div style="text-align:center;">
    <img src="https://drive.google.com/uc?export=view&id=1QOmxWVFYLmUwfKEHKHoePdbTLbV37nCO" width="600" />
</div>







In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer


def train(
    dataloader: DataLoader,
    model: nn.Module,
    loss_fn: nn.Module,
    optimizer: Optimizer,
    device: str,
) -> None:
    """
    General training loop for PyTorch models.
    """
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        ## alter/add etc code
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def evaluate(
    dataloader: DataLoader, model: nn.Module, device: str, k: int = 10
) -> None:
    """
    General evaluation loop for PyTorch models. Calculates hits@k.
    """
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            ## alter/add etc code
            pred = model(batch)

    hits = hits_at_k(ranks, k=10)
    print(f"hits@{k}: {hits:>7f}")


def hits_at_k(ranks: torch.tensor | np.array | list[int], k: int) -> torch.tensor:
    return torch.mean((torch.tensor(ranks) <= k + 1).float())


## Part 1: ComplEx

Using the provided heterogeneous book graph dataset, implement a [ComplEx](https://arxiv.org/pdf/1606.06357) knowledge graph embedding model from scratch using PyTorch. You are encouraged to optimize the model for improved performance or make any necessary enhancements.

Two helper functions, `train()` and `evaluation()`, are available for you to use when training and testing your model. The functions has been intentionally left incomplete for you to customize as you see fit.

**NOTE: Using them or training the model is NOT a requirement. But can be helpful to check model behaviour**

## Part 2: Negative Sampler

In this section, we will address the challenge our model faces under the Open World Assumption: distinguishing between false facts and missing ones. To assist our model in learning this distinction, we will employ a Negative Sampling strategy. This involves generating "corrupted" versions of existing facts, which we will use as negative samples.


#### Example
Let's consider a simple example:

$$
\mathcal{E} = \{\text{Mike, George, Liverpool, Manchester, London}\}, \\
\mathcal{R} = \{\text{bornIn, friendsWith}\}, \\
f \in G = \{\text{Mike, bornIn, Liverpool}\}, \\
$$

where, $G$ is our knowledge graph with $\mathcal{E}$, the set of entities, $\mathcal{R}$ the set of relations  and $f$ is a true fact.

If we change the tail of the fact, we can generate the following synthetic negatives:

$$
\text{negatives} = \begin{bmatrix}
\text{(Mike, bornIn, Manchester)  } \\
\text{(Mike, bornIn, George)  } \\
\text{(Mike, bornIn, London )  }
\end{bmatrix}.
$$


Some Negative Samplers can be more edge-informed, producing more relevant negatives, such as:

$$
\text{negatives} = \begin{bmatrix}
\text{(Mike, bornIn, Manchester)  } \\
\text{(Mike, bornIn, London )  }
\end{bmatrix}.
$$


## Requirements

1. Implement a negative sampler that corrupts the tail entity (you are encouraged to create the data-informed version) __(code needed)__
2. Explain how you would integrate it into the training pipeline __(text needed)__
    * When/ Where are the negative samples going to be generated?
    * How are they going to be used and what loss functions will you use and why?
    * How could it affect your evaluation pipeline?

You can re-use/alter any code that you previously made

**NOTE: For the 2nd Requirement you do not need to implement the training code. Just provide your answer in text**

## You've Made it to the End! 🎉

Good luck, and we’re excited to see your solution! 🎉