# Stock Trading

The cell below defines the **abstract class** whose API you need to implement. **Do NOT modify it** - use the dedicated cell further below for your implementation instead.

In [2]:
# DO NOT MODIFY THIS CELL

from abc import ABC, abstractmethod  
      

# abstract class to represent a stock trading platform
class AbstractStockTradingPlatform(ABC):
    
    # constructor
    @abstractmethod
    def __init__(self):
        pass           
        
    # adds transactionRecord to the set of completed transactions
    @abstractmethod
    def logTransaction(self, transactionRecord):
        pass

    # returns a list with all transactions of a given stockName,
    # sorted by increasing trade value. 
    # stockName : str
    @abstractmethod
    def sortedTransactions(self, stockName): 
        sortedList = []
        return sortedList    
    
    # returns a list of transactions of a given stockName with minimum trade value
    # stockName : str
    @abstractmethod
    def minTransactions(self, stockName): 
        minList = []
        return minList    
    
    # returns a list of transactions of a given stockName with maximum trade value
    # stockName : str
    @abstractmethod
    def maxTransactions(self, stockName): 
        maxList = []
        return maxList    

    # returns a list of transactions of a given stockName, 
    # with the largest trade value below a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def floorTransactions(self, stockName, thresholdValue): 
        floorList = []
        return floorList    

    # returns a list of transactions of a given stockName, 
    # with the smallest trade value above a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def ceilingTransactions(self, stockName, thresholdValue): 
        ceilingList = []
        return ceilingList    

        
    # returns a list of transactions of a given stockName,  
    # whose trade value is within the range [fromValue, toValue].
    # stockName : str
    # fromValue : double
    # toValue : double
    @abstractmethod
    def rangeTransactions(self, stockName, fromValue, toValue): 
        rangeList = []
        return rangeList    

Use the cell below to define any data structure and auxiliary python function you may need. Leave the implementation of the main API to the next code cell instead.

In [3]:
# ADD AUXILIARY DATA STRUCTURE DEFINITIONS AND HELPER CODE HERE

class TreeNode(object):
    def __init__(self, price, quant, time):
        self.val = round(price*quant,2)
        self.price = [price]
        self.quant = [quant]
        self.time = [time]
        self.left = None
        self.right = None
        self.height = 1

    def insertNode(self, root, price, quant, time):
        key = round(price * quant,2)
        if not root:
            return TreeNode(price, quant, time)
        elif key == root.val:
            root.price.append(price)
            root.quant.append(quant)
            root.time.append(time)
            return root
        elif key < root.val:
            root.left = self.insertNode(root.left, price, quant, time)
        else:
            root.right = self.insertNode(root.right, price, quant, time)

        root.height = 1 + max(self.getHeight(root.left),self.getHeight(root.right))

        balanceFactor = self.getBalance(root)
        if balanceFactor > 1:
            if self.getBalance(root.left) >= 0:
                return self.rightRotate(root)
            else:
                root.left = self.leftRotate(root.left)
                return self.rightRotate(root)
        if balanceFactor < -1:
            if self.getBalance(root.right) <= 0:
                return self.leftRotate(root)
            else:
                root.right = self.rightRotate(root.right)
                return self.leftRotate(root)
        return root

    def leftRotate(self, z):
        y = z.right
        T2 = y.left
        y.left = z
        z.right = T2
        z.height = 1 + max(self.getHeight(z.left),self.getHeight(z.right))
        y.height = 1 + max(self.getHeight(y.left),self.getHeight(y.right))

        return y

    def rightRotate(self, z):
        y = z.left
        T3 = y.right
        y.right = z
        z.left = T3
        z.height = 1 + max(self.getHeight(z.left),self.getHeight(z.right))
        y.height = 1 + max(self.getHeight(y.left),self.getHeight(y.right))

        return y
    
    def getHeight(self, root):
        if not root:
            return 0
        return root.height
    
    def getBalance(self, root):
        if not root:
            return 0
        return self.getHeight(root.left) - self.getHeight(root.right)
    
    def getMinValueNode(self, root):
        if root is None or root.left is None:
            return root
        return self.getMinValueNode(root.left)

    def getMaxValueNode(self, root):
        if root is None or root.right is None:
            return root
        return self.getMaxValueNode(root.right)
    
    def getFloorValueNode(self, root, threshold, lis=[]):
        if not root:
            return 0

        if root.val == threshold:
            return root
        elif root.val < threshold:
            k = self.getFloorValueNode(root.right, threshold)
            if k == 0:
                return root
            else:
                return k
        elif root.val > threshold:
            return self.getFloorValueNode(root.left, threshold)
    
    def getCeilingValueNode(self, root, threshold, lis=[]):
        if not root:
            return 0

        if root.val == threshold:
            return root
        elif root.val > threshold:
            k = self.getCeilingValueNode(root.left, threshold, lis)
            if k == 0:
                return root
            else:
                return k
        elif root.val < threshold:
            return self.getCeilingValueNode(root.right, threshold, lis)
    
    def rangeOrdered(self, root, lb, ub, lis=[]):
        if root.left is not None:
            self.rangeOrdered(root.left, lb, ub, lis)
        if root.val <= ub and root.val >= lb:
            for n in range(0,len(root.price)):
                lis.append((root.val, root.price[n], root.quant[n], root.time[n]))
        if root.right is not None:
            self.rangeOrdered(root.right, lb, ub, lis)
        return lis
    
    def inOrder(self, root, lis=[]):
        if root.left is not None:
            self.inOrder(root.left, lis)
        for n in range(0,len(root.price)):
            lis.append((root.val, root.price[n], root.quant[n], root.time[n]))
        if root.right is not None:
            self.inOrder(root.right, lis)
        return lis

    

