In [None]:
from utils import *

In [None]:
# params
lr = 0.02
batch_size = 65536
# batch_size = None
quick_run = {
'max_epochs': None,
# 'limit_train_batches': 0.1,
# 'limit_val_batches': 0.1,
# 'limit_test_batches': 0.1,
}
fast_dev_run_kwargs = {'fast_dev_run': True, 'enable_checkpointing': False}
overfit_batches_kwargs = {'overfit_batches': True, 'enable_checkpointing': False}
large_model = {'precision': "16-mixed"}
grad_accum = {'accumulate_grad_batches': 7}
resume_training = {'ckpt_path': 'path/to/ckpt'}
arcitecture_name = 'autoencoder' # convencoder, autoencoder
experiemnt_name = "mnist-autoencoder"
run_name = arcitecture_name
dir_artifacts = dir_artifacts/arcitecture_name
dir_artifacts.mkdir(exist_ok=True)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, n_in, h1, h2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(n_in, h1), nn.ReLU(), nn.Linear(h1, h2))
        self.decoder = nn.Sequential(nn.Linear(h2, h1), nn.ReLU(), nn.Linear(h1, n_in))
    
    def forward(self, x):
        x = x.view(x.size(0), 28*28)
        encoded = self.encoder(x)
        x_hat = self.decoder(encoded)
        return x_hat

    def encode(self, x):
        return self.encoder(x)


class ConvAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), # Assuming MNIST images are 1x28x28
            nn.ReLU(),
            nn.Conv2d(16, 8, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 3, 7) # This will output 3 numbers
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(3, 8, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid() # Using Sigmoid for output because MNIST pixels are in range [0, 1]
        )

    def forward(self, x):
        x = x.view(x.size(0), 1, 28, 28)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    def encode(self, x):
        return self.encoder(x.view(x.size(0),1,28,28))


class MNISTAutoEncoder(L.LightningModule):
    def __init__(self, model, lr:float=2e-2):
        super().__init__()
        self.model = model
        self.loss_func = F.mse_loss
        self.lr = lr
        self.example_input_array = torch.randn(5, 784)
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model(x)

    def _step(self, batch, idx, set_name: str):
        x,y=batch
        x_hat = self(x)
        loss = self.loss_func(x_hat.view(x.size(0), -1), x)
        self.log(f'{set_name}_loss', loss, prog_bar=True)
        return loss

    def training_step(self, batch, idx):
        return self._step(batch, idx, 'train')

    def validation_step(self, batch, idx):
        return self._step(batch, idx, 'valid')

    def test_step(self, batch, idx):
        return self._step(batch, idx, 'test')

    def predict_step(self, batch, idx,  dataloader_idx=0):
        return self(batch[0])

    def configure_optimizers(self):
        return optim.Adam(params=self.parameters(), lr=self.lr)


In [None]:
mlflow.set_tracking_uri('file://' + dir_mlruns.as_posix())
mlflow.set_experiment(experiment_name=experiemnt_name)
mlflow.pytorch.autolog()
with mlflow.start_run(run_name=run_name) as run:
    data = MNISTDataModule(dir_mnist.as_posix())
    arcitectures = {
        'autoencoder': AutoEncoder(784, 64, 3),
        'convencoder': ConvAutoEncoder(),
    }
    arcitecture = arcitectures[arcitecture_name]
    callbacks = [
        ModelCheckpoint(every_n_epochs=2),
        EarlyStopping(monitor="valid_loss"),
        StochasticWeightAveraging(swa_lrs=1e-2),
    ]
    mlf_logger = MLFlowLogger(
        experiment_name=experiemnt_name,
        run_id=mlflow.active_run().info.run_id,
        log_model=True,
        tracking_uri=uri_mlruns
    )
    trainer = L.Trainer(callbacks=callbacks, logger=mlf_logger, **quick_run)
    
    model = MNISTAutoEncoder(model=arcitecture, lr=lr)
    data = MNISTDataModule(dir_mnist.as_posix(), batch_size=batch_size)
    trainer.fit(model, datamodule=data)

    trainer.predict(model, datamodule=data)
    trainer.test(model, data)
    mlflow.log_artifacts(dir_artifacts.as_posix())

mlflow.pytorch.autolog(disable=True)
launch_mlflow_ui(uri=uri_mlruns, run=run)

In [None]:
data.setup('predict')
x = data.ds_predict.x
y = data.ds_predict.y
x_hat = model(x)

encoded = model.model.encode(x).detach().numpy()
df = pd.DataFrame(encoded)
columns = ['e0', 'e1', 'e2']
df.columns = columns
df['lbl'] = y.detach().numpy().astype(str)

x,y,x_hat = map(lambda t:t.detach().numpy(), (x,y,x_hat))

In [None]:
i = 16
fig, ax = plt.subplots(1, 2)
ax[0].imshow(x[i].reshape(28,28), cmap='gray')
ax[0].axis('off')
ax[1].imshow(x_hat[i].reshape(28,28), cmap='gray')
ax[1].axis('off')
plt.show()

In [None]:
f = px.scatter_3d(df, *columns, color='lbl')
f.update_traces(marker_size=3)
f

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
df = pd.DataFrame(pca.fit_transform(x))
columns = ['e0', 'e1', 'e2']
df.columns = columns
df['lbl'] = y.astype(str)

In [None]:
f = px.scatter_3d(df, *columns, color='lbl')
f.update_traces(marker_size=3)
f