In [29]:
import json
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

from data.graph_loader import load_graph
%matplotlib inline

print(f'Loading ratings')
mr = pd.read_csv('movielens.csv')

# Load entities
entity_labels = dict()
entity_names = dict()
with open('../data/mindreader/entities_clean.json', 'r') as file:
    data = json.load(file)
    
    for uri, name, labels in data:
        entity_labels[uri] = set(labels.split('|'))
        entity_names[uri] = name

# Load NX graph
print(f'Loading graph')
g = load_graph('../data/graph/triples.csv', directed=True)

Loading ratings
Loading graph


In [30]:
def _propagate(node, preferences, preference):
    for neighbor in g.neighbors(node):
        preferences[neighbor].append(preference)
        
        # _propagate(neighbor, preferences, preference * 0.5)
            

def propagate(movie_preferences):
    preferences = defaultdict(list)
    
    for movie, preference in movie_preferences.items():
        _propagate(movie, preferences, preference)
            
    return preferences


def between(num, lower, upper):
    return lower < num < upper


def reduce_preferences(preferences):
    reduced = dict()
    
    # Reduces preferences from a list of ratings to its majority vote
    for uri, preference_list in preferences.items():
        mean = np.around(np.mean(preference_list))
        std = np.std(preference_list)
        
        prediction = None if between(mean, 2.5, 3.5) or std > 0.5 else mean
        if prediction:
            reduced[uri] = prediction
    
    return reduced


def infer_preferences(ratings, user):
    user_ratings = ratings[ratings.userId == user][['uri', 'rating', 'isItem']]
    movie_preferences = dict()
    
    for idx, row in user_ratings[user_ratings.isItem].iterrows():
        movie_preferences[row['uri']] = row['rating']
    
    return reduce_preferences(propagate(movie_preferences))


def predict(user, uris=None):
    predicted = infer_preferences(mr, user)
    

    rows = []
    for key, rating in predicted.items():
        rows.append([user, rating, key, False, 'N/A'])
        
    data = pd.DataFrame.from_dict(rows, columns={'userId', 'rating', 'uri', 'isItem', 'title'})
    
    print(data)
    
    # remember uris
    # print(len(predicted))


def get_metrics(df, uris=None):
    users = df.userId.unique()
    
    for user in tqdm(users):
         predict(user, uris=uris)

In [31]:
print(mr)
splits = {
    'All': None,
    'Genre': {'Genre'},
    'Subject': {'Subject'},
    'Actor': {'Actor'},
    'Director': {'Director'},
    'Studio': {'Company'},
    'Decade': {'Decade'}
}

def get_valid_uris(restrict_to=None):
    if not restrict_to:
        return set(entity_labels.keys())
    
    return {e for e, l in entity_labels.items() if not restrict_to or l.intersection(restrict_to)}

get_metrics(mr, get_valid_uris())






  0%|          | 0/671 [00:00<?, ?it/s][A[A[A[A  0%|          | 0/671 [00:00<?, ?it/s]


       Unnamed: 0  userId  rating  \
0               0       1     2.5   
1             442       1     2.5   
2             259       1     2.0   
3             165       1     4.0   
4             305       1     2.0   
...           ...     ...     ...   
92248       56230     671     4.0   
92249       56267     671     4.0   
92250       27737     671     4.5   
92251       54017     671     4.0   
92252       30767     671     4.0   

                                                title  \
0                              Dangerous Minds (1995)   
1                Star Trek: The Motion Picture (1979)   
2                                      Ben-Hur (1959)   
3      Cinema Paradiso (Nuovo cinema Paradiso) (1989)   
4                                       Gandhi (1982)   
...                                               ...   
92248               O Brother, Where Art Thou? (2000)   
92249                            Thirteen Days (2000)   
92250                           Ocean's El

ValueError: cannot use columns parameter with orient='columns'