In [1]:
import csv
import sqlparse
import itertools
from itertools import product, chain

tableSchema= {}

In [2]:
# loading the tabes schema 
def getTableSchema():
    f = open("./files/metadata.txt", "r")
    metaData= f.read().lower()
#     split the metadata based on begin_table
    metaData= metaData.split('<begin_table>\n')
#     split all entries based on end_table to get only the table name and its attributes in each list entry
    metaData= [data.split('<end_table>\n') for data in metaData]
    
#     for each table store its attibute list corresponding to table name after splitting it based on next line char
    for entry in range (1, len(metaData)):
        tableInfo= metaData[entry][0]
        tableInfo= tableInfo.split('\n')[:-1]
        tableSchema[tableInfo[0]]= tableInfo[1:]

In [3]:
def readTable(tableName):
    if(tableName not in tableSchema.keys()):
        print("Invalid Table Name, {}!".format(tableName))
        exit()
    tableFile= open('./files/'+tableName+'.csv', 'r')
    tableReader= csv.reader(tableFile)
    tableData= [[tableName.lower()+'.'+colName.lower() for colName in tableSchema[tableName]]]
    for row in tableReader:
        tableData.append(row)    
        
    tableData[1:] = [list(map(int, tableData[tableRow])) for tableRow in range (1, len(tableData))]
    
    return tableData

In [4]:
def prepData(tableList):
#     append the rows of all tables to each other and prepare cross-product dataset
    if(len(tableList)==0):
        print ("Invalid Query..Please provide Table Name(s).")
        exit()
    dataset= readTable(tableList[0])
    for tableNum in range (1, len(tableList)): 
        currentTable= readTable(tableList[tableNum])
        dataset[0]+= currentTable[0]
        finalTable= [dataset[0]]
        for ind1 in range (1, len(dataset)):
            for ind2 in range (1, len(currentTable)):
                finalTable.append(dataset[ind1]+currentTable[ind2])
        dataset= finalTable
    return dataset

In [5]:
def selectAll(tableList):    
#     select all the attirbutes and display the result
    dataset= prepData(tableList)
    return dataset

In [6]:
def dropUnwantedAttributes(dataset, attributeList):
    if(len(attributeList)==0):
        print ("Invalid Query..Please provide Attribute Name(s).")
        return 
    if attributeList[0]=='*':
        return dataset
    
    actualCols= []
    indexNums= []
    for index in range (0, len(dataset[0])):
        attributeName= dataset[0][index].split('.')[1]
        actualCols.append(attributeName)
        if(attributeName not in attributeList):
            indexNums.append(index)
#     invalid query if any such attribute is demanded which does not exist
    for attribute in attributeList:
        if attribute not in actualCols:
            print("Invalid Attribute Name, {}!".format(attribute))
            return
#     pop all those attibutes and respective values from all records which are not demanded
    indexNums.sort(reverse= True)    #pop from last index to avoid ambiguities
    for row in dataset:
        for index in indexNums:
            if index < len(row):
                row.pop(index)
    
    return dataset

In [7]:
def dropUnwantedAttributesFromGroups(dataset, attributeList):
    if(len(attributeList)==0):
        print ("Invalid Query..Please provide Attribute Name(s).")
        return
    if attributeList[0]=='*':
        return dataset
    
    actualCols= []
    indexNums= []
    for index in range (0, len(dataset[0][0])):
        attributeName= dataset[0][0][index].split('.')[1]
        actualCols.append(attributeName)
        if(attributeName not in attributeList):
            indexNums.append(index)
            
#     invalid query if any such attribute is demanded which does not exist
    for attribute in attributeList:
        if attribute not in actualCols:
            print("Invalid Attribute Name, {}!".format(attribute))
            return
        
#     pop all those attibutes and respective values from all records which are not demanded
    indexNums.sort(reverse= True)    #pop from last index to avoid ambiguities
    for group in dataset:
        for row in group:
            for index in indexNums:
                if index < len(row):
                    row.pop(index)
    return dataset

In [8]:
def selectDistinct(dataset):
    tempData= dataset[1:]
    dataset= [dataset[0]]
    for entry in tempData:
        if entry not in dataset:
            dataset.append(entry)
    
    return dataset

