In [1]:
from typing import Union
from multipledispatch import dispatch

In [2]:
Point = Union[list[int], tuple[int, ...]]

In [15]:
class Block():

    @dispatch((tuple, list), (tuple, list))
    def __init__(self, start: Point, end: Point) -> None:
        self.start = tuple(start)
        self.end = tuple(end)
        self.dim = len(start)

    @dispatch(str)
    def __init__(self, coords: str) -> None:
        coordlist = coords.split(',')
        self.dim = len(coordlist)
        start = [0,]*self.dim
        end = [0,]*self.dim
        for i, c in enumerate(coordlist):
            v = c.split('=')[1]
            se = v.split('..')
            start[i] = int(se[0])
            end[i] = int(se[1])
        self.start = tuple(start)
        self.end = tuple(end)

    def __repr__(self) -> str:
        return f'{self.start} => {self.end}'

    def volume(self) -> int:
        vol = 1
        for i in range(self.dim):
            vol *= (self.end[i]-self.start[i]) + 1
        return vol
        
    def check_overlap(self, other: Block) -> bool:
        for i in range(self.dim):
            if self.end[i] < other.start[i]:
                return False
            if other.end[i] < self.start[i]:
                return False
        return True

    def slice(self, other: Block) -> list[Block]:
        if not self.check_overlap(other):
            return [self,]
        cuts = [0,]*self.dim
        intersecc = [0,]*self.dim
        for d in range(self.dim):
            if self.start[d] <= other.start[d]:
                if self.end[d] < other.end[d]:
                    cuts[d] = (self.start[d]-1, other.start[d]-1, self.end[d])
                    intersecc[d] = 1
                else:
                    cuts[d] = (self.start[d]-1, other.start[d]-1, other.end[d], self.end[d])
                    intersecc[d] = 1
            else:
                if self.end[d] >= other.end[d]:
                    cuts[d] = (self.start[d]-1, other.end[d], self.end[d])
                    intersecc[d] = 0
                else:
                    cuts[d] = (self.start[d]-1, self.end[d])
                    intersecc[d] = 0
                    
        childs = []
        child_index = [0,]*self.dim
        cont = True
        while cont:
            if child_index != intersecc:
                start = [0,]*self.dim
                end = [0,]*self.dim
                for d in range(self.dim):
                    start[d] = cuts[d][child_index[d]]+1
                    end[d] = cuts[d][child_index[d]+1]
                childs.append(Block(start, end))
            rem = 1
            for s in range(len(child_index)):
                r = (child_index[s] + rem) // (len(cuts[s])-1)
                child_index[s] = (child_index[s] + rem) % (len(cuts[s])-1)
                rem = r
            if rem == 1:
                cont = False
        return childs



In [16]:
with open('input', 'rt') as f:
    data = f.read()
lines = data.split('\n')

In [30]:
blocklist = [Block(lines[0].split(' ')[1]),]

In [31]:
for i in range(1,len(lines)-1):
    vol = 0
    for b in blocklist:
        vol += b.volume()
    blockdata = lines[i].split(' ')
    newb = Block(blockdata[1])
    newlist = []
    for b in blocklist:
        newlist += b.slice(newb)
    if blockdata[0] == 'on':
        newlist.append(newb)
    blocklist = newlist

In [32]:
vol = 0
for b in blocklist:
    vol += b.volume()

In [33]:
vol

1235484513229032