In [3]:
import attr
from contextlib import contextmanager

In [4]:
class Transaction:
    def __init__(self, parent):
        self._adds = dict()
        self._deletes = set()
        self._parent = parent

    def __str__(self):
        s = f"self: adds: {self._adds}, deletes: {self._deletes}\n\tparent: {self._parent}"
        return s

    def add(self, key, val):
        # if the key has been deleted already then pluck it out of the deletes set
        if key in self._deletes:
            self._deletes.remove(key)
        self._adds[key] = val

    def delete(self, key):
        if key in self._deletes:
            # key must have alredy been deleted
            raise KeyError(f"Key {key} not found!")
        elif key in self._adds:
            # delete from adds and if parent knows about it then add it to deletes set
            del self._adds[key]
            if self._parent.contains(key):
                self._deletes.add(key)
        elif self._parent.contains(key):
            self._deletes.add(key)
        else:  # key not found in self or parent!
            raise KeyError(f"Key {key} not found!")
    
    def get(self, key):
        if key in self._deletes:
            raise KeyError(f"key {key} not found!")
        elif key in self._adds:
            return self._adds[key]
        else:
            return self._parent.get(key)

    def contains(self, key):
        return key in self._adds or self._parent.contains(key)

    def rollback(self):
        self._reset_state()

    def commit(self):
        added_keys = set(self._adds.keys())
        assert not (added_keys & self._deletes), \
            f"Keys of deletes({self._deletes}) shouldn't be in adds ({added_keys})!"
        self._parent._propogate_commit(self._adds, self._deletes)
        self._reset_state()

    def _reset_state(self):
        self._adds.clear()
        self._deletes.clear()

    def _propogate_commit(self, adds, deletes):
        for key, val in adds.items():
            self._adds[key] = val
        # handling of deletes is tricky, either we edit adds or we have add to deletes
        for key in deletes:
            if key in self.adds.keys():
                del self.adds[key]
            else:
                self._deletes.add(key)

In [5]:
class InMemoryDb:
    @contextmanager
    def begin_transaction(self, parent_transaction=None):
        try:
            if parent_transaction:
                t = Transaction(parent_transaction)
            else:
                # Db itself is duck-typed to be the transaction so pass itself to a transaction at level 1
                t = Transaction(self)
            yield t
        finally:
            # if rollback hasn't been called then we'll commit uncommitted stuff
            t.commit()

    def __init__(self, filename=None):
        self._filename = filename
        self._data = self._load_data(self._filename)

    def __str__(self):
        return f"data: {self._data}, filename: {self._filename}"

    def _load_data(self, filename):
        if filename:
            pass  # TODO: load data from file
        return dict()

    def _snapshot_data(self, filename):
        if filename:
            pass  # TODO: write a snapshot of in-memory db to file

    def _propogate_commit(self, adds, deletes):
        assert not (deletes - set(self._data.keys())), "Every key in deletes must exist in data!"
        assert not (set(adds.keys()) & deletes), "Keys of deletes shouldn't be in adds!"
        # apply the add and deletes to the main data and create a disk snapshot to persist the changes.
        self._data.update(adds)
        for key in deletes:
            del self._data[key]
        self._snapshot_data(self._filename)

    def contains(self, key):
        return key in self._data
    
    def get(self, key):
        return self._data[key]

In [6]:
if __name__ == "__main__":
    db = InMemoryDb()
    with db.begin_transaction() as t1:
        t1.add(1, 100)
        t1.add(2, 200)
        t1.add(3, 200)
        t1.delete(2)
        print(f"key: 1, value: {t1.get(1)}")
        print(f"key: 3, valye: {t1.get(3)}")
        try:
            val = t1.get(2)
            assert False, "Should have raised an exception!"
        except KeyError as e:
            print(f"Caught KeyError: {e}, as expected")
            assert True, "Should have raised KeyError"
        print(t1)
        print(f"Before commit: {db}")
        with db.begin_transaction(t1) as t2:
            t2.add(2, 200)
            print(f"key: 2, value: {t2.get(2)}")
            try:
                val = t1.get(2)
            except KeyError as e:
                print(f"Caught KeyError: {e}, as expected!")
            print(t2)
            t2.rollback()
        print(t1)
            
    print(f"after commit: {db}")

key: 1, value: 100
key: 3, valye: 200
Caught KeyError: 2, as expected
self: adds: {1: 100, 3: 200}, deletes: set()
	parent: data: {}, filename: None
Before commit: data: {}, filename: None
key: 2, value: 200
Caught KeyError: 2, as expected!
self: adds: {2: 200}, deletes: set()
	parent: self: adds: {1: 100, 3: 200}, deletes: set()
	parent: data: {}, filename: None
self: adds: {1: 100, 3: 200}, deletes: set()
	parent: data: {}, filename: None
after commit: data: {1: 100, 3: 200}, filename: None
