Skip to content

Commit

Permalink
Merge pull request #55 from saghiles/master
Browse files Browse the repository at this point in the history
Add: GraphModule, c2pf example. Change: c2pf, pcrl, results. Remove: which_ function
  • Loading branch information
saghiles committed Mar 8, 2019
2 parents cfffee8 + 71756f1 commit 0721473
Show file tree
Hide file tree
Showing 20 changed files with 331 additions and 211 deletions.
55 changes: 52 additions & 3 deletions cornac/data/graph.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,67 @@
# -*- coding: utf-8 -*-

"""
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
@author: Aghiles Salah <asalah@smu.edu.sg>
"""

from . import Module
import scipy.sparse as sp
import numpy as np


class GraphModule(Module):
"""Graph module
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.raw_data = kwargs.get('data', None)
self.matrix = None
self.map_data = []

def _build_triplet(self, ordered_ids):
"""Build adjacency matrix in sparse triplet format using maped ids
"""

for i, j, val in self.raw_data:
self.map_data.append([ordered_ids[i], ordered_ids[j], val])
self.map_data = np.asanyarray(self.map_data)
self.raw_data = None

def _build_sparse_matrix(self, triplet):
"""Build sparse adjacency matrix
"""

n_rows = max(triplet[:, 0]) + 1
n_cols = max(triplet[:, 1]) + 1
self.matrix = sp.csc_matrix((triplet[:, 2], (triplet[:, 0], triplet[:, 1])), shape=(n_rows, n_cols))

def get_train_triplet(self, train_row_ids, train_col_ids):
"""Get the training tuples
"""
train_triplet = []
# this makes operations much more efficient
train_row_ids = np.asanyarray(list(train_row_ids))
train_col_ids = np.asanyarray(list(train_col_ids))
for i, j, val in self.map_data:
if (i not in train_row_ids) or (j not in train_col_ids):
continue
train_triplet.append([i, j, val])

return np.asarray(train_triplet)

def build(self, ordered_ids):
pass
self._build_triplet(ordered_ids)
self._build_sparse_matrix(self.map_data)

def batch(self, batch_ids):

"""Collaborative Context Poisson Factorization.
Parameters
----------
batch_ids: array, required
An array conting the ids of rows to be returned from the sparse adjacency matrix.
"""

return self.matrix[batch_ids]
7 changes: 4 additions & 3 deletions cornac/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np


class Module:
"""Module
Expand Down Expand Up @@ -43,8 +44,8 @@ def _build_feature(self, ordered_ids):
return

self.data_feature = np.zeros((len(ordered_ids), self.feature_dim))
for idx, id in enumerate(ordered_ids):
self.data_feature[idx] = self._id_feature[id]
for map_id, raw_id in enumerate(ordered_ids.keys()):
self.data_feature[map_id] = self._id_feature[raw_id]
if self._normalized:
self.data_feature = self.data_feature - np.min(self.data_feature)
self.data_feature = self.data_feature / (np.max(self.data_feature) + 1e-10)
Expand All @@ -59,4 +60,4 @@ def build(self, ordered_ids):
def batch_feature(self, batch_ids):
"""Return a matrix (batch of feature vectors) corresponding to provided batch_ids
"""
return self.data_feature[batch_ids]
return self.data_feature[batch_ids]
10 changes: 6 additions & 4 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,14 @@ def _build_uir(self, train_data, test_data, val_data=None):

def _build_modules(self):
for user_module in [self.user_text, self.user_image, self.user_graph]:
if user_module is None: continue
user_module.build(ordered_ids=self.global_uid_map.keys())
if user_module is None:
continue
user_module.build(ordered_ids=self.global_uid_map)

for item_module in [self.item_text, self.item_image, self.item_graph]:
if item_module is None: continue
item_module.build(ordered_ids=self.global_iid_map.keys())
if item_module is None:
continue
item_module.build(ordered_ids=self.global_iid_map)

for data_set in [self.train_set, self.test_set, self.val_set]:
if data_set is None: continue
Expand Down
2 changes: 1 addition & 1 deletion cornac/eval_methods/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from .base_method import BaseMethod
from ..utils.common import safe_indexing
from ..experiment.cv_result import CVSingleModelResult
from ..experiment.result import CVSingleModelResult


class CrossValidation(BaseMethod):
Expand Down
65 changes: 0 additions & 65 deletions cornac/experiment/cv_result.py

This file was deleted.

4 changes: 2 additions & 2 deletions cornac/experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-

"""
@author: Aghiles Salah
@author: Aghiles Salah <asalah@smu.edu.sg>
Quoc-Tuan Truong <tuantq.vnu@gmail.com>
"""

from .result import Result
from .cv_result import CVResult
from .result import CVResult


class Experiment:
Expand Down
57 changes: 57 additions & 0 deletions cornac/experiment/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,38 @@ def _get_data_frame(self, avg_res, model_name, metric_names):
return avg_res


class CVSingleModelResult(SingleModelResult):
""" Cross Validation Result Class for a single model
Parameters
----------
"""

def __init__(self, metric_avg_results=None):
self.avg = metric_avg_results
self.per_fold_avg = {}
self.avg = {}

def _add_fold_res(self, fold, metric_avg_results):
# think to organize the results first
self.per_fold_avg[fold] = metric_avg_results

def _compute_avg_res(self):
for mt in self.per_fold_avg[0]:
self.avg[mt] = 0.0
for f in self.per_fold_avg:
for mt in self.per_fold_avg[f]:
self.avg[mt] += self.per_fold_avg[f][mt] / len(self.per_fold_avg)

def _organize_avg_res(self, model_name, metric_names):
# global avg
self.avg = self._get_data_frame(avg_res=self.avg, model_name=model_name, metric_names=metric_names)
# per_fold avg
for f in self.per_fold_avg:
self.per_fold_avg[f] = self._get_data_frame(avg_res=self.per_fold_avg[f], model_name=model_name,
metric_names=metric_names)


class Result:
""" Result Class
Expand All @@ -50,3 +82,28 @@ def _add_model_res(self, res, model_name):

def show(self):
print(self.avg)


class CVResult(Result):
""" Cross Validation Result Class
Parameters
----------
"""

def __init__(self, n_folds, avg_results=None):
self.avg = avg_results
self.per_fold_avg = {}
for f in range(n_folds):
self.per_fold_avg[f] = None

def _add_model_res(self, res, model_name):
if self.avg is None:
self.avg = res.avg
else:
self.avg = self.avg.append(res.avg)
for f in res.per_fold_avg:
if self.per_fold_avg[f] is None:
self.per_fold_avg[f] = res.per_fold_avg[f]
else:
self.per_fold_avg[f] = self.per_fold_avg[f].append(res.per_fold_avg[f])
2 changes: 1 addition & 1 deletion cornac/models/c2pf/cpp/cpp_c2pf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ void c2pf_cpp(Mat const&tX, Mat const&tC, int const&g, Mat &G_s, Mat &G_r, Mat &
//Hyper parameter setting
double att = 1.0;
double aa = 0.3;
double a_ = 6.;
//double a_ = 6.;
double a1_ = 5.;
double a_t = at;
double b_t = bt;
Expand Down

0 comments on commit 0721473

Please sign in to comment.