In [1]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class StateSpaceCollaborativeFilter:
    def __init__(self, n_users, n_items):
        self.n_users = n_users
        self.n_items = n_items
        self.user_states = np.random.rand(n_users, 10)  # Initialize user state vectors
        self.item_states = np.random.rand(n_items, 10)  # Initialize item state vectors

    def update_states(self, interactions, learning_rate=0.01):
        """Update user and item states based on observed interactions."""
        for user, item, rating in interactions:
            error = rating - np.dot(self.user_states[user], self.item_states[item])
            self.user_states[user] += learning_rate * error * self.item_states[item]
            self.item_states[item] += learning_rate * error * self.user_states[user]

    def predict(self, user, item):
        """Predict the rating for a user-item pair."""
        return np.dot(self.user_states[user], self.item_states[item])

    def recommend(self, user, top_k=5):
        """Recommend top_k items for a given user."""
        scores = np.dot(self.item_states, self.user_states[user])
        return np.argsort(scores)[-top_k:][::-1]  # Return indices of top_k items

# Example usage
def main():
    n_users = 5
    n_items = 10
    interactions = [
        (0, 1, 4.0),
        (0, 2, 5.0),
        (1, 2, 3.0),
        (3, 4, 2.0),
    ]  # Format: (user_id, item_id, rating)

    model = StateSpaceCollaborativeFilter(n_users, n_items)

    print("Initial State:")
    print("User States:", model.user_states)
    print("Item States:", model.item_states)

    for epoch in range(10):
        model.update_states(interactions)

    print("\nUpdated State:")
    print("User States:", model.user_states)
    print("Item States:", model.item_states)

    user = 0
    print(f"\nTop recommendations for user {user}:", model.recommend(user))

if __name__ == "__main__":
    main()


Initial State:
User States: [[0.23960302 0.8608224  0.76370735 0.21628098 0.06142652 0.84458568
  0.64029663 0.13552885 0.9563945  0.72933667]
 [0.34753389 0.79684356 0.90064535 0.82224556 0.71082559 0.10565611
  0.05468688 0.61704738 0.80263969 0.41281787]
 [0.64832368 0.00433176 0.5070145  0.4007786  0.57184146 0.10882089
  0.0845258  0.97698251 0.43477065 0.7523511 ]
 [0.30824471 0.04648345 0.33343507 0.25904574 0.48601466 0.8198254
  0.70201992 0.91270031 0.42437763 0.90804264]
 [0.58397798 0.00576773 0.52576731 0.78411547 0.62896956 0.06405239
  0.56671912 0.87428647 0.72234375 0.66064156]]
Item States: [[0.40594265 0.20100721 0.95453887 0.65377963 0.46516061 0.665635
  0.0463661  0.45388411 0.30170365 0.36038983]
 [0.89992535 0.78910515 0.32303305 0.99928981 0.91992431 0.64901656
  0.675      0.84124552 0.42429245 0.12535704]
 [0.1039352  0.47208739 0.06987412 0.67026291 0.24311446 0.53506081
  0.74122054 0.90356712 0.97192106 0.7701846 ]
 [0.27975872 0.98838729 0.07881316 0.7923