In [None]:
!wget https://mind201910small.blob.core.windows.net/release/MINDsmall_train.zip https://mind201910small.blob.core.windows.net/release/MINDsmall_dev.zip
!sudo apt install unzip
!unzip MINDsmall_train.zip -d train
!unzip MINDsmall_dev.zip -d test
!pip install torch

In [None]:
import os
import re
import json
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import torch.nn.functional as F
from ast import literal_eval
from tqdm import tqdm
import copy

# 1) Deep Knowledge-Aware Network (DKN)

DKN is a deep learning model which incorporates information from knowledge graph for better news recommendation. Specifically, DKN uses [TransX](https://towardsdatascience.com/introduction-to-knowledge-graph-embedding-with-dgl-ke-77ace6fb60ef) method for knowledge graph representation learning, then applies a CNN framework, named KCNN, to combine entity embedding with word embedding and generate a final embedding vector for a news article. CTR prediction is made via an attention-based neural scorer.

<img src='https://camo.githubusercontent.com/a7331fdfe727b5012bd32f44444273ba43273f9d99f71d32f46e14c955935df4/68747470733a2f2f7265636f64617461736574732e7a32302e7765622e636f72652e77696e646f77732e6e65742f696d616765732f646b6e5f6172636869746563747572652e706e67'>

DKN takes one piece of candidate news and one piece of a user’s clicked news as input. For each piece of news, a specially designed KCNN is used to process its title and generate an embedding vector. KCNN is an extension of traditional CNN that allows flexibility in incorporating symbolic knowledge from a knowledge graph into sentence representation learning.

With the KCNN, we obtain a set of embedding vectors for a user’s clicked history. To get final embedding of the user with respect to the current candidate news, we use an attention-based method to automatically match the candidate news to each piece of his clicked news, and aggregate the user’s historical interests with different weights. The candidate news embedding and the user embedding are concatenated and fed into a deep neural network (DNN) to calculate the predicted probability that the user will click the candidate news.

## 1.1 Pre-processing

* **Impression** means that content was delivered to someone's feed. List of news displayed in this impression and user's click behaviors on them (1 for click and 0 for non-click). The orders of news in a impressions have been shuffled.

* **clicked_news** is the history (ID list of clicked news) of this user before this impression. The clicked news articles are ordered by time.

* **entity_embedding** and **relation_embedding** contain the 100-dimensional embeddings of the entities and relations learned from the subgraph (from WikiData knowledge graph) by TransE method. In both files, the first column is the ID of entity/relation, and the other columns are the embedding vector values.

* **Entities**

1. Label	The entity name in the Wikidata knwoledge graph
2. Type	The type of this entity in Wikidat
3. WikidataId	The entity ID in Wikidata
4. Confidence	The confidence of entity linking
5. OccurrenceOffsets	The character-level entity offset in the text of title or abstract
6. SurfaceForms	The raw entity names in the original text

In [None]:
class Config():
    num_filters = 50
    window_sizes = [2, 3, 4]
    num_batches = 8000
    num_batches_batch_loss = 50  # Number of batchs to show loss
    num_batches_val_loss_and_acc = 300
    num_batches_save_checkpoint = 400
    batch_size = 256
    learning_rate = 0.001
    train_validation_split = (0.8, 0.2)
    num_workers = 4  
    num_clicked_news_a_user = 50  
    use_context = os.environ['CONTEXT'] == '1' if 'CONTEXT' in os.environ else False
    use_attention = os.environ['ATTENTION'] == '1' if 'ATTENTION' in os.environ else True
    load_checkpoint = os.environ['LOAD_CHECKPOINT'] == '1' if 'LOAD_CHECKPOINT' in os.environ else True
    num_words_a_news = 20
    entity_confidence_threshold = 0.5
    word_freq_threshold = 3
    entity_freq_threshold = 3
    num_word_tokens = 1 + 14760
    word_embedding_dim = 100
    entity_embedding_dim = 100

model_config = Config()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
### pre-process data
def clean_dataset(
    behaviors_source, behaviors_target, news_source, news_target):
    
    behaviors = pd.read_table(
        behaviors_source, header=None, usecols=[3, 4], names=['clicked_news', 'impressions'])
    
    behaviors.impressions = behaviors.impressions.str.split()
    behaviors = behaviors.explode('impressions').reset_index(drop=True)
    behaviors['candidate_news'], behaviors['clicked'] = behaviors.impressions.str.split('-').str
    behaviors.clicked_news.fillna('', inplace=True)
    behaviors.to_csv(
        behaviors_target, sep='\t', index=False, columns=['clicked_news', 'candidate_news', 'clicked'])

    news = pd.read_table(
        news_source, header=None, usecols=[0, 3, 6], names=['id', 'title', 'entities'])
    news.to_csv(news_target, sep='\t', index=False)
    return behaviors[['clicked_news', 'candidate_news', 'clicked']], news[['id', 'title', 'entities']]

df_behaviors, df_news = clean_dataset(
    './train/behaviors.tsv', 
    './train/behaviors_cleaned.tsv',
    './train/news.tsv', 
    './train/news_cleaned.tsv',
    )

df_behaviors_test, df_news_test = clean_dataset(
    './test/behaviors.tsv', 
    './test/behaviors_cleaned.tsv',
    './test/news.tsv', 
    './test/news_cleaned.tsv',
    )

pd.merge(df_behaviors, df_news, left_on='candidate_news', right_on='id').head()

  behaviors['candidate_news'], behaviors['clicked'] = behaviors.impressions.str.split('-').str


Unnamed: 0,clicked_news,candidate_news,clicked,id,title,entities
0,N55189 N42782 N34694 N45794 N18445 N63302 N104...,N55689,1,N55689,"Charles Rogers, former Michigan State football...","[{""Label"": ""Charles Rogers (American football)..."
1,N8419 N15771 N1431 N5888 N18663 N24123 N22130 ...,N55689,0,N55689,"Charles Rogers, former Michigan State football...","[{""Label"": ""Charles Rogers (American football)..."
2,N58936 N15919 N11917 N2153 N55312 N13008 N4142...,N55689,0,N55689,"Charles Rogers, former Michigan State football...","[{""Label"": ""Charles Rogers (American football)..."
3,N41089 N3577 N59496 N18086 N56175 N56630 N1389...,N55689,0,N55689,"Charles Rogers, former Michigan State football...","[{""Label"": ""Charles Rogers (American football)..."
4,N15415 N3680 N19638 N9155 N848 N16636 N1603 N1...,N55689,1,N55689,"Charles Rogers, former Michigan State football...","[{""Label"": ""Charles Rogers (American football)..."


In [58]:
for i in df_news['entities'].sample(5).values:
    print(i, '\n')

[{"Label": "Hawthorn Woods, Illinois", "Type": "G", "WikidataId": "Q2288376", "Confidence": 1.0, "OccurrenceOffsets": [0], "SurfaceForms": ["Hawthorn Woods"]}] 

[{"Label": "Audible (store)", "Type": "O", "WikidataId": "Q366651", "Confidence": 1.0, "OccurrenceOffsets": [46], "SurfaceForms": ["Audible"]}] 

[{"Label": "Philadelphia City Council", "Type": "B", "WikidataId": "Q7182649", "Confidence": 1.0, "OccurrenceOffsets": [0], "SurfaceForms": ["Philadelphia City Council"]}] 

[{"Label": "Washington Redskins", "Type": "O", "WikidataId": "Q212654", "Confidence": 0.998, "OccurrenceOffsets": [0], "SurfaceForms": ["Redskins"]}] 

[{"Label": "Queens", "Type": "G", "WikidataId": "Q18424", "Confidence": 1.0, "OccurrenceOffsets": [58], "SurfaceForms": ["Queens"]}] 



In [None]:
### balance negative and positive data
def balance(source, target, true_false_division_range):
    low = true_false_division_range[0]
    high = true_false_division_range[1]

    original = pd.read_table(source)
    true_part = original[original['clicked'] == 1]
    false_part = original[original['clicked'] == 0]

    if len(true_part) / len(false_part) < low:
        print(f'Drop {len(false_part) - int(len(true_part) / low)} from false part')
        false_part = false_part.sample(n=int(len(true_part) / low))

    elif len(true_part) / len(false_part) > high:
        print(f'Drop {len(true_part) - int(len(false_part) * high)} from true part')
        true_part = true_part.sample(n=int(len(false_part) * high))

    balanced = pd.concat([true_part, false_part]).sample(frac=1).reset_index(drop=True)
    balanced.to_csv(target, sep='\t', index=False)
    return balanced

df_balance = balance(
    './train/behaviors_cleaned.tsv',
    './train/behaviors_cleaned_balanced.tsv',
    [1 / 2, 2],
)

df_balance.head()

Drop 5134412 from false part


Unnamed: 0,clicked_news,candidate_news,clicked
0,N26168 N45395 N27402 N55846 N27927 N3046 N6228...,N4663,0
1,N51238 N13995 N61277 N36030 N36312 N40704 N465...,N40381,0
2,N10779 N61864 N14538 N44018 N14939 N22570 N558...,N47652,1
3,N60615 N28533 N15185 N31686 N12393 N30145 N298...,N38869,1
4,N8748 N23505 N61864 N11071 N29177 N16695 N1809...,N24180,1


In [None]:
def parse_news(source, target, word2int_path, entity2int_path, mode):

    def clean_text(text):
        return re.sub(r'[^a-zA-Z ]', '', text).lower().strip()

    if mode == 'train':
        word2int, word2freq, entity2int, entity2freq = [{} for _ in range(4)]

        news = pd.read_table(source)
        news.entities.fillna('[]', inplace=True)
        parsed_news = pd.DataFrame(columns=['id', 'title', 'entities'])

        with tqdm(total=len(news), desc="Counting words and entities") as pbar:
            for row in news.itertuples(index=False):
                for w in clean_text(row.title).split():
                    if w not in word2freq:
                        word2freq[w] = 1
                    else:
                        word2freq[w] += 1

                for e in json.loads(row.entities):
                    # Count occurrence time within title
                    occur_lis = list(filter(lambda x: x < len(row.title), e['OccurrenceOffsets']))
                    times = len(occur_lis) * e['Confidence']

                    if times > 0:
                        if e['WikidataId'] not in entity2freq:
                            entity2freq[e['WikidataId']] = times
                        else:
                            entity2freq[e['WikidataId']] += times
                pbar.update(1)

        for k, v in word2freq.items():
            if v >= Config.word_freq_threshold:
                word2int[k] = len(word2int) + 1

        for k, v in entity2freq.items():
            if v >= Config.entity_freq_threshold:
                entity2int[k] = len(entity2int) + 1

        with tqdm(total=len(news), desc="Parsing words and entities") as pbar:
            for row in news.itertuples(index=False):
                new_row = [
                    row.id, [0] * Config.num_words_a_news,
                    [0] * Config.num_words_a_news
                ]

                # Calculate local entity map (map lower single word to entity)
                local_entity_map = {}
                for e in json.loads(row.entities):
                    if e['Confidence'] > Config.entity_confidence_threshold and e[
                            'WikidataId'] in entity2int:
                        for x in ' '.join(e['SurfaceForms']).lower().split():
                            local_entity_map[x] = entity2int[e['WikidataId']]
                try:
                    for i, w in enumerate(clean_text(row.title).split()):
                        if w in word2int:
                            new_row[1][i] = word2int[w]
                            if w in local_entity_map:
                                new_row[2][i] = local_entity_map[w]
                except IndexError:
                    pass
                parsed_news.loc[len(parsed_news)] = new_row

                pbar.update(1)

        parsed_news.to_csv(target, sep='\t', index=False)

        word2int = pd.DataFrame(word2int.items(), columns=['word','int'])
        word2int.to_csv(word2int_path, sep='\t', index=False)

        entity2int = pd.DataFrame(entity2int.items(), columns=['entity', 'int'])
        entity2int.to_csv(entity2int_path, sep='\t', index=False)
        return parsed_news, word2int, entity2int

    elif mode == 'test':
        news = pd.read_table(source)
        news.entities.fillna('[]', inplace=True)
        parsed_news = pd.DataFrame(columns=['id', 'title', 'entities'])

        word2int = dict(pd.read_table(word2int_path).values.tolist())
        entity2int = dict(pd.read_table(entity2int_path).values.tolist())

        word_total = 0
        word_missed = 0

        with tqdm(total=len(news), desc="Parsing words and entities") as pbar:
            for row in news.itertuples(index=False):
                new_row = [
                    row.id, [0] * Config.num_words_a_news,
                    [0] * Config.num_words_a_news
                ]

                # Calculate local entity map (map lower single word to entity)
                local_entity_map = {}
                for e in json.loads(row.entities):
                    if e['Confidence'] > Config.entity_confidence_threshold and e[
                            'WikidataId'] in entity2int:
                        for x in ' '.join(e['SurfaceForms']).lower().split():
                            local_entity_map[x] = entity2int[e['WikidataId']]
                try:
                    for i, w in enumerate(clean_text(row.title).split()):
                        word_total += 1
                        if w in word2int:
                            new_row[1][i] = word2int[w]
                            if w in local_entity_map:
                                new_row[2][i] = local_entity_map[w]
                        else:
                            word_missed += 1
                except IndexError:
                    pass

                parsed_news.loc[len(parsed_news)] = new_row
                pbar.update(1)
        print(f'Out-of-Vocabulary rate: {word_missed/word_total:.4f}')
        parsed_news.to_csv(target, sep='\t', index=False)
        return parsed_news

    else:
        print('Wrong mode!')
        return 0

df_parsed_news, df_word2int, df_entity2int = parse_news(
    './train/news_cleaned.tsv',
    './train/news_with_entity.tsv',
    './train/word2int.tsv',
    './train/entity2int.tsv',
    mode='train'
)

df_parsed_news_test = parse_news(
    './test/news_cleaned.tsv',
    './test/news_with_entity.tsv',
    './train/word2int.tsv',
    './train/entity2int.tsv',
    mode='test'
)

print(df_entity2int.head())
print(df_word2int.head())
df_parsed_news.head()

Counting words and entities: 100%|██████████| 51282/51282 [00:01<00:00, 48194.92it/s]
Parsing words and entities: 100%|██████████| 51282/51282 [06:47<00:00, 125.94it/s]
Parsing words and entities: 100%|██████████| 42416/42416 [05:02<00:00, 140.25it/s]


Out-of-Vocabulary rate: 0.0506
     entity  int
0    Q43274    1
1     Q9682    2
2   Q193583    3
3  Q1215884    4
4    Q49233    5
        word  int
0        the    1
1     brands    2
2      queen    3
3  elizabeth    4
4     prince    5


Unnamed: 0,id,title,entities
0,N55528,"[1, 2, 3, 4, 5, 6, 7, 5, 8, 9, 10, 0, 0, 0, 0,...","[0, 0, 2, 2, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."
1,N19639,"[11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,N61837,"[1, 16, 17, 18, 19, 20, 21, 1, 22, 17, 23, 24,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,N53526,"[25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 3...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,N38324,"[31, 37, 38, 39, 17, 40, 41, 42, 37, 43, 44, 0...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [None]:
def transform_entity_embedding(source, target, entity2int_path):

    entity_embedding = pd.read_table(source, header=None)
    entity_embedding['vector'] = entity_embedding.iloc[:, 1:101].values.tolist()
    entity_embedding = entity_embedding[[0, 'vector']].rename(columns={0: "entity"})
    entity2int = pd.read_table(entity2int_path)
    merged_df = pd.merge(entity_embedding, entity2int, on='entity').sort_values('int')
    entity_embedding_transformed = np.zeros(
        (len(entity2int) + 1, Config.entity_embedding_dim))
    for row in merged_df.itertuples(index=False):
        entity_embedding_transformed[row.int] = row.vector
    np.save(target, entity_embedding_transformed)
    return entity_embedding_transformed

entity_transformed = transform_entity_embedding(
    "./train/entity_embedding.vec",
    "./train/entity_embedding.npy",
    "./train/entity2int.tsv",
)

print(entity_transformed.shape)

(3072, 100)


In [None]:
def transform2json(source, target):
    behaviors = pd.read_table(
        source, header=None, names=['uid', 'time', 'clicked_news', 'impression'])
    
    f = open(target, "w")
    with tqdm(total=len(behaviors), desc="Transforming tsv to json") as pbar:
        for row in behaviors.itertuples(index=False):
            item = {}
            item['uid'] = row.uid[1:]
            item['time'] = row.time
            item['impression'] = {
                x.split('-')[0][1:]: int(x.split('-')[1])
                for x in row.impression.split()
            }
            f.write(json.dumps(item) + '\n')
            pbar.update(1)
    f.close()

transform2json('./test/behaviors.tsv', './test/truth.json')

Transforming tsv to json: 100%|██████████| 73152/73152 [00:03<00:00, 18766.81it/s]


## 1.2 DataLoader

In [None]:
class DKNDataset(Dataset):
    def __init__(self, behaviors_path, news_with_entity_path, config):

        super(Dataset, self).__init__()
        self.behaviors = pd.read_table(behaviors_path)
        self.behaviors['clicked_news'].fillna('', inplace=True)
        self.news_with_entity = pd.read_table(
            news_with_entity_path,
            index_col='id',
            converters={'title': literal_eval,'entities': literal_eval}
            )
        self.config = config

    def __len__(self):
        return len(self.behaviors)

    def __getitem__(self, idx):

        def news2dict(news, df):
            if news in df.index:
                news_dict = {"word": df.loc[news].title, "entity": df.loc[news].entities}  
            else: 
                # if cant find the ID in entity list
                news_dict = {
                    "word": [0] * Config.num_words_a_news, 
                    "entity": [0] * Config.num_words_a_news}
            return news_dict

        item = {}
        row = self.behaviors.iloc[idx]
        item["clicked"] = row['clicked']
        item["candidate_news"] = news2dict(row['candidate_news'], self.news_with_entity)

        item["clicked_news"] = [
            news2dict(x, self.news_with_entity)
            for x in row['clicked_news'].split()[:self.config.num_clicked_news_a_user]
        ]

        # if item["clicked_news"] dont have num_clicked_news_a_user history
        padding = {
            "word": [0] * self.config.num_words_a_news,
            "entity": [0] * self.config.num_words_a_news
        }

        repeated_times = self.config.num_clicked_news_a_user - len(item["clicked_news"])
        item["clicked_news"].extend([padding] * repeated_times)
        return item

### testing
dataset = DKNDataset(
    './train/behaviors_cleaned.tsv',
    './train/news_with_entity.tsv',
    model_config
)

## train test split
train_size = int(model_config.train_validation_split[0] / sum(Config.train_validation_split) * len(dataset))
validation_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, (train_size, validation_size))

train_dataloader = DataLoader(
    train_dataset, batch_size=model_config.batch_size, shuffle=True,
    num_workers=model_config.num_workers, drop_last=True)

val_dataloader = DataLoader(
    val_dataset, batch_size=model_config.batch_size, shuffle=False,
    num_workers=model_config.num_workers, drop_last=False)

sample_dataset = next(iter(train_dataset))
entity_embedding = np.load('./train/entity_embedding.npy')
context_embedding = np.load('./train/entity_embedding.npy')

print(entity_embedding.shape)
print(context_embedding.shape)
print(sample_dataset['clicked'])
print(sample_dataset['candidate_news'])
print(sample_dataset['clicked_news'][0])

(3072, 100)
(3072, 100)
0
{'word': [768, 4468, 37, 2723, 1079, 146, 9561, 12829, 53, 1989, 4644, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'entity': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'word': [43, 1519, 260, 136, 1040, 37, 2478, 299, 1060, 719, 1637, 120, 2882, 2757, 11743, 0, 0, 0, 0, 0], 'entity': [0, 172, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}




## 1.3 Model Construction

In [None]:
class Attention(torch.nn.Module):
    """
    Attention Net
    Input embedding vectors (produced by KCNN) of a candidate news and all of user's clicked news,
    produce final user embedding vectors with respect to the candidate news.
    """

    def __init__(self, config):
        super(Attention, self).__init__()
        self.config = config
        self.dnn = nn.Sequential(
            nn.Linear(len(self.config.window_sizes) * 2 * self.config.num_filters, 16), 
            nn.Linear(16, 1)
        )

    def forward(self, candidate_news_vector, clicked_news_vector):
        """
        Args:
          candidate_news_vector: batch_size, len(window_sizes) * num_filters
          clicked_news_vector: num_clicked_news_a_user, batch_size, len(window_sizes) * num_filters
        Returns:
          user_vector: batch_size, len(window_sizes) * num_filters
        """

        # [num_clicked_news_a_user, m, len(window_sizes) * num_filters]
        candidate_expanded = candidate_news_vector.expand(
            self.config.num_clicked_news_a_user, -1, -1)
        
        # [num_clicked_news_a_user, m, 1]
        clicked_news_weights = self.dnn(
            torch.cat((clicked_news_vector, candidate_expanded), dim=-1)
            )

        # [m, num_clicked_news_a_user]
        clicked_news_weights = F.softmax(
            clicked_news_weights.squeeze(-1).transpose(0, 1),
            dim=1)        

        # [m, num_clicked_news_a_user, 1] x [m, num_clicked_news_a_user, len(window_sizes) * num_filters]
        # [batch_size, len(window_sizes) * num_filters]
        user_vector = torch.bmm(
            clicked_news_weights.unsqueeze(1), clicked_news_vector.transpose(0, 1)).squeeze(1)
        return user_vector

class KCNN(torch.nn.Module):
    """
    Knowledge-aware CNN (KCNN) based on Kim CNN.
    Input a news sentence (e.g. its title), produce its embedding vector.
    """

    def __init__(self, config, entity_embedding, context_embedding):
        
        super(KCNN, self).__init__()
        self.config = config
        self.word_embedding = nn.Embedding(
            config.num_word_tokens, config.word_embedding_dim)
        
        self.entity_embedding = entity_embedding
        self.context_embedding = context_embedding
        self.transform_matrix = nn.Parameter(
            torch.empty(self.config.word_embedding_dim, self.config.entity_embedding_dim)
            )
        self.transform_bias = nn.Parameter(
            torch.empty(self.config.word_embedding_dim)
            )

        self.conv_filters = nn.ModuleDict({
            str(x): nn.Conv2d(
                3 if self.config.use_context else 2,
                self.config.num_filters,
                (x, self.config.word_embedding_dim),
                ) for x in self.config.window_sizes
        })

        self.transform_matrix.data.uniform_(-0.1, 0.1)
        self.transform_bias.data.uniform_(-0.1, 0.1)

    def forward(self, news):
        """
        Args:
          news:
            {
                "word": [Tensor(batch_size) * num_words_a_news],
                "entity":[Tensor(batch_size) * num_words_a_news]
            }
        Returns:
          final_vector: batch_size, len(window_sizes) * num_filters
        """

        # stack news["word"](list) to be torch tensor
        # [m, num_words_a_news] --> [m, num_words_a_news, word_embedding_dim]
        word_vector = self.word_embedding(torch.stack(news["word"], dim=1).to(device))

        # entity_embedding.shape = [3072, 100]
        # [m, num_words_a_news, 100]
        entity_vector = F.embedding(
            torch.stack(news["entity"], dim=1),
            torch.from_numpy(self.entity_embedding)).float().to(device)

        if self.config.use_context:
            # [m, num_words_a_news, 100]
            context_vector = F.embedding(
                torch.stack(news["entity"], dim=1),
                torch.from_numpy(self.context_embedding)).float().to(device)

        # The abbreviations are the same as those in paper
        b = self.config.batch_size
        n = self.config.num_words_a_news
        d = self.config.word_embedding_dim
        k = self.config.entity_embedding_dim
        
        # transform_matrix [word_dim, entity_dim] --> (b * n, word_embedding_dim, entity_embedding_dim)
        transformed_entity_vector = torch.bmm(
            self.transform_matrix.expand(b * n, -1, -1), # (5120, 100, 100)
            entity_vector.view(b * n, k, 1) # (5120, 100, 1)
            )
        
        # [m, num_words_a_news, word_embedding_dim]
        transformed_entity_vector = torch.tanh(
            torch.add(
                transformed_entity_vector.view(b, n, d),
                self.transform_bias.expand(b, n, -1)))

        if self.config.use_context:
            # [m, num_words_a_news, word_embedding_dim]
            transformed_context_vector = torch.tanh(
                torch.add(
                    torch.bmm(self.transform_matrix.expand(b * n, -1, -1),
                              context_vector.view(b * n, k, 1)).view(b, n, d),
                    self.transform_bias.expand(b, n, -1)))

        if self.config.use_context:
            # [m, 3, num_words_a_news, word_embedding_dim]
            multi_channel_vector = torch.stack(
                [word_vector, transformed_entity_vector, transformed_context_vector], dim=1)
            
        else:
            # [m, 2, num_words_a_news, word_embedding_dim]
            multi_channel_vector = torch.stack(
                [word_vector, transformed_entity_vector], dim=1)

        pooled_vectors = []
        for x in self.config.window_sizes:
            # [m, num_filters, num_words_a_news + 1 - x] 
            convoluted = self.conv_filters[str(x)](multi_channel_vector).squeeze(dim=3)
            activated = F.relu(convoluted)

            # [m, num_filters]
            # pooled = activated.max(dim=-1)[0]
            pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2)            
            pooled_vectors.append(pooled)

        # [m, len(window_sizes) * num_filters]
        final_vector = torch.cat(pooled_vectors, dim=1)
        return final_vector

class DKN(torch.nn.Module):
    """
    Deep knowledge-aware network.
    Input a candidate news and a list of user clicked news, produce the click probability.
    """

    def __init__(self, config, entity_embedding, context_embedding):
        super(DKN, self).__init__()
        self.config = config
        self.kcnn = KCNN(config, entity_embedding, context_embedding)

        if self.config.use_attention:
            self.attention = Attention(config)

        self.dnn = nn.Sequential(
            nn.Linear(
                len(self.config.window_sizes) * 2 * self.config.num_filters, 16), 
                nn.Linear(16, 1)
                )

    def forward(self, candidate_news, clicked_news):
        """
        Args:
          candidate_news:
            {
                "word": [Tensor(batch_size) * num_words_a_news],
                "entity":[Tensor(batch_size) * num_words_a_news]
            }
          clicked_news:
            [
                {
                    "word": [Tensor(batch_size) * num_words_a_news],
                    "entity":[Tensor(batch_size) * num_words_a_news]
                } * num_clicked_news_a_user
            ]
        Returns:
          click_probability: batch_size
        """
        # [m, len(window_sizes) * num_filters]
        candidate_news_vector = self.kcnn(candidate_news)

        # [num_clicked_news_a_user, m, len(window_sizes) * num_filters]
        clicked_news_vector = torch.stack([self.kcnn(x) for x in clicked_news])

        # [m, len(window_sizes) * num_filters]
        if self.config.use_attention:
            user_vector = self.attention(candidate_news_vector, clicked_news_vector)
        else:
            user_vector = clicked_news_vector.mean(dim=0)

        # Sigmoid is done with BCEWithLogitsLoss
        # batch_size
        click_probability = self.dnn(
            torch.cat((user_vector, candidate_news_vector), dim=1)).squeeze(dim=1)
        return click_probability

### testing
# sample_dataset = next(iter(train_dataloader))
dkn = DKN(model_config, entity_embedding, context_embedding).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([23.7]).float().to(device))
optimizer = torch.optim.Adam(dkn.parameters(), lr=Config.learning_rate)

y_pred = dkn(
    sample_dataset["candidate_news"], sample_dataset["clicked_news"])
y = sample_dataset["clicked"].float().to(device)
loss = criterion(y_pred, y)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

print(y_pred.size())
print(loss)

torch.Size([256])
tensor(1.4333, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


## 1.4 Evaluation

In [None]:
### evaluate
from sklearn.metrics import roc_auc_score

def mrr_score(y_true, y_score):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)

def dcg_score(y_true, y_score, k=10):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)

def ndcg_score(y_true, y_score, k=10):
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_score, k)
    return actual / best

