In [236]:
import pandas as pd

from typing import Dict
import numpy as np
from itertools import chain

from scipy.sparse import csr_matrix

In [339]:
data = pd.read_csv('UserItem.csv')

In [322]:
class UserItemMatrix:
    def __init__(self, sales_data: pd.DataFrame):
        """
        Class initialization. You can make necessary
        calculations here.

        Args:
            sales_data (pd.DataFrame): Sales dataset.

        Example:
            sales_data (pd.DataFrame):

                user_id  item_id  qty   price
            0        1      118    1   626.66
            1        1      285    1  1016.57
            2        2     1229    3   518.99
            3        4     1688    2   940.84
            4        5     2068    1   571.36
        """
        self.sales_data = sales_data

    # @staticmethod
    def user_count(self) -> int:
        """
        Returns:
            int: the number of users in sales_data.
        """
        return self.sales_data['user_id'].nunique()

    # @staticmethod
    def item_count(self) -> int:
        """
        Returns:
            int: the number of items in sales_data.
        """
        return self.sales_data['item_id'].nunique()

    # @staticmethod
    def user_map(self) -> Dict[int, int]:
        """Creates a mapping from user_id to matrix rows indexes.

        Example:
            sales_data (pd.DataFrame):

                user_id  item_id  qty   price
            0        1      118    1   626.66
            1        1      285    1  1016.57
            2        2     1229    3   518.99
            3        4     1688    2   940.84
            4        5     2068    1   571.36

            user_map (Dict[int, int]):
                {1: 0, 2: 1, 4: 2, 5: 3}

        Returns:
            Dict[int, int]: User map
        """
        return {v: k for k, v in enumerate(self.sales_data['user_id'].sort_values().unique())}

    # @staticmethod
    def item_map(self) -> Dict[int, int]:
        """Creates a mapping from item_id to matrix rows indexes.

        Example:
            sales_data (pd.DataFrame):

                user_id  item_id  qty   price
            0        1      118    1   626.66
            1        1      285    1  1016.57
            2        2     1229    3   518.99
            3        4     1688    2   940.84
            4        5     2068    1   571.36

            item_map (Dict[int, int]):
                {118: 0, 285: 1, 1229: 2, 1688: 3, 2068: 4}

        Returns:
            Dict[int, int]: Item map
        """
        return {v: k for k, v in enumerate(self.sales_data['item_id'].sort_values().unique())}

    # @staticmethod
    def csr_matrix(self):
        """User items matrix in form of CSR matrix.

        User row_ind, col_ind as
        rows and cols indecies (mapped from user/item map).

        Returns:
            csr_matrix: CSR matrix
        """
        users = self.user_map()
        items = self.item_map()
        temp_df = self.sales_data.copy()
        temp_df['user_id'] = temp_df['user_id'].apply(lambda x: users[x])
        temp_df['item_id'] = temp_df['item_id'].apply(lambda x: items[x])
        return csr_matrix((np.array(temp_df['qty']), (np.array(temp_df['user_id']), np.array(temp_df['item_id']))))


In [340]:
test = UserItemMatrix(sales_data=data)

In [342]:
test.item_map()

{7: 0,
 112: 1,
 116: 2,
 118: 3,
 139: 4,
 147: 5,
 172: 6,
 272: 7,
 286: 8,
 289: 9,
 302: 10,
 334: 11,
 356: 12,
 365: 13,
 366: 14,
 396: 15,
 440: 16,
 474: 17,
 494: 18,
 497: 19,
 513: 20,
 527: 21,
 535: 22,
 536: 23,
 580: 24,
 589: 25,
 595: 26,
 601: 27,
 722: 28,
 815: 29,
 820: 30,
 834: 31,
 874: 32,
 887: 33,
 898: 34,
 910: 35,
 957: 36,
 959: 37,
 984: 38,
 988: 39,
 1049: 40,
 1139: 41,
 1186: 42,
 1187: 43,
 1216: 44,
 1217: 45,
 1229: 46,
 1295: 47,
 1320: 48,
 1324: 49,
 1331: 50,
 1362: 51,
 1413: 52,
 1427: 53,
 1444: 54,
 1507: 55,
 1514: 56,
 1563: 57,
 1568: 58,
 1571: 59,
 1583: 60,
 1599: 61,
 1655: 62,
 1669: 63,
 1716: 64,
 1753: 65,
 1765: 66,
 1809: 67,
 1868: 68,
 1874: 69,
 1968: 70,
 2028: 71,
 2037: 72,
 2068: 73,
 2134: 74,
 2152: 75,
 2168: 76,
 2192: 77,
 2198: 78,
 2399: 79,
 2433: 80,
 2469: 81,
 2513: 82,
 2587: 83,
 2643: 84,
 2654: 85,
 2717: 86,
 2810: 87,
 2846: 88,
 2853: 89,
 2860: 90,
 2889: 91,
 2918: 92,
 2926: 93,
 2987: 94,
 2989: 