In [9]:
def selectDistinctFromGroups(dataset):
    tempData= dataset[1:]
    dataset= [[dataset[0]]]
    for group in tempData:
        for entry in group:
            if entry not in dataset:
                dataset.append(entry)
    
    return dataset

In [10]:
# check if the query has aggregate function present
def isAggregate(subQuery):
    if 'max' in subQuery or 'min' in subQuery or 'avg' in subQuery or 'sum' in subQuery or 'count' in subQuery:
        return True
    return False

In [11]:
def solveAggregate(dataset, aggregateAttributes, aggregateFunctions):
    
    dataset= dropUnwantedAttributes(dataset, aggregateAttributes)
    if dataset==None:
        return
    
    columnDisplay= []
    aggregateResults= []
    for function, attribute in zip(aggregateFunctions,aggregateAttributes):
        for column in dataset[0]:
            if column.split('.')[1] == attribute:
                attributeIndex= dataset[0].index(column)
        values= []
        for rowNum in range(1, len(dataset)):
            values.append(dataset[rowNum][attributeIndex])
#         values= list(chain.from_iterable(dataset)) 
        if(function == 'max'):
            columnDisplay.append(function+'('+attribute+')')
            aggregateResults.append(max(values))
        elif(function == 'min'):
            columnDisplay.append(function+'('+attribute+')')
            aggregateResults.append(min(values))
        elif(function == 'sum'):
            columnDisplay.append(function+'('+attribute+')')
            aggregateResults.append(sum(values))
        elif(function == 'avg'):
            columnDisplay.append(function+'('+attribute+')')
            aggregateResults.append(sum(values)/len(values))
        elif(function == 'count'):
            columnDisplay.append(function+'('+attribute+')')
            aggregateResults.append(len(values))
        else:
            print("Invalid Aggregate function!")
            return
    return columnDisplay, aggregateResults

In [12]:
def solveAggregateOnGroups(dataset, groupbyCol, attributeList, aggregateAttributes, aggregateFunctions):
    dataset= dropUnwantedAttributesFromGroups(dataset, attributeList)
    if dataset==None:
        return
    for element in aggregateAttributes:
        if element in attributeList:
            attributeList.remove(element)
            
    columnDisplay= []
    aggregateResults= []
    if len(attributeList)>0:
        for attributeName in dataset[0][0]:
            if attributeName.split('.')[1]==groupbyCol:
                groupbyColIndex= dataset[0][0].index(attributeName)        
    
    for function, attribute in zip(aggregateFunctions,aggregateAttributes):
        for column in dataset[0][0]:
            if column.split('.')[1] == attribute:
                attributeIndex= dataset[0][0].index(column)
        for group in dataset[1:]:
            values= []
            for rowNum in range(0, len(group)):
                values.append(group[rowNum][attributeIndex])
            if(function == 'max'):
                columnDisplay.append(function+'('+attribute+')')
                if len(attributeList)>0:
                    otherValueSelected= group[values.index(max(values))][groupbyColIndex]
                    aggregateResults.append((otherValueSelected, max(values)))
                else:
                    aggregateResults.append(max(values))
            elif(function == 'min'):
                columnDisplay.append(function+'('+attribute+')')
                if len(attributeList)>0:
                    otherValueSelected= group[values.index(min(values))][groupbyColIndex]
                    aggregateResults.append((otherValueSelected, min(values)))
                else:
                    aggregateResults.append(min(values))
            elif(function == 'sum'):
                columnDisplay.append(function+'('+attribute+')')
                if len(attributeList)>0:
                    otherValueSelected= group[0][groupbyColIndex]
                    aggregateResults.append((otherValueSelected, sum(values)))
                else:
                    aggregateResults.append(sum(values))
            elif(function == 'avg'):
                columnDisplay.append(function+'('+attribute+')')
                if len(attributeList)>0:
                    otherValueSelected= group[0][groupbyColIndex]
                    aggregateResults.append((otherValueSelected, sum(values)/len(values)))
                else:
                    aggregateResults.append(sum(values)/len(values))
            elif(function == 'count'):
                columnDisplay.append(function+'('+attribute+')')
                if len(attributeList)>0:
                    otherValueSelected= group[0][groupbyColIndex]
                    aggregateResults.append((otherValueSelected, len(values)))
                else:
                    aggregateResults.append(len(values))
            else:
                print("Invalid Aggregate function!")
                return
    return columnDisplay, aggregateResults

