In [4]:
import json
from collections import defaultdict
import pandas as pd
from data.graph_loader import load_graph
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

print(f'Loading ratings')
mr = pd.read_csv('../data/mindreader/ratings.csv')

# Remove unknown ratings
mr = mr[mr.sentiment != 0]

# Load entities
entities = dict()
with open('../data/mindreader/entities_clean.json', 'r') as file:
    data = json.load(file)
    
    for uri, name, labels in data:
        entities[uri] = set(labels.split('|'))

# Load NX graph
print(f'Loading graph')
g = load_graph('../data/graph/triples.csv', directed=True)

Loading ratings
Loading graph


In [5]:
def _propagate(node, preferences, preference):
    for neighbor in g.neighbors(node):
        preferences[neighbor].append(preference)
        
        # _propagate(neighbor, preferences, preference * 0.5)
            

def propagate(movie_preferences):
    preferences = defaultdict(list)
    
    for movie, preference in movie_preferences.items():
        _propagate(movie, preferences, preference)
            
    return preferences


def reduce_preferences(preferences):
    reduced = dict()
    
    # Reduces preferences from a list of ratings to its majority vote
    for uri, preference_list in preferences.items():
        mean = np.mean(preference_list)
        variance = np.var(preference_list)
        
        prediction = 0 if mean == 0. or variance > 0.5 else 1 if mean > 0 else -1
        if prediction:
            reduced[uri] = prediction
    
    return reduced


def infer_preferences(ratings, user):
    user_ratings = ratings[ratings.userId == user][['uri', 'sentiment', 'isItem']]
    entity_preferences = dict()
    movie_preferences = dict()
    
    for idx, row in user_ratings[user_ratings.isItem].iterrows():
        movie_preferences[row['uri']] = row['sentiment']
        
    for idx, row in user_ratings[~user_ratings.isItem].iterrows():
        entity_preferences[row['uri']] = row['sentiment']
    
    return reduce_preferences(propagate(movie_preferences)), entity_preferences


def predict(user, uris=None):
    res, actual = infer_preferences(mr, user)
    overlapping = set(res.keys()).intersection(set(actual.keys()))
    tp, fp, tn, fn = [0 for _ in range(4)]

    for key in overlapping:
        if uris and key not in uris:
            continue
        
        if actual[key] == 1 and res[key] == 1:
            tp += 1
        elif actual[key] == -1 and res[key] == 1:
            fp += 1
        elif actual[key] == -1 and res[key] == -1:
            tn += 1
        elif actual[key] == 1 and res[key] == -1:
            fn += 1
    
    return tp, fp, tn, fn


def get_metrics(df, uris=None):
    users = df.userId.unique()
    total_tp, total_fp, total_tn, total_fn = [0 for _ in range(4)]
    
    for user in tqdm(users):
        tp, fp, tn, fn = predict(user, uris=uris)
        
        total_tp += tp
        total_fp += fp
        total_tn += tn
        total_fn += fn
    
    precision = total_tp / (total_tp + total_fp)
    recall = total_tp / (total_tp + total_fn)
    true_negative = total_tn / (total_tn + total_fp)
    accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return precision, recall, true_negative, f1, accuracy


def print_metrics(precision, recall, true_negative, f1, accuracy):
    print(f'Precision: {precision * 100}%')
    print(f'Recall: {recall * 100}%')
    print(f'True negative rate: {true_negative * 100}%')
    print(f'F1: {f1 * 100}%')
    print(f'Accuracy: {accuracy * 100}%')

In [None]:
splits = {
    'All': None,
    'Genre': {'Genre'},
    'Subject': {'Subject'},
    'Person': {'Person'},
    'Studio': {'Company'},
    'Decade': {'Decade'}
}


def get_valid_uris(restrict_to=None):
    if not restrict_to:
        return set(entities.keys())
    
    return {e for e, l in entities.items() if not restrict_to or l.intersection(restrict_to)}
    
for split, restriction in splits.items():
    metrics = get_metrics(mr, get_valid_uris(restriction))
    
    print(split)
    print(print_metrics(*metrics))
    print()

  0%|          | 0/850 [00:00<?, ?it/s]  0%|          | 4/850 [00:00<00:22, 37.44it/s]  1%|          | 9/850 [00:00<00:20, 40.42it/s]  1%|▏         | 12/850 [00:00<00:25, 32.50it/s]  2%|▏         | 18/850 [00:00<00:22, 36.21it/s]  3%|▎         | 22/850 [00:00<00:23, 35.56it/s]  3%|▎         | 26/850 [00:00<00:24, 33.04it/s]  4%|▎         | 30/850 [00:00<00:25, 31.88it/s]  4%|▍         | 34/850 [00:00<00:24, 32.92it/s]  5%|▍         | 39/850 [00:01<00:24, 33.38it/s]  5%|▌         | 43/850 [00:01<00:25, 31.71it/s]  6%|▌         | 49/850 [00:01<00:21, 36.53it/s]  6%|▌         | 53/850 [00:01<00:25, 31.76it/s]  7%|▋         | 57/850 [00:01<00:24, 32.09it/s]  7%|▋         | 62/850 [00:01<00:22, 35.21it/s]  8%|▊         | 67/850 [00:01<00:20, 37.35it/s]  9%|▊         | 74/850 [00:02<00:19, 40.77it/s]  9%|▉         | 79/850 [00:02<00:19, 39.12it/s] 10%|▉         | 84/850 [00:02<00:19, 40.17it/s] 10%|█         | 89/850 [00:02<00:18, 41.62it/s] 11%|█         | 94/850 [00:02