class AVL(object):
    def __init__(self):
        self.root = None

    #INTERACTIVE
    def insert(self, price, quant, time):
        if self.root is None:
            self.root = TreeNode(price, quant, time)
        else:
            self.root = self.root.insertNode(self.root, price, quant, time)

    def maxNode(self):
        result = []
        tmp = self.root.getMaxValueNode(self.root)
        for n in range(0,len(tmp.price)):
                result.append((tmp.val, tmp.price[n], tmp.quant[n], tmp.time[n]))

        return result

    def minNode(self):
        result = []
        tmp = self.root.getMinValueNode(self.root)
        for n in range(0,len(tmp.price)):
                result.append((tmp.val, tmp.price[n], tmp.quant[n], tmp.time[n]))

        return result

    def history(self):
        result = self.root.inOrder(self.root)
            
        return result

    def floor(self, threshold):
        result = []
        tmp = self.root.getFloorValueNode(self.root, threshold)
        for n in range(0,len(tmp.price)):
                result.append((tmp.val, tmp.price[n], tmp.quant[n], tmp.time[n]))
            
        return result

    def ceil(self, threshold):
        result = []
        tmp = self.root.getCeilingValueNode(self.root, threshold)
        for n in range(0,len(tmp.price)):
                result.append((tmp.val, tmp.price[n], tmp.quant[n], tmp.time[n]))

        return result

    def range(self, lb, ub):
        result =  self.root.rangeOrdered(self.root, lb, ub)
            
        return result
    




Use the cell below to implement the requested API. 

In [4]:
# IMPLEMENT HERE THE REQUESTED API

class StockTradingPlatform(AbstractStockTradingPlatform):
    
    def __init__(self):
        # ADD YOUR CODE HERE
        self.stockList = {}
        
        pass           
        

    def logTransaction(self, transactionRecord):
        # ADD YOUR CODE HERE
        stockName = transactionRecord[0]
        if stockName in self.stockList:
            self.stockList[stockName].insert(transactionRecord[1],transactionRecord[2],transactionRecord[3])
        else:
            self.stockList[stockName] = AVL()
            self.stockList[stockName].insert(transactionRecord[1],transactionRecord[2],transactionRecord[3])
        
        pass


    def sortedTransactions(self, stockName): 
        sortedList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            sortedList = self.stockList[stockName].history()
        
        return sortedList    
    
    def minTransactions(self, stockName): 
        minList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            minList = self.stockList[stockName].minNode()
        
        return minList    
    
    def maxTransactions(self, stockName): 
        maxList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            maxList = self.stockList[stockName].maxNode()
        
        return maxList    

    def floorTransactions(self, stockName, thresholdValue): 
        floorList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            floorList = self.stockList[stockName].floor(thresholdValue)
        
        return floorList    

    def ceilingTransactions(self, stockName, thresholdValue): 
        ceilingList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            ceilingList = self.stockList[stockName].ceil(thresholdValue)
        
        return ceilingList    

    
    def rangeTransactions(self, stockName, fromValue, toValue): 
        rangeList = []
        # ADD YOUR CODE HERE
        if stockName in self.stockList:
            rangeList = self.stockList[stockName].range(fromValue, toValue)
        
        return rangeList    