In [13]:
# check if the data value in table satisfies the query condition
def isSatisfied(data, operator, value):
    if '<=' == operator:
        return data<=value
    elif '>=' == operator:
        return data>=value
    elif '='== operator:
        return data==value
    elif '<' == operator:
        return data<value
    elif '>' == operator:
        return data>value

In [14]:
# eliminate the records from dataset which do not meet condition specified
def removeRecords(dataset, attribute, operator, value):
    datasetColumns= dataset[0]
    toCheckIndices= []
#     extract the columns with the attribute asked (can be throught multiple columns)
    for colNum in range (0, len(datasetColumns)):
        if datasetColumns[colNum].split('.')[1]==attribute:
            toCheckIndices.append(colNum)
    
    if(len(toCheckIndices)==0):
        print("Invalid Attribute Field, please check and try again.")
        exit()
        return
    
    cleanedDataset= [dataset[0]]
    for rowNum in range(1, len(dataset)):
        row= dataset[rowNum]
#         assume that the row satisfies the condition
        toInclude= True
        for col in toCheckIndices:
#             if it does not satisfy make the 'toInclude' flag as False
            if isSatisfied(row[col], operator, value) ==False:
                toInclude= False
                break
#         only if 'toInclude' is still true, add the row to final dataset
        if(toInclude):
            cleanedDataset.append(row)
    return cleanedDataset

In [15]:
def union(dataset1, dataset2):
    datasetColumns= dataset1[0]
    tuple1=map(tuple,dataset1[1:])
    tuple2=map(tuple,dataset2[1:])
    unionDataset= [datasetColumns]
    tempData= list(set(tuple1).union(tuple2))
    for entry in tempData:
        unionDataset.append(list(entry))
    return unionDataset

In [16]:
def intersection(dataset1, dataset2):
    datasetColumns= dataset1[0]
    dataset1= dataset1[1:]
    dataset2= dataset2[1:]
    intesectedDataset= [datasetColumns]+ [value for value in dataset1 if value in dataset2] 
    return intesectedDataset

In [17]:
def meetCondition(dataset, condition):
    if len(condition)==3:        # when there is only 1 condition ['a', '<=', '80']
        dataset= removeRecords(dataset, condition[0], condition[1], int(condition[2]))
        return dataset
    elif len(condition)==7:      # when there are 2 conditions  ['a', '<=', '80', 'AND', 'b', '=','431']
        dataset1= removeRecords(dataset, condition[0], condition[1], int(condition[2]))
        dataset2= removeRecords(dataset, condition[4], condition[5], int(condition[6]))

        if(condition[3]=='or'):
            unionDataset= union(dataset1, dataset2)
            return unionDataset
        elif(condition[3]=='and'):
            intersectionDataset= intersection(dataset1, dataset2)
            return intersectionDataset
        else:
            print("Invalid Condition!")
            exit()
            return
    else:
        print("Invalid Syntax!")
        exit()
        return

In [18]:
def groupBy(dataset, groupbyCol):
    try:
        for attributeName in dataset[0]:
            if attributeName.split('.')[1]==groupbyCol:
                groupbyColIndex= dataset[0].index(attributeName)
        # Key function 
        key_func = lambda x: x[groupbyColIndex] 
        groupedDataset= [[dataset[0]]]
        dataset= dataset[1:]

        for key, group in itertools.groupby(dataset, key_func): 
            groupedDataset.append(list(group)) 
        return groupedDataset
    except:
        print('{} column not in the table.'.format(groupbyCol))
            

In [19]:
def orderBy(dataset, orderbyCol, orderbyStyle):
    try :
        for attributeName in dataset[0]:
            if attributeName.split('.')[1]==orderbyCol:
                orderbyColIndex= dataset[0].index(attributeName)
        orderedDataset= [dataset[0]]
        dataset= dataset[1:]
        if orderbyStyle =='asc':
            dataset.sort(key = lambda dataset: dataset[orderbyColIndex]) 
            orderedDataset +=dataset
            return orderedDataset
        elif orderbyStyle =='desc':
            dataset.sort(key = lambda dataset: dataset[orderbyColIndex], reverse = True) 
            orderedDataset+= dataset
            return orderedDataset
        else:
            print("Invalid Ordering Style.")
            return
    except:
        print ("Invalid Attribute Field in Order By Clause, please check and try again.")

