In [44]:
pip install scikit-sparse

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [45]:
pip install sansa

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [57]:
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split
from sansa import (
    SANSA,
    SANSAConfig,
    CHOLMODGramianFactorizerConfig,
    ICFGramianFactorizerConfig,
    UMRUnitLowerTriangleInverterConfig,
)

In [58]:
# 示例数据（确保用户和物品 ID 连续并且增加交互）
data = {
    'user_id': [0, 0, 1, 1, 2, 2, 3, 3, 0, 1, 2, 0, 1, 2, 0, 1],
    'item_id': [0, 1, 0, 2, 1, 2, 0, 1, 2, 1, 2, 0, 1, 2, 1, 0],
    'interaction': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
}
df = pd.DataFrame(data)

In [59]:
# 确保用户和物品 ID 是连续的
df['user_id'] = df['user_id'].astype('category').cat.codes
df['item_id'] = df['item_id'].astype('category').cat.codes

In [60]:
# 创建用户-物品交互矩阵
num_users = df['user_id'].nunique()
num_items = df['item_id'].nunique()
interaction_matrix = csr_matrix((df['interaction'], (df['user_id'], df['item_id'])), shape=(num_users, num_items))

# 确认矩阵形状
print("Interaction matrix shape:", interaction_matrix.shape)
print("Number of non-zero entries:", interaction_matrix.nnz)

# 提取训练集
X_train, X_test = train_test_split(interaction_matrix, test_size=0.2, random_state=42)

Interaction matrix shape: (4, 3)
Number of non-zero entries: 10


In [61]:
from sansa import ICFGramianFactorizerConfig
from sansa import UMRUnitLowerTriangleInverterConfig
from sansa import CHOLMODGramianFactorizerConfig
# 配置模型
factorizer_config = CHOLMODGramianFactorizerConfig()  # 使用默认参数
inverter_config = UMRUnitLowerTriangleInverterConfig(scans=1, finetune_steps=5)

config = SANSAConfig(
    l2=20.0,
    weight_matrix_density=5e-5,
    gramian_factorizer_config=factorizer_config,
    lower_triangle_inverter_config=inverter_config,
)

In [62]:
# 训练模型
try:
    model.fit(X_train)
    print("Model trained successfully.")
except Exception as e:
    print("Error during fitting:", e)

INFO:sansa.model:Computing LDL^T decomposition of permuted item-item matrix...
INFO:sansa.core.factorizers:Computing incomplete Cholesky decomposition of X^TX + 20.0*I...
                    Selected density 0.005000% is too low, clipping to 66.666667%. 
                    Minimum density might result in worse quality of the approximate factor.
                
INFO:sansa.core.factorizers:Finding a fill-in reducing ordering (method = colamd)...
INFO:sansa.core.factorizers:Computing approximate Cholesky decomposition (method = CHOLMOD)...
INFO:sansa.core.factorizers:Dropping small entries in L (66.666667% dense, target = 66.666667%)...
INFO:sansa.core.factorizers:Scaling columns and creating diagonal matrix D (LL^T -> L'DL'^T)...
INFO:sansa.model:Computing approximate inverse of L...
INFO:sansa.core.inverters:Calculating initial guess using 1 step of Schultz method...
INFO:sansa.core.inverters:Calculating approximate inverse using Uniform Minimal Residual algorithm...
INFO:sansa.core._

Model trained successfully.


In [63]:
w1, w2 = model.weights  # tuple of scipy.sparse.csr_matrix of shape (num_items, num_items)

model.load_weights((w1, w2))

<sansa.model.SANSA at 0x26a3df3bbd0>

In [64]:
# 检查模型权重
weights = model.weights
if weights is not None:
    w1, w2 = weights
    print("Model weights shapes:", w1.shape, w2.shape)
else:
    print("Model weights are None. The model may not have trained correctly.")  

Model weights shapes: (3, 3) (3, 3)
