In [None]:
import torch
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import sklearn

In [None]:
torch.random.manual_seed(0)
np.random.seed(0)

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = torch.nn.Linear(input_dim, hidden_dim)
        self.decoder = torch.nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        encoded = torch.sigmoid(self.encoder(x))
        decoded = torch.sigmoid(self.decoder(encoded))
        return decoded

In [None]:
class GraphEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super().__init__()
        self.autoencoders = torch.nn.ModuleList()
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            self.autoencoders.append(AutoEncoder(prev_dim, hidden_dim))
            prev_dim = hidden_dim

    def forward(self, X, train_mode, **kwargs):
        if train_mode == 'layerwise':
            layer_number = kwargs.get('layer_number', None)
            if layer_number is None or layer_number < 0 or layer_number >= len(self.autoencoders):
                raise ValueError("Invalid layer number for layerwise training")
            encoded = torch.sigmoid(self.autoencoders[layer_number].encoder(X))
            decoded = torch.sigmoid(self.autoencoders[layer_number].decoder(encoded))
            return encoded, decoded
            
        elif train_mode == 'endtoend':
            for autoencoder in self.autoencoders:
                X = torch.sigmoid(autoencoder.encoder(X))
            encoded = X
            for autoencoder in reversed(self.autoencoders):
                X = torch.sigmoid(autoencoder.decoder(X))
            decoded = X
            return encoded, decoded
    
    def train(self,
              X,
              compile,
              train_mode,
              iters,
              optimizer,
              rho=0.01,
              beta=1.0,
              batch_size=None):
        if batch_size is None:
            batch_size = X.shape[0] - 1

        if compile=="True":
            train_model = torch.compile(self)
        else:
            train_model = self
        
        if train_mode == 'layerwise':
            for layer_number in range(len(self.autoencoders)):
                for _ in tqdm.tqdm(range(iters), desc=f"Training layer {layer_number}"):
                    batch_idx = torch.randint(0, X.shape[0]-batch_size, (1,)).item()
                    X_batch = X[batch_idx : batch_idx + batch_size]
                    optimizer.zero_grad()
                    encoded, decoded = train_model(X_batch, train_mode='layerwise', layer_number=layer_number)
                    loss_1 = torch.nn.functional.mse_loss(decoded, X_batch, reduction='sum')
                    rho_hat = torch.mean(encoded, dim=0)
                    loss_2 = torch.sum(rho * torch.log(rho / rho_hat) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat)))
                    loss = loss_1 + beta * loss_2
                    loss.backward()
                    optimizer.step()
                X = torch.sigmoid(self.autoencoders[layer_number].encoder(X)).detach()
        elif train_mode == 'endtoend':
            for _ in tqdm.tqdm(range(iters)):
                batch_idx = torch.randint(0, X.shape[0]-batch_size, (1,)).item()
                X_batch = X[batch_idx : batch_idx + batch_size]
                optimizer.zero_grad()
                encoded, decoded = train_model(X_batch, train_mode='endtoend')
                loss_1 = torch.nn.functional.mse_loss(decoded, X_batch, reduction='sum')
                rho_hat = torch.mean(encoded, dim=0)
                loss_2 = torch.sum(rho * torch.log(rho / rho_hat) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat)))
                loss = loss_1 + beta * loss_2
                loss.backward()
                optimizer.step()
    
    @torch.no_grad()
    def encode(self, X):
        for autoencoder in self.autoencoders:
            X = torch.sigmoid(autoencoder.encoder(X))
        return X
    
    @torch.no_grad()
    def encode_decode(self, X):
        for autoencoder in self.autoencoders:
            X = torch.sigmoid(autoencoder.encoder(X))
        for autoencoder in reversed(self.autoencoders):
            X = torch.sigmoid(autoencoder.decoder(X))
        return X

In [None]:
X, y = sklearn.datasets.load_digits(return_X_y=True, as_frame=False)
X = X / 255.0

In [None]:
torch.manual_seed(0)
np.random.seed(0)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NB_EPOCHS = 5000
BATCH_SIZE = 256
X = torch.tensor(X, dtype=torch.float32).to(DEVICE)

graph_encoder = GraphEncoder(input_dim=X.shape[1], hidden_dims=[32, 16, 8]).to(DEVICE)
optimizer = torch.optim.Adam(graph_encoder.parameters(), lr=0.01)
nb_iters = NB_EPOCHS * (X.shape[0] // BATCH_SIZE)
graph_encoder.train(X, compile="compile", train_mode="endtoend", iters=nb_iters, optimizer=optimizer, rho=0.01, beta=1.0, batch_size=BATCH_SIZE)

In [None]:
Xh = graph_encoder.encode_decode(X).to('cpu')

In [None]:
i = torch.randint(0, X.shape[0], (1,)).item()
plt.imshow(X[i].to("cpu").reshape(8, 8), cmap='gray')
plt.show()
plt.imshow(Xh[i].to("cpu").reshape(8, 8), cmap='gray')