In [20]:
def orderByOnGroups(dataset, orderbyCol, orderbyStyle):
    try :
        for attributeName in dataset[0][0]:
            if attributeName.split('.')[1]==orderbyCol:
                orderbyColIndex= dataset[0][0].index(attributeName)
        orderedDataset= [dataset[0]]
        dataset= dataset[1:]
        if orderbyStyle =='ASC' or orderbyStyle =='asc':
            for group in dataset:
                group.sort(key = lambda group: group[orderbyColIndex]) 
                orderedDataset +=[group]
            return orderedDataset
        elif orderbyStyle =='DESC' or orderbyStyle =='desc':
            for group in dataset:
                group.sort(key = lambda group: group[orderbyColIndex], reverse = True) 
                orderedDataset+= [group]
            return orderedDataset
        else:
            print("Invalid Ordering Style.")
            return
    except:
        print ("Invalid Attribute Field in Order By Clause, please check and try again.")

In [21]:
def processQuery(query):
    if(query[len(query)-1] != ';'): 
        print ("Invalid Syntax..SemiColon Missing!")
        exit()
        return
    
    query= query[:-1]
    parsedQuery = sqlparse.parse(query)[0].tokens
    queryType = sqlparse.sql.Statement(parsedQuery).get_type()
    identifierList = [str(item).lower() for item in sqlparse.sql.IdentifierList(parsedQuery).get_identifiers()]
    print (identifierList)
    
    selectCount= identifierList.count('select')
    fromCount= identifierList.count('from')
    whereCount, whereClause= 0, ''
    distinctCount= identifierList.count('distinct')
    groupbyCount= identifierList.count('group by')
    orderbyCount= identifierList.count('order by')
    
    if distinctCount==1 and len(identifierList)>5 and 'where' in identifierList[5]:
            whereCount= 1
            whereClause= identifierList[5]
    if distinctCount==0 and len(identifierList)>4 and 'where' in identifierList[4]:
            whereCount= 1
            whereClause= identifierList[4]
                
#     if either select or from clause is absent....and any other has a count greater than 1
    if(selectCount!=1 or fromCount!=1 or whereCount>1 or distinctCount>1):
            print ("Invalid Query!")
            return
        
#    fetch the table names to get the dataset to process queries on
    tableList= identifierList[identifierList.index('from')+1].split(',')
    tableList= [tableName.strip() for tableName in tableList]
    dataset= selectAll(tableList)    #fetch all attributes of table(s)
    
# ------------------------------------------------------------------------------------------------------------ #
    
    if whereCount ==1:
