In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

In [4]:
from typing import Tuple

EPS = 1e-7

class BernoulliMM(nn.Module):

    def __init__(
        self, 
        dims: Tuple[int, int],
        n_components: int = 10
    ):
        super().__init__()
        self._mixing_weights = nn.Parameter(
            torch.empty(n_components, dtype=torch.double),
            requires_grad=True,
        )
        self._component_probs = nn.Parameter(
            torch.empty((n_components,) + dims, dtype=torch.double),
            requires_grad=True
        )
        self._n_components = n_components
        self._initialized_parameters = False

    @property
    @torch.no_grad()
    def n_components(self):
        return self._n_components

    @property
    @torch.no_grad()
    def weights(self):
        return F.softmax(self._mixing_weights)

    @property
    @torch.no_grad()
    def probabilities(self):
        return F.sigmoid(self._component_probs)

    @torch.no_grad()
    def initialize_parameters(self, component_probs: torch.Tensor, mixing_weights: torch.Tensor):
        self._mixing_weights[:] = mixing_weights
        self._component_probs[:, :, :] = component_probs
        self._initialized_parameters = True

    def loglik(self, data: torch.Tensor):
        # broadcast the data
        batch_size = data.size(0)
        data = data.unsqueeze(1).expand(-1, self._n_components, -1, -1)  # (batch, n_components, dim_1, dim_2)
        probs = self._component_probs.sigmoid().unsqueeze(0).expand(batch_size, -1, -1, -1)
        component_prob = F.binary_cross_entropy(
            input=probs, 
            target=data, 
            reduction="none"
        ).neg().sum(dim=[-1, -2]).exp()
        mixing_weights = F.softmax(self._mixing_weights, dim=0).unsqueeze(0)
        ll = torch.sum(mixing_weights * component_prob, dim=-1).log()
        return ll

    @torch.no_grad()
    def sample(size: int):
        pass

    def fit(
        self,
        train_loader,
        val_loader,
        n_epochs: int, 
        optimizer: torch.optim.Optimizer,
    ):
        if not self._initialized_parameters:
            raise RuntimeError(
                "Please initialize the parameters with initialize_parameters() before running fit()"
            )
        train_loss = []
        val_loss = []
        for epoch in range(n_epochs):
            train_loss_running = 0.
            batch_counter = 0
            train_pbar = tqdm(total=len(train_loader))
            self.train()
            for img_batch, _ in train_loader:
                optimizer.zero_grad()
                ll = self.loglik(img_batch)
                loss = ll.mean().neg()
                batch_counter += 1
                train_loss_running += loss.detach().item()
                loss.backward()
                optimizer.step()
                train_pbar.update(1)
                train_pbar.set_description(f"Train Loss: {loss.item()}")
            train_pbar.close()
            epoch_train_loss = train_loss_running / batch_counter
            print(f"Epoch {epoch} Train Loss: {epoch_train_loss}")
            self.eval()
            with torch.no_grad():
                val_pbar = tqdm(total=len(val_loader))
                val_loss_running = 0.
                batch_counter = 0
                for img_batch, _ in val_loader:
                    ll = self.loglik(img_batch)
                    loss = ll.mean().neg()
                    val_loss_running += loss.item()
                    batch_counter += 1
                    val_pbar.update(1)
                    val_pbar.set_description(f"Val Loss: {loss.item()}")
                epoch_val_loss = val_loss_running / batch_counter
                val_loss.append(epoch_val_loss)
                print(f"Epoch {epoch} Val Loss: {epoch_val_loss}")