In [81]:
from contextlib import contextmanager

@contextmanager
def test():
    try:
        print('connected')
        yield 1
        print('olll')
    finally:
        print('disconnect')


In [82]:
with test() as conn:
    print(conn)

connected
1
olll
disconnect


In [89]:
class CustomContextManager:
    def __init__(self, func):
        self.func = func
        self.gen = None
    
    def __enter__(self):
        print('enter')
        self.gen = self.func()
        return next(self.gen)
    
    def __exit__(self, *args, **kwargs):
        print('exit')
        try:
            next(self.gen)
        except StopIteration:
            pass
        finally:
            self.gen.close()

    def __call__(self):
        return self

@CustomContextManager
def test():
    try:
        print('connected')
        yield 1
        print('olll')
    finally:
        print('disconnect')

with test() as conn:
    print('....')

enter
connected
....
exit
olll
disconnect


In [107]:
import sqlite3
from contextlib import contextmanager
from dataclasses import dataclass

DB_NAME = 'problems_v9.db'
PROBLEM_TABLE = 'problems'
SUBMISSION_TEST_CASE_TABLE = 'submission_test_cases'

class DB:
    def __init__(self, cursor, conn):
        self.cursor = cursor
        self.conn = conn

@contextmanager
def create_sql_connection():
    conn = None
    try:
        conn = sqlite3.connect(DB_NAME)
        conn.row_factory = sqlite3.Row
        yield DB(cursor=conn.cursor(), conn= conn)
    finally:
        if conn:
            conn.close()

def get_problem(problem_id: str):
    with create_sql_connection() as _db:
        return _db.cursor.execute(
            f"SELECT * FROM {PROBLEM_TABLE} WHERE questionId = ?", (problem_id, )
        ).fetchone()
        
        

In [108]:
get_problem('605')

<sqlite3.Row at 0x111d51990>