The cell below provides helper code that you can use within your experimental framework to generate random transaction data. **Do NOT modify it**.

In [5]:
# DO NOT MODIFY THIS CELL

import random
from datetime import timedelta
from datetime import datetime

class TransactionDataGenerator:
    def __init__(self):
        self.stockNames = ["Barclays", "HSBA", "Lloyds Banking Group", "NatWest Group", 
                      "Standard Chartered", "3i", "Abrdn", "Hargreaves Lansdown", 
                      "London Stock Exchange Group", "Pershing Square Holdings", 
                      "Schroders", "St. James's Place plc."]
        self.minTradeValue = 500.00
        self.maxTradeValue = 100000.00
        self.startDate = datetime.strptime('1/1/2022 1:00:00', '%d/%m/%Y %H:%M:%S')
        random.seed(20221603)
          
    # returns the name of a traded stock at random
    def getStockName(self):
        return random.choice(self.stockNames)

    # returns the trade value of a transaction at random
    def getTradeValue(self):
        return round(random.uniform(self.minTradeValue, self.maxTradeValue), 2)
    
    # returns a list of N randomly generated transactions,
    # where each transaction is represented as a list [stock name, price, quantity, timestamp]
    # N : int
    def generateTransactionData(self, N):   
        listTransactions = [[]]*N
        listDates = [self.startDate + timedelta(seconds=3*x) for x in range(0, N)]
        listDatesFormatted = [x.strftime('%d/%m/%Y %H:%M:%S') for x in listDates]
        for i in range(N):
            stockName = random.choice(self.stockNames)
            price = round(random.uniform(50.00, 100.00), 2)
            quantity = random.randint(10,1000)
            listTransactions[i] = [stockName, price, quantity, listDatesFormatted[i]]   
        return listTransactions

Use the cell below for the python code needed to realise your **experimental framework** (i.e., to generate test data, to instante the `StockTrading` class, to thorouhgly experiment with its API functions, and to experimentally measure their performance). You may use the previously provided ``TransactionDataGenerator`` class to generate random transaction data.

In [6]:
import random
import timeit

# ADD YOUR EXPERIMENTAL FRAMEWORK CODE HERE




The cell below exemplifies **debug** code I will invoke on your submission - it does not represent an experimental framework (which should me much more comprehensive). **Do NOT modify it**. 

In [7]:
# DO NOT MODIFY THIS CELL

import timeit

testPlatform = StockTradingPlatform()
testDataGen = TransactionDataGenerator()

numTransactions = 1000000
testData = testDataGen.generateTransactionData(numTransactions)

numRuns = 100

print("Examples of transactions:", testData[0], testData[numTransactions//2], testData[numTransactions-1])

#
# testing the logTransaction() API 
#
starttime = timeit.default_timer()
for i in range(numTransactions):
    testPlatform.logTransaction(testData[i])
endtime = timeit.default_timer()
print("\nExecution time to load", numTransactions, "transactions:", round(endtime-starttime,4))

#
# testing the various API functions
#
starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.sortedTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time sortedTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.minTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time minTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.maxTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time maxTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.floorTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time floorTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.ceilingTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time ceilingTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    rangeValues = sorted([testDataGen.getTradeValue(), testDataGen.getTradeValue()])
    output = testPlatform.rangeTransactions(testDataGen.getStockName(), rangeValues[0], rangeValues[1])
endtime = timeit.default_timer()
print("\nMean execution time rangeTransactions:", round((endtime-starttime)/numRuns,4))

Examples of transactions: ['NatWest Group', 51.89, 96, '01/01/2022 01:00:00'] ['Barclays', 59.84, 245, '18/01/2022 09:40:00'] ['Pershing Square Holdings', 84.27, 319, '04/02/2022 18:19:57']

Execution time to load 1000000 transactions: 101.2316

Mean execution time sortedTransactions: 0.3006

Mean execution time minTransactions: 0.0

Mean execution time maxTransactions: 0.0

Mean execution time floorTransactions: 0.0

Mean execution time ceilingTransactions: 0.0

Mean execution time rangeTransactions: 0.212
