In [1]:
import csv
import sqlparse
import itertools
from itertools import product, chain

tableSchema= {}

In [2]:
# loading the tabes schema 
def getTableSchema():
    f = open("./files/metadata.txt", "r")
    metaData= f.read()
#     split the metadata based on begin_table
    metaData= metaData.split('<begin_table>\n')
#     split all entries based on end_table to get only the table name and its attributes in each list entry
    metaData= [data.split('<end_table>\n') for data in metaData]
    
#     for each table store its attibute list corresponding to table name after splitting it based on next line char
    for entry in range (1, len(metaData)):
        tableInfo= metaData[entry][0]
        tableInfo= tableInfo.split('\n')[:-1]
        tableSchema[tableInfo[0]]= tableInfo[1:]

In [3]:
def readTable(tableName):
    if(tableName not in tableSchema.keys()):
        print("Invalid Table Name, {}!".format(tableName))
        exit()
    tableFile= open('./files/'+tableName+'.csv', 'r')
    tableReader= csv.reader(tableFile)
    tableData= [[tableName+'.'+colName for colName in tableSchema[tableName]]]
    for row in tableReader:
        tableData.append(row)    
        
    tableData[1:] = [list(map(int, tableData[tableRow])) for tableRow in range (1, len(tableData))]
#     for row in tableData:
#         print (row)
    return tableData

In [4]:
def prepData(tableList):
#     append the rows of all tables to each other and prepare cross-product dataset
    if(len(tableList)==0):
        print ("Invalid Query..Please provide Table Name(s).")
        exit()
    dataset= readTable(tableList[0])
    for tableNum in range (1, len(tableList)): 
        currentTable= readTable(tableList[tableNum])
        dataset[0]+= currentTable[0]
        finalTable= [dataset[0]]
        for ind1 in range (1, len(dataset)):
            for ind2 in range (1, len(currentTable)):
                finalTable.append(dataset[ind1]+currentTable[ind2])
        dataset= finalTable
    return dataset

In [5]:
def selectAll(tableList):    
#     select all the attirbutes and display the result
    dataset= prepData(tableList)
#     for row in dataset:
#         print (row)
    return dataset

In [6]:
def selectSome(tableList, attributeList):
    if(len(attributeList)==0):
        print ("Invalid Query..Please provide Attribute Name(s).")
        exit()
        return 
    
    dataset= prepData(tableList)
    actualCols= []
    indexNums= []
    for index in range (0, len(dataset[0])):
        attributeName= dataset[0][index].split('.')[1].upper()
        actualCols.append(attributeName)
        if(attributeName not in attributeList):
            indexNums.append(index)
#     invalid query if any such attribute is demanded which does not exist
    for attribute in attributeList:
        if attribute not in actualCols:
            print("Invalid Attribute Name, {}!".format(attribute))
            exit()
            return
#     pop all those attibutes and respective values from all records which are not demanded
    indexNums.sort(reverse= True)    #pop from last index to avoid ambiguities
    for row in dataset:
        for index in indexNums:
            if index < len(row):
                row.pop(index)
#     for row in dataset:
#         print (row)
    return dataset

In [7]:
def selectDistinct(dataset):
    tempData= dataset[1:]
    dataset= [dataset[0]]
    for entry in tempData:
        if entry not in dataset:
            dataset.append(entry)
    
    for row in dataset:
        print(row)

In [8]:
# check if the query has aggregate function present
def isAggregate(subQuery):
    if 'max' in subQuery or 'min' in subQuery or 'avg' in subQuery or 'sum' in subQuery or 'count' in subQuery:
        return True
    return False

In [9]:
def solveAggregate(tableList, attributeList, function):
#     generate error when multiple tables are used to perform aggregate query on
    if(len(tableList)>1):
        print("Invalid Query..Aggregate Function cannot be applied to multiple tables/attributes.")
        exit() 
        return
        