loss_full, aucs, mrrs, ndcg5s, ndcg10s = [[] for _ in range(5)]

with tqdm(total=len(val_dataloader), desc="Checking loss and accuracy") as pbar:

    for minibatch in val_dataloader:

        y_pred = dkn(minibatch["candidate_news"], minibatch["clicked_news"])
        y = minibatch["clicked"].float().to(device)
        loss = criterion(y_pred, y)
        loss_full.append(loss.item())
        y_pred_list = y_pred.tolist()
        y_list = y.tolist()

        auc = roc_auc_score(y_list, y_pred_list)
        mrr = mrr_score(y_list, y_pred_list)
        ndcg5 = ndcg_score(y_list, y_pred_list, 5)
        ndcg10 = ndcg_score(y_list, y_pred_list, 10)

        aucs.append(auc)
        mrrs.append(mrr)
        ndcg5s.append(ndcg5)
        ndcg10s.append(ndcg10)
        pbar.update(1)
        # break

print(
    f"loss_full: {np.mean(loss_full)}",
    f"aucs: {np.mean(aucs)}",
    f"mrrs: {np.mean(mrrs)}",
    f"ndcg5s: {np.mean(ndcg5s)}",
    f"ndcg10s: {np.mean(ndcg10s)}",
)

Checking loss and accuracy:   0%|          | 1/4566 [00:31<39:35:36, 31.22s/it]

loss_full: 1.4227112531661987 aucs: 0.6055327868852459 mrrs: 0.016584527987922888 ndcg5s: 0.0 ndcg10s: 0.0





In [None]:
### inference
y_pred = []
y = []
count = 0

with tqdm(total=len(val_dataloader), desc="Inferering") as pbar:
    for minibatch in val_dataloader:
        y_pred.extend(dkn(minibatch["candidate_news"], minibatch["clicked_news"]).tolist())
        y.extend(minibatch["clicked"].float().tolist())
        pbar.update(1)
        count += 1
        if count == 500:
            break

y_pred = iter(y_pred)
y = iter(y)

truth_file = open('./test/truth.json', 'r')
submission_answer_file = open('./data/test/answer.json', 'w')
for line in truth_file.readlines():
    user_truth = json.loads(line)
    user_inference = copy.deepcopy(user_truth)
    for k in user_truth['impression'].keys():
        assert next(y) == user_truth['impression'][k]
        user_inference['impression'][k] = next(y_pred)
    submission_answer_file.write(json.dumps(user_inference) + '\n')