In [1]:
import csv
import sqlparse
from itertools import product, chain

tableSchema= {}

In [2]:
# loading the tabes schema 
def getTableSchema():
    f = open("./files/metadata.txt", "r")
    metaData= f.read()
#     split the metadata based on begin_table
    metaData= metaData.split('<begin_table>\n')
#     split all entries based on end_table to get only the table name and its attributes in each list entry
    metaData= [data.split('<end_table>\n') for data in metaData]
    
#     for each table store its attibute list corresponding to table name after splitting it based on next line char
    for entry in range (1, len(metaData)):
        tableInfo= metaData[entry][0]
        tableInfo= tableInfo.split('\n')[:-1]
        tableSchema[tableInfo[0]]= tableInfo[1:]

In [3]:
def readTable(tableName):
    if(tableName not in tableSchema.keys()):
        print("Invalid Table Name, {}!".format(tableName))
        exit()
    tableFile= open('./files/'+tableName+'.csv', 'r')
    tableReader= csv.reader(tableFile)
    tableData= [[tableName+'.'+colName for colName in tableSchema[tableName]]]
    for row in tableReader:
        tableData.append(row)    
        
    tableData[1:] = [list(map(int, tableData[tableRow])) for tableRow in range (1, len(tableData))]
    for row in tableData:
        print (row)
    return tableData

In [4]:
def prepData(tableList):
    if(len(tableList)==0):
        print ("Invalid Query..Please provide Table Name(s).")
        exit()
    dataset= readTable(tableList[0])
    for tableNum in range (1, len(tableList)): 
        currentTable= readTable(tableList[tableNum])
        dataset[0]+= currentTable[0]
        finalTable= [dataset[0]]
        for ind1 in range (1, len(dataset)):
            for ind2 in range (1, len(currentTable)):
                finalTable.append(dataset[ind1]+currentTable[ind2])
        dataset= finalTable
    return dataset

In [5]:
def selectAll(tableList):    
    dataset= prepData(tableList)
    for row in dataset:
        print (row)
    return dataset

In [6]:
def selectSome(tableList, attributeList):
    if(len(attributeList)==0):
        print ("Invalid Query..Please provide Attribute Name(s).")
        exit()
        return 
    
    dataset= prepData(tableList)
    actualCols= []
    indexNums= []
    for index in range (0, len(dataset[0])):
        attributeName= dataset[0][index].split('.')[1].upper()
        actualCols.append(attributeName)
        if(attributeName not in attributeList):
            indexNums.append(index)
    
    for attribute in attributeList:
        if attribute not in actualCols:
            print("Invalid Attribute Name, {}!".format(attribute))
            exit()
            return
    
    indexNums.sort(reverse= True)
    for row in dataset:
        for index in indexNums:
            if index < len(row):
                row.pop(index)
    for row in dataset:
        print (row)
    return dataset

In [7]:
def isAggregate(subQuery):
    if 'max' in subQuery or 'min' in subQuery or 'avg' in subQuery or 'sum' in subQuery or 'count' in subQuery:
        return True
    return False

In [8]:
def solveAggregate(tableList, attributeList, function):
    if(len(tableList)>1):
        print("Invalid Query..Aggregate Function cannot be applied to multiple tables/attributes.")
        exit()
        return
        
    if len(attributeList)!=1:
        print("Invalid number of parameters for Aggregate Function. It must be Function(Column Name).")
        exit()
        return
    
    else:
        values= selectSome(tableList, attributeList)[1:]
        values = list(chain.from_iterable(values)) 
        print (values)
        if(function == 'max'):
            print(max(values))
        elif(function == 'min'):
            print(min(values))
        elif(function == 'sum'):
            sum= sum(values)
            print(sum)
        elif(function == 'avg'):
            sum= sum(values)
            avg=sum/len(values)
            print(avg)
        elif(function == 'distinct'):
            newValues = list(set(values))
            for i in newValues:
                print(i)
        else:
            print("Invalid Aggregate function!")
            exit()
            return
    

In [9]:
def processQuery(query):
    if(query[len(query)-1] != ';'): 
        print ("Invalid Syntax..SemiColon Missing!")
        exit()
        return
    
    parsedQuery = sqlparse.parse(query)[0].tokens
    queryType = sqlparse.sql.Statement(parsedQuery).get_type()
    identifierList = [str(item).lower() for item in sqlparse.sql.IdentifierList(parsedQuery).get_identifiers()][:-1]
    
    print (identifierList)
    selectCount= identifierList.count('select')
    fromCount= identifierList.count('from')
    whereCount= identifierList.count('where')
    distinctCount= identifierList.count('distinct')
    groupbyCount= identifierList.count('group by')
    orderbyCount= identifierList.count('order by')
    havingCount= identifierList.count('having')
    
#     if either select or frompassorder clause is absent....and any other has a count greater than 1
    if(selectCount!=1 or fromCount!=1 or whereCount>1 or distinctCount>1):
            print ("Invalid Query!")
            exit()
            return
    
#     a plain query of select and from clauses only
    if(whereCount ==0 and groupbyCount ==0 and orderbyCount ==0 and havingCount ==0):
#         fetch the table names to get the dataset to process queries on
        tableList= identifierList[-1].split(',')
        tableList= [tableName.strip() for tableName in tableList]
        
#         if aggregate function query
        if(isAggregate(identifierList[1])):
            function= identifierList[1].split('(')[0]
            attributeList= identifierList[1].split('(')[1].split(')')[0].split(',')
            attributeList= [attributeName.strip().upper() for attributeName in attributeList]
            solveAggregate(tableList, attributeList, function)
#         if distinct keyword is not present
        elif(distinctCount ==0):
            attributeList= identifierList[1].split(',')
            attributeList= [attributeName.strip().upper() for attributeName in attributeList]
#             fetch all attributes of table(s)
            if(len(attributeList)==1 and attributeList[0]=='*'):
                dataset= selectAll(tableList)
#             fetch some attribute(s) of table(s)
            elif(len(attributeList)>=1 and attributeList[0]!='*'):
                dataset= selectSome(tableList, attributeList)
            else:
                print("Invalid Query")
                exit()
                return
        
        else:
            attributeList= identifierList[2].split(',')
            attributeList= [attributeName.strip() for attributeName in attributeList]


In [None]:
if __name__ == "__main__":
    getTableSchema()
    while(1):
        print("Your Query?")
        query= input()
        processQuery(query)
        print ()

Your Query?
select max(A) from table1;
['select', 'max(a)', 'from', 'table1']
['table1.A', 'table1.B', 'table1.C']
[922, 158, 5727]
[640, 773, 5058]
[775, 85, 10164]
[-551, 811, 1534]
[-952, 311, 1318]
[-354, 646, 7063]
[-497, 335, 4549]
[411, 803, 10519]
[-900, 718, 9020]
[858, 731, 3668]
['table1.A']
[922]
[640]
[775]
[-551]
[-952]
[-354]
[-497]
[411]
[-900]
[858]
[922, 640, 775, -551, -952, -354, -497, 411, -900, 858]
922

Your Query?
