# Abstract Factory

## O que é?

Esse padrão prevê a criação de uma interface única para para elementos que fazem parte de um grupo ou que dependem uns nos outros.

## Por quê?

Muitas vezes diferentes classes devem ser usadas em diferentes contextos ou em conjunto com diferentes objetos, ainda que elas desempenhem a mesma função. O padrão _abstract factory_ pertide classes específicas não sejam _hardcoded_, tornando o código mais flexível.

## Estrutura 

![estrutura](assets/estrutura.png)

## Exemplo 1

Uma companhia SaaS usa _machine learning_ para previnir fraude. Dois tipos de modelos são usados em produção, um simples model multiplicativo chamado _Logistic Regression_ e um modelo baseado em árvores de decisões chamado _Random Forest_. Cada modelo utiliza um método de optimização de parametros diferent, _Grid Search_ e _Random Search_ respectivamente.

Vamos ver como isso functionaria na prática.

In [68]:
import abc
import random

class BaseModel(abc.ABC):
    """
    Define the base interface for models
    """
    
    @abc.abstractmethod
    def fit(self, X: "List[List[float]]", y: "List[int]"):
        """
        Train the model
        """
        pass
    
    def predict(self, X):
        print(f"{self.__class__.__name__} has predicted something...")

class LogisticRegression(BaseModel):
    def __init__(self, l1_reg: float = 0, l2_reg: float = 1):
        self._l1_reg = l1_reg
        self._l2_reg = l2_reg
        
    def fit(self, X: "List[List[float]]", y: "List[int]"):
        print(f"Fitted using l1 of: {self._l1_reg} and l2 of {self._l2_reg}")
        
    def __repr__(self):
        return f"LogistRegression(l1_reg={self._l1_reg}, l2_reg={self._l2_reg})"
        
class RandomForest(BaseModel):
    def __init__(self, max_depth: int = 50, n_estimators: int = 20):
        self._max_depth = max_depth
        self._n_estimators = n_estimators
        
    def fit(self, X: "List[List[float]]", y: "List[int]"):
        print(
            f"Fitted using max_depth of: {self._max_depth} and {self._n_estimators} estimators"
        )

    def __repr__(self):
        return f"RandomForest(max_depth={self._max_depth}, n_estimators={self._n_estimators})"
        
class BaseTuner(abc.ABC):
    """
    Defines the interface for hyper-parameter tuning classes.
    """
    @abc.abstractmethod
    def tune(self) -> "BaseModel":
        """
        Tune the parameters of a model.
        """
        pass
        
class GridSearch(BaseTuner):
    def __init__(self, model : "BaseModel", params: dict):
        self._model = model.__class__
        self._params = params
              
    def tune(self) -> "BaseModel":
        print("grid searching...")
        best_params = {param: random.randint(10, 50) for param in self._params}
        return self._model(**best_params)
        
class RandomSearch(BaseTuner):
    def __init__(self, model : "BaseModel", params: dict):
        if not isinstance(model, RandomForest):
            raise ValueError(f"{model.__class__} is not supported.")
        self._model = model.__class__
        self._params = params

    def tune(self) -> "BaseModel":
        print("randomly searching...")
        best_params = {param: random.randint(10, 50) for param in self._params}
        return self._model(**best_params)

No dia a dia...

In [69]:
X = [list(range(3)) for _ in range(10)]
y = [random.randrange(0, 2) for _ in range(10)]

In [70]:
X

[[0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2],
 [0, 1, 2]]

In [71]:
y

[1, 1, 1, 0, 1, 1, 0, 0, 1, 1]

In [73]:
rf = RandomForest()

params = {"max_depth": range(10, 50, 5)}

tuned_rf = RandomSearch(rf, params).tune()
tuned_rf

randomly searching...


RandomForest(max_depth=14, n_estimators=20)

In [74]:
lr = LogisticRegression()

params = {"l1_reg": range(10, 50, 5)}

tuned_lr = GridSearch(lr, params).tune()
tuned_lr

grid searching...


LogistRegression(l1_reg=16, l2_reg=1)

Mas por falta de atencao, um usuario acaba intruduzindo um _bug_ ao tentar optimizar o model utilizando o _tuner_ errado.

In [75]:
lr = LogisticRegression()

params = {"l1_reg": range(10, 50, 5)}

tuned_lr = RandomSearch(lr, params).tune()
tuned_lr

ValueError: <class '__main__.LogisticRegression'> is not supported.

Utilizando o pardrao _Abstract Facotry_ nos podemos mitigar tal problema.

In [85]:
class ModelStackFactory(abc.ABC):
    @abc.abstractmethod
    def get_model(self, **params: "Any") -> "BaseModel":
        """
        Instantiate a model.
        """
        pass
    
    def get_tuner(self, model: "BaseModel") -> "BaseTuner":
        """
        Return the instance of tuner used to tune the model parameters.
        """
        pass
    

class RandomForestFactory(ModelStackFactory):
    def get_model(self, **params: "Any") -> "BaseModel":
        return RandomForest(**params)
    
    def get_tuner(self, model: "BaseModel", params: dict = None) -> "BaseTuner":
        params = params or {"max_depth": range(10, 100, 10)}
        return RandomSearch(model, params)
    

class LogisticRegressionFactory(ModelStackFactory):
    def get_model(self, **params: "Any") -> "BasseModel":
        return LogisticRegression(**params)
    
    def get_tuner(self, model: "BaseModel", params: dict = None) -> "BaseTuner":
        params = params or {"l1_reg": range(10, 100, 10)}
        return GridSearch(model, params)    

