# Bricks API

In [None]:
from typing import Any, Optional, Tuple, Mapping
from torch import nn

BrickState = Mapping[str, Any]
FitData = Any
ProcessedFitData = Any
Data = Any
ProcessedData = Any


class Brick(nn.Module):
    def __init__(self):
        super().__init__()
        self._state = {}
        self._fitted = False

    def set_fitted(self) -> None:
        self._fitted = True

    def set_state(self, state: BrickState) -> None:
        self._state = state
        self.set_fitted()

    @property
    def state(self) -> BrickState:
        assert self._fitted, "Brick is not fitted yet, call fit() before accessing state"
        return self._state

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        raise NotImplementedError()

    def forward(
        self,
        data: Data,
        fit_data: Optional[FitData],
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        raise NotImplementedError()

    def reverse(
        self,
        processed_data: ProcessedData,
        fit_data: Optional[FitData],
    ) -> Tuple[Data, Optional[ProcessedFitData]]:
        raise NotImplementedError()


class Center(Brick):
    def __init__(self):
        super().__init__()

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        mean = fit_data.mean(axis=0)
        self.set_state({"mean": mean})
        return self(fit_data, fit_data=None)[0]

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return data - self.state["mean"], fit_data

    def reverse(self, processed_data: ProcessedData, fit_data: Optional[FitData] = None) -> Data:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return processed_data + self.state["mean"], fit_data


class Scale(Brick):
    def __init__(self):
        super().__init__()

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        std = fit_data.std(axis=0)
        self.set_state({"std": std})
        return self(fit_data, fit_data=None)[0]

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return data / self.state["std"], fit_data

    def reverse(self, processed_data: ProcessedData, fit_data: Optional[FitData] = None) -> Data:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return processed_data * self.state["std"], fit_data


class CenterScale(Brick):
    def __init__(self):
        super().__init__()
        self.center = Center()
        self.scale = Scale()

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        fit_data = self.center.fit(fit_data)
        fit_data = self.scale.fit(fit_data)
        self.set_fitted()
        return fit_data

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        processed_data, _ = self.center(data)
        processed_data, _ = self.scale(processed_data)
        return processed_data, fit_data

    def reverse(self, processed_data: ProcessedData, fit_data: Optional[FitData] = None) -> Data:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        processed_data, _ = self.scale.reverse(processed_data)
        processed_data, _ = self.center.reverse(processed_data)
        return processed_data, fit_data


class Relative(Brick):
    def __init__(self):
        super().__init__()

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        self.set_state({"anchors": fit_data})
        return self(fit_data, fit_data=None)[0]

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return data @ self.state["anchors"].T, fit_data


class RelativeProjector(Brick):
    def __init__(self):
        super().__init__()
        self.center = Center()
        self.scale = Scale()
        self.relative = Relative()

    def fit(
        self,
        fit_data: FitData,
    ) -> ProcessedFitData:
        fit_data = self.center.fit(fit_data)
        fit_data = self.scale.fit(fit_data)
        fit_data = self.relative.fit(fit_data)
        self.set_fitted()
        return fit_data

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return self.relative(self.scale(self.center(data)[0])[0])[0], fit_data


class SVDTranslation(Brick):
    def __init__(self):
        super().__init__()
        self.center = Center()
        self.scale = Scale()

    def fit(
        self,
        fit_data: FitData,  # TODO: how to pass source and target data?
    ) -> ProcessedFitData:
        fit_data = self.center.fit(fit_data)
        fit_data = self.scale.fit(fit_data)
        return fit_data

    def forward(
        self, data: Data, fit_data: Optional[FitData] = None
    ) -> Tuple[ProcessedData, Optional[ProcessedFitData]]:
        if fit_data is not None:
            fit_data = self.fit(fit_data)
        return self.relative(self.scale(self.center(data)[0])[0])[0], fit_data

In [None]:
# Relative
import torch

a = torch.randn(20, 2)
anchors = torch.randn(100, 2)

relative_proj = RelativeProjector()

proc_anchors = relative_proj.fit(anchors)
print(proc_anchors.shape)

proc_a, _ = relative_proj(a)
print(proc_a.shape)

# Computational Graph API

#### Idea
- Use bricks as modular components that can be composed to build arbitrary complex bridges


#### Problems:
- How to handle different inputs? E.g. for SVDTranslator, we need to pass in the source and target data, we want types and autocomplete (no kwargs everywhere)
- How do we type everything nicely? E.g. we want to be able to say that the input of the fit method of Relative are anchors
- How do we automate the process of creating the computational graph? We want to compose the bricks easily, without specificying everything manually

#### Goal
- Implement RelativeProjector and SVDTranslator more intuitively with an (possibly automated computation graph)