#     generate error when proper parameters are not passed to aggregate functions
    if len(attributeList)!=1:
        print("Invalid number of parameters for Aggregate Function. It must be Function(Column Name).")
        exit()
        return
    
    else:
        values= selectSome(tableList, attributeList)[1:] #eliminating the name of column
        values = list(chain.from_iterable(values)) 
        print (values)
        if(function == 'max'):
            print(max(values))
        elif(function == 'min'):
            print(min(values))
        elif(function == 'sum'):
            sum= sum(values)
            print(sum)
        elif(function == 'avg'):
            sum= sum(values)
            avg=sum/len(values)
            print(avg)
        else:
            print("Invalid Aggregate function!")
            exit()
            return
    

In [10]:
# return the attribute, operator and value of the where condtion in that order
def getConditionUnits(condition):
    if '<=' in condition:
        return condition.split('<=')[0], '<=', int(condition.split('<=')[1])
    elif '>=' in condition:
        return condition.split('>=')[0], '>=', int(condition.split('>=')[1])
    elif '=' in condition:
        return condition.split('=')[0], '=', int(condition.split('=')[1])
    elif '<' in condition:
        return condition.split('<')[0], '<', int(condition.split('<')[1])
    elif '>' in condition:
        return condition.split('>')[0], '>', int(condition.split('>')[1])

In [11]:
# check if the data value in table satisfies the query condition
def isSatisfied(data, operator, value):
    if '<=' == operator:
        return data<=value
    elif '>=' == operator:
        return data>=value
    elif '='== operator:
        return data==value
    elif '<' == operator:
        return data<value
    elif '>' == operator:
        return data>value

In [12]:
# eliminate the records from dataset which do not meet condition specified
def removeRecords(dataset, condition):
    attribute, operator, value= getConditionUnits(condition)
    datasetColumns= dataset[0]
    toCheckIndices= []
#     extract the columns with the attribute asked (can be throught multiple columns)
    for colNum in range (0, len(datasetColumns)):
        if datasetColumns[colNum].split('.')[1].lower()==attribute:
            print (datasetColumns[colNum])
            toCheckIndices.append(colNum)
            
    cleanedDataset= []
    for rowNum in range(1, len(dataset)):
        row= dataset[rowNum]
#         assume that the row satisfies the condition
        toInclude= True
        for col in toCheckIndices:
#             if it does not satisfy make the 'toInclude' flag as False
            if isSatisfied(row[col], operator, value) ==False:
                toInclude= False
                break
#         only if 'toInclude' is still true, add the row to final dataset
        if(toInclude):
            cleanedDataset.append(row)
    
    return cleanedDataset

In [13]:
def union(dataset1, dataset2):
    tuple1=map(tuple,dataset1)
    tuple2=map(tuple,dataset2)
    return list(set(tuple1).union(tuple2))

In [14]:
def intersection(dataset1, dataset2):
    return [value for value in lst1 if value in lst2] 

In [15]:
def meetCondition(dataset, condition):
    if len(condition)==1:
        dataset= removeRecords(dataset, condition[0])
        print (dataset)
    elif len(condition)==3:
        dataset1= removeRecords(dataset, condition[0])
        dataset2= removeRecords(dataset, condition[2])
        print (dataset1, "\n\n\n\n\n\n\n")
        print (dataset2, "\n\n\n\n\n\n\n")
        if(condition[1]=='or'):
            unionDataset= union(dataset1, dataset2)
            print (unionDataset)
        elif(condition[1]=='and'):
            intersectionDataset= intersection(dataset1, dataset2)
            print(intersectionDataset)
        else:
            print("Invalid Condition!")
            exit()
            return
    else:
        print("Invalid Syntax!")
        exit()
        return