#         extract the condition
        condition= whereClause.split(' ')[1:]
        if condition[len(condition)-1].strip() =='':
            condition.pop()
        dataset= meetCondition(dataset, condition)
        if(groupbyCount==0 and isAggregate(identifierList[1]) ==False and orderbyCount==0):
            if dataset:
                attributeList= identifierList[1].split(',')
                attributeList= [attributeName.strip() for attributeName in attributeList]
                if(len(attributeList)>=1 and attributeList[0]!='*'):
                    dataset= dropUnwantedAttributes(dataset, attributeList)
                if dataset:
                    for row in dataset:
                        for i in range (0, len(row)-1):
                            print (row[i], end=',')
                        print(row[len(row)-1])
           
    if groupbyCount==1:
        groupbyCol= identifierList[identifierList.index('group by')+1]
        if dataset:
            dataset= groupBy(dataset, groupbyCol)
            if isAggregate(identifierList[1])==False and orderbyCount==0:
                attributeList= identifierList[1].split(',')
                attributeList= [attributeName.strip() for attributeName in attributeList]
                dataset= dropUnwantedAttributesFromGroups(dataset, attributeList)
                if dataset:
                    for group in dataset:
                        for row in group:
                            for i in range (0, len(row)-1):
                                print (row[i], end=',')
                            print(row[len(row)-1])

    if isAggregate(identifierList[1]):
        attributeList= identifierList[1].split(',')
        attributeList= [attributeName.strip() for attributeName in attributeList]
        aggregateFunctions, aggregateAttributes =[], []
        for attributeNum in range (0, len(attributeList)):
            if(isAggregate(attributeList[attributeNum])):
                aggregateFunctions.append(attributeList[attributeNum].split('(')[0])
                aggregateAttributes.append(attributeList[attributeNum].split('(')[1].split(')')[0])
                attributeList[attributeNum]= attributeList[attributeNum].split('(')[1].split(')')[0]


        if dataset:
            if groupbyCount==0:
                columnDisplay, aggregateResults= solveAggregate(dataset, aggregateAttributes, aggregateFunctions)
            else:
                columnDisplay, aggregateResults= solveAggregateOnGroups(dataset, groupbyCol, attributeList, aggregateAttributes, aggregateFunctions)
            if(aggregateResults==None):
                return
            elif orderbyCount==1:
                try:
                    orderbyStyle= identifierList[identifierList.index('order by')+1].split(' ')[1]
                except:
                    print("Please provide Style by which you want to Order.")
                    return
                if orderbyStyle=='ASC' or orderbyStyle=='asc':
                    print (aggregateResults.sort())
                elif orderbyStyle=='DESC' or orderbyStyle=='desc':
                    print (aggregateResults.sort(reverse= True))
                else:
                    print("Invalid Ordering Style.")
                    return
            else:
                for i in range (0, len(columnDisplay)-1):
                    print (columnDisplay[i], end= ',')
                print(columnDisplay[len(columnDisplay)-1])
                for i in range (0, len(aggregateResults)-1):
                    print(aggregateResults[i], end=',')
                print(aggregateResults[len(aggregateResults)-1])
                return
        
    if distinctCount==1:
        attributeList= identifierList[2].split(',')
        attributeList= [attributeName.strip() for attributeName in attributeList]
        
        dataset= dropUnwantedAttributes(dataset, attributeList) 
        if dataset:
            dataset= selectDistinct(dataset)
        if orderbyCount==0:
            for row in dataset:
                for i in range (0, len(row)-1):
                    print (row[i], end=',')
                print(row[len(row)-1])

    if orderbyCount==1:
        attributeList= identifierList[1].split(',')
        attributeList= [attributeName.strip() for attributeName in attributeList]
        if groupbyCount==0:
            dataset= dropUnwantedAttributes(dataset, attributeList) 
        else:
            dataset= dropUnwantedAttributesFromGroups(dataset, attributeList) 
        if dataset:
            try:
                orderbyCol= identifierList[identifierList.index('order by')+1].split(' ')[0]
            except: 
                print("Please provide Attribute by which you want to Order.")
                return
            try:
                orderbyStyle= identifierList[identifierList.index('order by')+1].split(' ')[1]
            except:
                print("Please provide Style by which you want to Order.")
                return
            if groupbyCount==0:
                dataset= orderBy(dataset, orderbyCol, orderbyStyle)
                if dataset:
                    for row in dataset:
                        for i in range (0, len(row)-1):
                            print (row[i], end=',')
                        print(row[len(row)-1])
            else:
                print (dataset)
                dataset= orderByOnGroups(dataset, orderbyCol, orderbyStyle)
                if dataset:
                    for group in dataset:
                        for row in group:
                            for i in range (0, len(row)-1):
                                print (row[i], end=',')
                            print(row[len(row)-1])

            
    if whereCount==0 and isAggregate(identifierList[1])==False and groupbyCount==0 and orderbyCount==0 and distinctCount==0:
        attributeList= identifierList[1].split(',')
        attributeList= [attributeName.strip() for attributeName in attributeList]
#             fetch all attributes of table(s)
        if(len(attributeList)==1 and attributeList[0]=='*'):
            dataset= selectAll(tableList)
            if dataset:
                for row in dataset:
                    for i in range (0, len(row)-1):
                        print (row[i], end=',')
                    print(row[len(row)-1])
                    
#             fetch some attribute(s) of table(s)
        elif(len(attributeList)>=1 and attributeList[0]!='*'):
            dataset= selectAll(tableList)
            dataset= dropUnwantedAttributes(dataset, attributeList)
            if dataset:
                for row in dataset:
                    for i in range (0, len(row)-1):
                        print (row[i], end=',')
                    print(row[len(row)-1])
        else:
            print("Invalid Query")
            return

In [None]:
if __name__ == "__main__":
    getTableSchema()
    while(1):
        print ("Your Query?")
        query= input()
        processQuery(query)
        print ()

Your Query?
