# Convert SQL relation declaration to python class declaration

In [324]:
import textwrap

In [325]:
def convert(sql_script: str) -> str:
    src = textwrap.dedent(sql_script).strip('\n')
    lines = src.splitlines()
    table_name = lines[0].split('"')[1]
    
    declaration = textwrap.dedent(f'''
    class {to_camel_case(table_name)}(Base):
        __tablename__ = "{table_name}"
    ''').strip('"')
    
    foreign_keys: dict[str, str] = {}
    for line in lines[1:-1]:
        line = line.strip(' ,\t')
        if not line.startswith('FOREIGN KEY'):
            continue
    
        parts = line.split()
        column = parts[2]
        ref = parts[4]
        column = column.strip('()"')
        ref = ref.strip('()"').replace('"("', '.')
        foreign_keys[column] = ref
    
    for line in lines[1:-1]:
        line = line.strip(' ,\t')
        if line.startswith('FOREIGN KEY'):
            continue
    
        column, *args = line.split()
        column = column.strip('"')
        
        is_primary_key = False
        nullable = True
        unique = False
        
        if args == ['SERIAL', 'PRIMARY', 'KEY']:
            is_primary_key = True
        if 'NOT' in args and 'NULL' in args:
            nullable = False
        if 'UNIQUE' in args:
            unique = True
    
        dtype = None
        if 'TEXT' in args:
            dtype = 'String'
        elif {'INTEGER', 'SMALLINT'} & set(args):
            dtype = 'Integer'
        elif 'BOOLEAN' in args:
            dtype = 'Boolean'
        elif 'TIMESTAMP' in args:
            dtype = 'DateTime'
        elif 'DATE' in args:
            dtype = 'Date'
        elif any(map(lambda x: x.startswith('NUMERIC'), args)):
            dtype = 'Float'
        if dtype is None and not is_primary_key:
            raise ValueError('Unknown dtype')
    
        if is_primary_key:
            col_declaration = '''
            id = Column(Integer, primary_key=True, nullable=False)
            '''
        else:
            col_declaration = '''
            {column} = Column({dtype}{foreign_key}, nullable={nullable}, unique={unique})
            '''.format(
                column=column,
                dtype=dtype,
                nullable=nullable,
                unique=unique,
                foreign_key=(
                    f", ForeignKey('{foreign_keys[column]}')"
                    if column in foreign_keys else ''
                ),
            )
    
        col_declaration = textwrap.dedent(col_declaration).strip('\n')
        declaration += '\n    ' + col_declaration
    return declaration

In [326]:
src = '''
CREATE TABLE IF NOT EXISTS "attempt_test_option" (
   "id" SERIAL PRIMARY KEY,
   "attempt_id" INTEGER NOT NULL,
   "option_id" INTEGER NOT NULL,
   FOREIGN KEY ("attempt_id") REFERENCES "attempt"("id"),
   FOREIGN KEY ("option_id") REFERENCES "test_option"("id")
);
'''

print(convert(src))


class AttemptTestOption(Base):
    __tablename__ = "attempt_test_option"

    id = Column(Integer, primary_key=True, nullable=False)
    attempt_id = Column(Integer, nullable=False, unique=False, ForeignKey('attempt.id'))
    option_id = Column(Integer, nullable=False, unique=False, ForeignKey('test_option.id'))
