# project

## project 的三个需求

1. 表示数字艺术品
2. 存储数字艺术品
3. 交换数字艺术品 transection

## 安全需求分析

1. 需要记录用户和用户拥有的艺术品
    - 需要将用户和用户的独特签名对应起来
    - 需要在艺术品的存储中添加用户的独特签名
    - 将签名添加到艺术品的算法需要加密，防止攻击者修改艺术品数据内的独特签名
2. 数据交换

    - 交换艺术品时，需要将交换的请求加密，防止被中途修改
    - ACID 原则
        - 原子性（A）：一个事务的所有系列操作步骤被看成一个动作，所有的步骤要么全部完成，要么一个也不会完成。如果在事务过程中发生错误，则会回滚到事务开始前的状态，将要被改变的数据库记录不会被改变。
        - 一致性（C）：一致性是指在事务开始之前和事务结束以后，数据库的完整性约束没有被破坏，即数据库事务不能破坏关系数据的完整性及业务逻辑上的一致性。
        - 隔离性（I）：主要用于实现并发控制，隔离能够确保并发执行的事务按顺序一个接一个地执行。通过隔离，一个未完成事务不会影响另外一个未完成事务。
        - 持久性（D）：一旦一个事务被提交，它应该持久保存，不会因为与其他操作冲突而取消这个事务。

3. 数据存储的安全性

    - NFT 的不可篡改性依靠区块链的安全特性来实现。这里，我们将区块链替换为了一般的数据库，这将使平台失去去中心化的优势。为了仍然保障安全性，需要做出以下假设：
        - 部署数据库的设备不会被攻破
        - 数据库管理员的密码不会被泄露
        - 数据库管理员不会恶意修改数据库

4. 需要保护所有用户的独特签名

## 重要 reference