In [86]:
factory = RandomForestFactory()

model = factory.get_model(max_depth=50)
tuner = factory.get_tuner(model)
tuned = tuner.tune()
tuned

randomly searching...


RandomForest(max_depth=50, n_estimators=20)

In [87]:
factory = LogisticRegressionFactory()

model = factory.get_model(l2_reg=50)
tuner = factory.get_tuner(model)
tuned = tuner.tune()
tuned

grid searching...


LogistRegression(l1_reg=31, l2_reg=1)

## Exemplo 2

A nossa companhia trabalha com uma database relacional e de documents local, porem pretende migrar para um _Cloud Provider_. O _Abstract Factory_ pode ser um padrao para tornar a migracao simples.

In [98]:
import os


class SQLClient(abc.ABC):
    @abc.abstractmethod
    def query(self, query: str) -> iter:
        """
        Run a SQL query and return an iterator with the results
        """
        pass
    
    
class CloudSQL(SQLClient):
    def __init__(self, project_id: str, auth_key: str):
        self._project_id = project_id
        self._auth_key = auth_key
        
    def _auth(self):
        return self._project_id and self._auth_key
        
    def query(self, query: str) -> iter:
        if not self._auth():
            raise ValueError("Invalid credentials")
        
        print("Queruing cloud SQL.")
        return [{"id": i, "customer": f"cust_{i}"} for i in range(10)]
    
    
class PGSQL(SQLClient):
    def __init__(self, db: str):
        self._db = db
        
    def query(self, query: str) -> iter:
        print("Querying Postgre.")
        return [{"id": i, "customer": f"cust_{i}"} for i in range(10)]
    
    
class DocumentClient(abc.ABC):
    @abc.abstractmethod
    def query(self, query: dict) -> iter:
        """
        Query the document database and returns an iterable with the results
        """
        pass
    
    
class MongoClient(DocumentClient):
    def query(self, query: dict) -> iter:
        print("Querying mongo")
        return [{"id": i, "document": f"doc_{i}"} for i in range(10)]

    
class MongoAtlasClient(DocumentClient):
    def query(self, query: dict) -> iter:
        print("Querying mongo atlas")
        return [{"id": i, "document": f"doc_{i}"} for i in range(10)]
    
    
class DAOFactory(abc.ABC):
    @abc.abstractmethod
    def create_sql_client(self) -> "SQLClient":
        """
        Create a client for a SQL database.
        """
        pass
    
    @abc.abstractmethod
    def create_document_client(self) -> "DocumentClient":
        """
        Create a client for a document based database.
        """
        pass
    
class LocalDAOFactory(DAOFactory):
    def __init__(self):
        self._db_name = "blah"
    
    def create_sql_client(self) -> "SQLClient":
        return PGSQL(self._db_name)
    
    def create_document_client(self) -> "DocumentClient":
        return MongoClient()
    

class CloudDAOFactory(DAOFactory):
    def __init__(self):
        self._project_id = "design-patterns"
        self._auth_key = os.getenv("CLOUD_AUTH_KEY", "123")
    
    def create_sql_client(self) -> "SQLClient":
        return CloudSQL(self._project_id, self._auth_key)
    
    def create_document_client(self) -> "DocumentClient":
        return MongoAtlasClient()

In [99]:
factories = {
    "cloud": CloudDAOFactory,
    "local": LocalDAOFactory,
}

In [103]:
# define the current context
ctx = os.getenv("DAO_CONTEXT", "local")

# get the correct factory
factory_cls = factories[ctx]
factory = factory_cls()

sql_client = factory.create_sql_client()
print(sql_client.query("SELECT * FROM mock"))

doc_client = factory.create_document_client()
print(doc_client.query({"id": 1}))

Querying Postgre.
[{'id': 0, 'customer': 'cust_0'}, {'id': 1, 'customer': 'cust_1'}, {'id': 2, 'customer': 'cust_2'}, {'id': 3, 'customer': 'cust_3'}, {'id': 4, 'customer': 'cust_4'}, {'id': 5, 'customer': 'cust_5'}, {'id': 6, 'customer': 'cust_6'}, {'id': 7, 'customer': 'cust_7'}, {'id': 8, 'customer': 'cust_8'}, {'id': 9, 'customer': 'cust_9'}]
Querying mongo
[{'id': 0, 'document': 'doc_0'}, {'id': 1, 'document': 'doc_1'}, {'id': 2, 'document': 'doc_2'}, {'id': 3, 'document': 'doc_3'}, {'id': 4, 'document': 'doc_4'}, {'id': 5, 'document': 'doc_5'}, {'id': 6, 'document': 'doc_6'}, {'id': 7, 'document': 'doc_7'}, {'id': 8, 'document': 'doc_8'}, {'id': 9, 'document': 'doc_9'}]


## Pros e cons

__Pros__:

- Tem as mesma vantagens do padrao _Factory Method_:
    - abstrai o processo de criacao de objetos, simplificando a interface para o client
    - insentiva o desenvolvedor a programar para uma interface e nao para um classe (promovendo _decoupling_)
- Torna facil a criacao de novas familias de objetos
- Garante que os objetos criados pela _factory_ funcionarao bem entre si
    
__Cons__:

- A criacao de um novo metodo e custosa, pois todas as subclasses devem ser reescritas

## Reflexao

Dado que _Abstract Factories_ sao compostas por _Factory Methods_ e cada _Factory Method_ tem um unica assinatura, isso significa que podemos criar objetos de uma unica maneira? (e.g.: stdout logger vs file logger)