In [16]:
def processQuery(query):
    if(query[len(query)-1] != ';'): 
        print ("Invalid Syntax..SemiColon Missing!")
        exit()
        return
    
    query= query[:-1]
    parsedQuery = sqlparse.parse(query)[0].tokens
    queryType = sqlparse.sql.Statement(parsedQuery).get_type()
    identifierList = [str(item).lower() for item in sqlparse.sql.IdentifierList(parsedQuery).get_identifiers()]
    print (identifierList)
    
    selectCount= identifierList.count('select')
    fromCount= identifierList.count('from')
    whereCount= identifierList.count('where')
    distinctCount= identifierList.count('distinct')
    groupbyCount= identifierList.count('group by')
    orderbyCount= identifierList.count('order by')
    havingCount= identifierList.count('having')
    
#     if either select or from clause is absent....and any other has a count greater than 1
    if(selectCount!=1 or fromCount!=1 or whereCount>1 or distinctCount>1):
            print ("Invalid Query!")
            exit()
            return
        
#    fetch the table names to get the dataset to process queries on
    tableList= identifierList[identifierList.index('from')+1].split(',')
    tableList= [tableName.strip() for tableName in tableList]
    
#     a plain query of select and from clauses only
#     if(whereCount ==0 and groupbyCount ==0 and orderbyCount ==0 and havingCount ==0):
        
# #         if aggregate function query
#         if(isAggregate(identifierList[1])):
#             print ("Solving the aggregate function..")
#             function= identifierList[1].split('(')[0]
#             attributeList= identifierList[1].split('(')[1].split(')')[0].split(',')
#             attributeList= [attributeName.strip().upper() for attributeName in attributeList]
#             solveAggregate(tableList, attributeList, function)
            
# #        if distinct keyword is present
#         elif(distinctCount==1):
#             print ("Solving the distinct function..")
#             attributeList= identifierList[2].split(',')
#             attributeList= [attributeName.strip().upper() for attributeName in attributeList]
#             dataset= selectSome(tableList, attributeList)
#             dataset= selectDistinct(dataset)      
            
# #         if distinct keyword is not present
#         elif(distinctCount ==0):
#             print ("Solving the normal query..")
#             attributeList= identifierList[1].split(',')
#             attributeList= [attributeName.strip().upper() for attributeName in attributeList]
# #             fetch all attributes of table(s)
#             if(len(attributeList)==1 and attributeList[0]=='*'):
#                 dataset= selectAll(tableList)
# #             fetch some attribute(s) of table(s)
#             elif(len(attributeList)>=1 and attributeList[0]!='*'):
#                 dataset= selectSome(tableList, attributeList)
#             else:
#                 print("Invalid Query")
#                 exit()
#                 return
    if(True):
        print ("Solving the where query..")
        attributeList= identifierList[1].split(',')
        attributeList= [attributeName.strip().upper() for attributeName in attributeList]
#         fetch all attributes of table(s)
        dataset= selectAll(tableList)
#         extract the condition
        condition= identifierList[-1].split(' ')[1:]
        dataset= meetCondition(dataset, condition)

In [None]:
if __name__ == "__main__":
    getTableSchema()
    while(1):
        print("Your Query?")
        query= input()
        processQuery(query)
        print ()

Your Query?
select * from table1 where a>800 OR b>100;
['select', '*', 'from', 'table1', 'where a>800 or b>100']
Solving the where query..
table1.A
table1.B
[[922, 158, 5727], [858, 731, 3668], [922, 158, 1234], [922, 158, 5235]] 







[[922, 158, 5727], [640, 773, 5058], [-551, 811, 1534], [-952, 311, 1318], [-354, 646, 7063], [-497, 335, 4549], [411, 803, 10519], [-900, 718, 9020], [858, 731, 3668], [922, 158, 1234], [922, 158, 5235]] 







[(-354, 646, 7063), (-900, 718, 9020), (-551, 811, 1534), (-952, 311, 1318), (411, 803, 10519), (-497, 335, 4549), (922, 158, 1234), (922, 158, 5727), (640, 773, 5058), (858, 731, 3668), (922, 158, 5235)]

Your Query?