-   [中泰证券 NFT 技术分析](https://dfscdn.dfcfw.com/download/A2_cms_f_20220216123508144922&direct=1&abc3847.pdf)
-   [我的总结（祥见参考列表） - csdn](https://blog.csdn.net/weixin_39591031/article/details/124138855)


In [None]:
import sqlite3
import json
import base64
import io
import time

from Crypto.PublicKey import ECC
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes

from PIL import Image


### 数据库

处理一切和数据库的交互

reference:

-   [https://developer.51cto.com/article/624601.html](https://developer.51cto.com/article/624601.html)
-   [https://www.runoob.com/sqlite/sqlite-data-types.html](https://www.runoob.com/sqlite/sqlite-data-types.html)


In [1]:
class DBmanager:
    DATABASE_PATH = "./demo.db"
    COLLECTIONS_TABLE_NAME = "collections"
    USER_TABLE_NAME = "users"
    TRANSECTIONS_TABLE_NAME = "transections"

    def __init__(self):
        # init connection, db will be created if doesnt exist
        self.conn = sqlite3.connect(self.DATABASE_PATH)
        self.cur = self.conn.cursor()
        # init tables
        self.init_collections_table()
        self.init_user_table()
        self.init_transections_table()

    # init tables

    def init_collections_table(self):
        if (
            len(
                self.cur.execute(
                    "SELECT name FROM sqlite_master WHERE type='table' AND name='{}';".format(
                        self.COLLECTIONS_TABLE_NAME
                    )
                ).fetchall()
            )
            > 0
        ):
            print("Find {} table in db.".format(self.COLLECTIONS_TABLE_NAME))
            return
        # id | owner_id | price | encrypted_content | preview | statue
        self.execute_and_commit(
            "CREATE TABLE {} (ID TEXT, OWNER_ID TEXT, PRICE REAL, ENCRYPTED_CONTENT BOLB, PREVIEW BOLB, STATUE TEXT);".format(
                self.COLLECTIONS_TABLE_NAME
            )
        )
        print("Images table initialized.")

    def init_user_table(self):
        if (
            len(
                self.cur.execute(
                    "SELECT name FROM sqlite_master WHERE type='table' AND name='{}';".format(
                        self.USER_TABLE_NAME
                    )
                ).fetchall()
            )
            > 0
        ):
            print("Find {} table in db.".format(self.USER_TABLE_NAME))
            return
        # id | validation_file | pub_key | balance
        self.execute_and_commit(
            "CREATE TABLE {} (ID TEXT, VALIDATION_FILE TEXT, PUB_KEY TEXT, BALANCE REAL);".format(
                self.USER_TABLE_NAME
            )
        )
        print("Users table initialized.")

    def init_transections_table(self):
        if (
            len(
                self.cur.execute(
                    "SELECT name FROM sqlite_master WHERE type='table' AND name='{}';".format(
                        self.TRANSECTIONS_TABLE_NAME
                    )
                ).fetchall()
            )
            > 0
        ):
            print("Find {} table in db.".format(self.TRANSECTIONS_TABLE_NAME))
            return
        # id | timestamp | type | content | collection_id | src_user_id | dest_user_id | status
        self.execute_and_commit(
            "CREATE TABLE {} (\
                ID INTEGER PRIMARY KEY, \
                TIMESTAMP REAL, \
                TYPE TEXT, \
                CONTENT TEXT, \
                COLLECTION_ID TEXT, \
                SRC_USER_ID TEXT, \
                DEST_USER_ID TEXT, \
                STATUS TEXT);".format(
                self.TRANSECTIONS_TABLE_NAME
            )
        )
        print("Users table initialized.")

    # manage collections TABLE

    def add_collection(
        self,
        collection_id: str,
        price: float = None,
        owner_id: str = None,
        encrypted_content: str = None,
        preview: bytes = None,
        status: str = None,
    ):
        self.execute_and_commit(
            "INSERT INTO {} VALUES('{}', '{}', '{}', '{}', '{}', '{}')".format(
                self.COLLECTIONS_TABLE_NAME,
                collection_id,
                price,
                owner_id,
                encrypted_content,
                preview,
                status,
            )
        )

    def remove_collection(self, collection_id):
        self.execute_and_commit(
            "DELETE FROM {} WHERE id = {}".format(
                self.COLLECTIONS_TABLE_NAME, collection_id
            )
        )

    def update_collection(
        self,
        collection_id: str,
        price: float = None,
        owner_id: str = None,
        encrypted_content: str = None,
        preview: bytes = None,
        status: str = None,
    ):
        """Update any field of the collection table in database."""
        for field_name, field_value in zip(
            [
                f"{price=}".split("=")[0],
                f"{owner_id=}".split("=")[0],
                f"{encrypted_content=}".split("=")[0],
                f"{preview=}".split("=")[0],
                f"{status=}".split("=")[0],
            ],
            [price, owner_id, encrypted_content, preview, status],
        ):
            if field_value:
                self.execute_and_commit(
                    "UPDATE {} SET {} = {} WHERE id = {};".format(
                        self.COLLECTIONS_TABLE_NAME,
                        field_name,
                        field_value,
                        collection_id,
                    )
                )

    def get_all_collections(self):
        self.cur.execute("SELECT * FROM {}".format(self.COLLECTIONS_TABLE_NAME))
        return self.cur.fetchall()

    def get_collection_by_id(self, collection_id):
        """
        Find collection from database. Return the collection info if exist, otherwise None.
        @return All data item of the colelction: (id, owner_id, price, encrypted_content, preview, statue)
        """
        self.cur.execute(
            "SELECT * FROM {} WHERE id = {}".format(
                self.COLLECTIONS_TABLE_NAME, collection_id
            )
        )
        res = self.cur.fetchall()
        if len(res) > 1:
            raise AssertionError("Fatel error, more than one collecion have same id.")
        return res[0]

    def get_collections_by_user_id(self, user_id):
        """
        Find all collections belongs to user. Return the collection info list.
        @return [(id, owner_id, price, encrypted_content, preview, statue), ...]
        """
        self.cur.execute(
            "SELECT * FROM {} WHERE owner_id = {}".format(
                self.COLLECTIONS_TABLE_NAME, user_id
            )
        )
        res = self.cur.fetchall()
        return res

    # manage users TABLE

    def add_user(
        self,
        user_id: str,
        validation_file: bytes = None,
        pub_key: str = None,
        balance: float = None,
    ):
        self.execute_and_commit(
            "INSERT INTO {} VALUES('{}', '{}', '{}', '{}')".format(
                self.USER_TABLE_NAME, user_id, validation_file, pub_key, balance,
            )
        )

    def remove_user(self, user_id):
        self.execute_and_commit(
            "DELETE FROM {} WHERE id = {}".format(self.USER_TABLE_NAME, user_id)
        )

    def update_user(
        self,
        user_id: str,
        validation_file: bytes = None,
        pub_key: str = None,
        balance: float = None,
    ):
        """Update any field of the collection table in database."""
        for field_name, field_value in zip(
            [
                f"{validation_file=}".split("=")[0],
                f"{pub_key=}".split("=")[0],
                f"{balance=}".split("=")[0],
            ],
            [validation_file, pub_key, balance],
        ):
            if field_value:
                self.execute_and_commit(
                    "UPDATE {} SET {} = {} WHERE id = {};".format(
                        self.USER_TABLE_NAME, field_name, field_value, user_id,
                    )
                )

    def get_all_user(self):
        self.cur.execute("SELECT * FROM {}".format(self.USER_TABLE_NAME))
        return self.cur.fetchall()

    def get_user_by_id(self, user_id):
        """
        Find user from database. Return the user info if exist, otherwise None.
        @return All data item of the user: (id, validation_file, pub_key, balance)
        """
        self.cur.execute(
            "SELECT * FROM {} WHERE id = {}".format(self.USER_TABLE_NAME, user_id)
        )
        res = self.cur.fetchall()
        if len(res) > 1:
            raise AssertionError("Fatel error, more than one collecion have same id.")
        return res[0]

    # manage transections TABLE

    def add_transection(
        self,
        transeciton_id: str,
        timestamp: float = None,
        type: str = None,
        content: str = None,
        collection_id: str = None,
        src_user_id: str = None,
        dest_user_id: str = None,
        status: str = None,
    ):
        self.execute_and_commit(
            "INSERT INTO {} VALUES('NULL', '{}', '{}', '{}', '{}', '{}', '{}', '{}')".format(
                self.TRANSECTIONS_TABLE_NAME,
                timestamp,
                type,
                content,
                collection_id,
                src_user_id,
                dest_user_id,
                status,
            )
        )

    def remove_transection(self, transection_id):
        self.execute_and_commit(
            "DELETE FROM {} WHERE id = {}".format(
                self.TRANSECTIONS_TABLE_NAME, transection_id
            )
        )

    def update_transection(
        self,
        timestamp: float = None,
        type: str = None,
        content: str = None,
        collection_id: str = None,
        src_user_id: str = None,
        dest_user_id: str = None,
        status: str = None,
    ):
        """Update any field of the collection table in database."""
        for field_name, field_value in zip(
            [
                f"{type=}".split("=")[0],
                f"{content=}".split("=")[0],
                f"{collection_id=}".split("=")[0],
                f"{src_user_id=}".split("=")[0],
                f"{dest_user_id=}".split("=")[0],
                f"{status=}".split("=")[0],
            ],
            [type, content, collection_id, src_user_id, dest_user_id, status,],
        ):
            if field_value:
                self.execute_and_commit(
                    "UPDATE {} SET {} = {} WHERE timestamp = {};".format(
                        self.TRANSECTIONS_TABLE_NAME,
                        field_name,
                        field_value,
                        timestamp,
                    )
                )

    def get_all_transection(self):
        self.cur.execute("SELECT * FROM {}".format(self.TRANSECTIONS_TABLE_NAME))
        return self.cur.fetchall()

    def get_transection_by_user_id(self, user_id):
        """
        Find all transecitons related to user. Return the transeciton info list.
        @return [(id, owner_id, price, encrypted_content, preview, statue), ...]
        """
        self.cur.execute(
            "SELECT * FROM {} WHERE owner_id = {}".format(
                self.COLLECTIONS_TABLE_NAME, user_id
            )
        )
        res = self.cur.fetchall()
        return res

    # general function

    def execute_and_commit(self, sql_cmd: str):
        self.cur.execute(sql_cmd)
        self.conn.commit()

    def destroy(self):
        self.cur.close()
        self.conn.close()
        print("Db connection closed.")


In [None]:
%%script false
# init database
db = DBmanager()


In [None]:
%%script false
db.destroy()


### 用户

属性

-   `ID`: user name, must be unique, thus can be view as ID
-   `validation_file`: json serilized file (2 fields: user_id & AES key) being encrypted using user's RSA private key
-   `pub_key`: user's RSA public key
-   `balance`: user's balance of XAV coin
-   `transections`: user's all transections

TODO: 数据库条件查询语句可以在DBmanager里实现。只需要附带条件字符串参数即可。

In [None]:
class User:
    DEFAULT_BALANCE = 3  # user default balance

    def __init__(
        self,
        id: str,
        # can be left empty during registration:
        pub_key: str = None,
        validation_file: bytes = None,
        balance: float = None,
        # can be always left empty
        collections: list = None,
        transections: list = None,
        # must be provided in instantiation:
        db: DBmanager = None,
    ):
        """
        id: user name, must be unique, thus can be view as ID
        pub_key: user's RSA public key
        validation_file: json serilized file (2 fields: user_id & AES key) being encrypted using user's RSA private key
        balance: user's balance of XAV coin
        collections: user's all collections
        transections: user's all transections
        db: DBmanager
        """

        # db must be provided in instantiation
        if not (db or self.db):  # if both are None
            raise AttributeError("Haven't connect to database, please connect first.")
        else:
            self.db = db

        # necessary fields
        self.id = id

        # if any of following fields is None -> register mode
        if not (
            pub_key and validation_file and balance and collections and transections
        ):
            # if any of following fields isn't none -> raise exception
            if pub_key and validation_file and balance and collections and transections:
                raise AttributeError(
                    "In non-register mode all fields must be provided."
                )
            # register user
            priv_key, self.pub_key, aes_key = self._gen_keys()
            self.validation_file = self._gen_validation_file(aes_key)
            self.balance = self.DEFAULT_BALANCE
            self.collections = []
            self.transections = []
            self._add_to_db()
        else:  # normal mode
            self.pub_key = pub_key
            self.validation_file = validation_file
            self.balance = balance
            self.collections = collections or self._get_collections()
            self.transections = transections or self._get_transecions()

    @classmethod
    def fromID(cls, id, db:DBmanager=None):
        if not cls.is_id_repeated(id):
            raise AttributeError("Collection doesn't exist with id={}.".format(id))
        _, validation_file, pub_key, balance = cls.db.get_user_by_id(id)
        return cls(id, validation_file, pub_key, balance, db)

    @classmethod
    def new(cls, id, db:DBmanager=None):
        priv_key, pub_key, aes_key = cls._gen_keys()
        validation_file = cls._gen_validation_file(aes_key)
        balance = cls.DEFAULT_BALANCE
        collections = []
        transections = []
        cls._add_to_db()
        return cls(id, validation_file, pub_key, balance, db)

    def _gen_keys(self):
        """
        Generate:
            1. a pair of keys using EEC algorithm
            2. a key using AES algorithm
        """
        key = ECC.generate(curve="P-256")

        # PEM is human readable string
        # when import, read as text (e.g., open(path, 'rt'))
        priv_key = key.export_key(format="PEM")
        pub_key = key.public_key().export_key(format="PEM")

        aes_key = get_random_bytes(32)  # 32-bytes is safer than 16-bytes

        print("Generate user keys:")
        print(priv_key)
        print(pub_key)
        print("AES key:", aes_key)

        return priv_key, pub_key, aes_key

    def _gen_validation_file(self, aes_key):
        return json.dumps({"user_id": self.id, "aes_key": aes_key})

    def _add_to_db(self):
        """
        - id: user name, must be unique, thus can be view as ID
        - pub_key: user's RSA public key
        - validation_file: json serilized file (2 fields: user_id & AES key) being encrypted using user's RSA private key
        - balance: user's balance of XAV coin
        - transections: user's all transections
        """
        self.db.add_user(
            self.id,
            self.pub_key,
            self.validation_file,
            self.balance,
            self.transections,
        )

    def _get_collections(self):
        # retrieve user's collections from database
        return self.db.get_collections_by_user_id(self.id)

    def _get_transecions(self):
        # retrieve user's transecions from database
        self.db.cur.execute(
            "SELECT * FROM {} WHERE src_user_id = {} OR dest_user_id = {}".format(
                db.TRANSECTIONS_TABLE_NAME, self.id, self.id
            )
        )
        return self.db.cur.fetchall()


In [None]:
xav = User("xav", db=db)


### Collection

属性：
- id
- owner_id
- price: 初始价格 0.1 XAV，每次交易升值 1 XAV
- encrpted_content
- preview: 低分辨率预览图
- status

数据库格式：
`(id, owner_id, price, encrypted_content, preview, status)`

注：想要解决盗版图片重新上传的问题
1. 算法检查图片重复度
2. 审核人员人工检查


In [None]:
class Collection:
    # Format in database: (id, owner_id, price, encrypted_content, preview, status)

    _DEFAULT_PRICE = 0.1  # default price of a collection
    _STATUS_CONFIRMED = "confirmed"  # default status
    _STATUS_PENDING = "pending"  # collection on processing
    db = None  # database

    def __init__(
        self,
        id: str,
        owner_id: str,
        price: float,
        encrypted_content: str,
        preview: str,
        status: str,
        # must be provided in instantiation:
        db: DBmanager = None,
    ):
        """
        @params
        - id: collection unique name
        - owner_id: id of collection's owner
        - price: price of the collection, auto increase by 1 after each transection
        - encrypted_content: raw data of the collection after encrypted with owner's AES key
        - preview: low resolution version of the image
        - status: pending if in the middle of a transection, otherwise confirmed
        - raw_data: raw data of the collection in bytes
        - ase_key: a ase key used to decrypt `encrypted_content`
        - db: DBmanager instance.
        """

        # db must be provided in instantiation
        if not (db or self.db):  # if both are None
            raise AttributeError("Haven't connect to database, please connect first.")
        else:
            self.db = db

        self.id = id
        self.owner_id = owner_id
        self.price = price
        self.encrypted_content = encrypted_content
        self.preview = preview
        self.status = status

    @classmethod
    def fromID(cls, id, db:DBmanager=None):
        if not cls.is_id_repeated(id):
            raise AttributeError("Collection doesn't exist with id={}.".format(id))
        _, owner_id, price, encrypted_content, preview, statue = cls.db.get_collection_by_id(id)
        return cls(id, owner_id, price, encrypted_content, preview, statue, db)

    @classmethod
    def new(cls, id, owner_id, raw_data, aes_key, db:DBmanager=None):
        '''Create a new collection and add to database.'''
        if cls.is_id_repeated(id):
            raise AttributeError("Collection id already exists, please use another id.")
        price = cls._DEFAULT_PRICE
        encrypted_content = cls._encrypte_content(raw_data, aes_key)
        preview = cls._gen_preview(raw_data)
        status = cls._STATUS_CONFIRMED
        collection = cls(id, owner_id, price, encrypted_content, preview, status, db)
        collection._add_to_db()
        return collection

    def is_id_repeated(self, id):
        """Return whether or not the collection's id already exists in database."""
        self.db.cur.execute(
            "SELECT * FROM {} WHERE id = {}".format(self.db.COLLECTIONS_TABLE_NAME, id)
        )
        return len(self.db.cur.fetchall()) > 0

    def _add_to_db(self):
        """
        - id: collection unique name
        - price: price of the collection, auto increase by 1 after each transection
        - owner_id: id of collection's owner
        - encrypted_content: raw data of the collection after encrypted with owner's AES key
        - preview: low resolution version of the image
        - status: pending if in the middle of a transection, otherwise confirmed
        """
        self.db.add_collection(
            self.id,
            self.price,
            self.owner_id,
            self.encrypted_content,
            self.preview,
            self.status,
        )

    def _encrypte_content(data, aes_key) -> str:
        """
        Encrypt content using AES (CTR mode, allow arbitrary length of data).
        @param data: raw data of image in bytes
        @return serialized json string (e.g., {"nonce": '4Sa\we', "ciphertext": 'wgS2F=D3'})
        """
        cipher = AES.new(aes_key, AES.MODE_CTR)
        ct_bytes = cipher.encrypt(data)
        nonce = base64.b64encode(cipher.nonce).decode("utf-8")
        ct = base64.b64encode(ct_bytes).decode("utf-8")
        result = json.dumps({"nonce": nonce, "ciphertext": ct})
        print("Encrypt result:", result)
        return result

    def _decrypte_content(data, aes_key) -> bytes:
        """
        Encrypt content using AES (CTR mode, allow arbitrary length of data).
        @param data: json serialized string (e.g., {"nonce": '4Sa\we', "ciphertext": 'wgS2F=D3'})
        @return decrypted bytes data
        """
        b64 = json.loads(data)
        nonce = base64.b64decode(b64["nonce"])
        ct = base64.b64decode(b64["ciphertext"])
        cipher = AES.new(aes_key, AES.MODE_CTR, nonce=nonce)
        pt = cipher.decrypt(ct)
        print("Decrypt result:", pt)
        return pt

    def _gen_preview(raw_data):
        '''Generate low resolution thumbnail and return its bytes data.'''
        PREVIEW_SIZE = (210, 294)  # default collection thubnail size (width, height)
        img = Image.frombytes(raw_data).thumbnail(PREVIEW_SIZE)
        img_byte_stream = io.BytesIO()
        img.save(img_byte_stream, format=img.format)
        return img_byte_stream.getvalue()

    def update_owner(self, new_owner_id: str):
        """Change owner of this artwork."""
        # update owner of this collection in database
        print("Image owner updated: {} ---> {}".format(self.owner_id, new_owner_id))
        self.db.update_collection(self.id, owner_id=new_owner_id)


### Transections

属性
- id: auto-increment id
- timestamp: the timestamp when the transection is created.
- type
- content
- collection_id
- src_user_id
- dest_user_id
- status

In [None]:
class Transection:
    def __init__(self, timestamp, type, content, collection_id, src_user_id, dest_user_id, status):
        self.timestamp = timestamp
        self.type = type
        self.content = content
        self.collection_id = collection_id
        self.src_user_id = src_user_id
        self.dest_user_id = dest_user_id
        self.status = status
    
    @classmethod
    def new(cls, type, content, collection_id, src_user_id, dest_user_id, status):
        return cls(time.time(), type, content, collection_id, src_user_id, dest_user_id, status)