Skip to content

Commit ae3bf91

Browse files
authoredAug 26, 2020
Fixed example implementation of AutoEncoder. (Lightning-AI#3190)
The previous implementation trained a auto encoder and evaluated classificator. I try to fix this by replacing the evaluation metric with an auto encoder metric. Hence, no classification is done. I'm not 100% sure what the original authors intent was, since he extends a classification model (LitMNIST) but does not use it. The following model is an AutoEncoder and does not do any classification. 1. Small textual changes. 2. forward() now implements encoding and not decoding (as it was described in the text.) 3. _shared_eval uses MSE loss instead of class loss, since no classification weights are learned. 4. initialized MSE in __init__, since calling MSE directly is not supported.
1 parent 17d8773 commit ae3bf91

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed
 

‎docs/source/child_modules.rst

+16-13
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
def val_dataloader():
1717
pass
1818

19+
def test_dataloader():
20+
pass
21+
1922
Child Modules
2023
-------------
2124
Research projects tend to test different approaches to the same dataset.
2225
This is very easy to do in Lightning with inheritance.
2326

2427
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
25-
Recall that `LitMNIST` already defines all the dataloading etc... The only things
26-
that change in the `Autoencoder` model are the init, forward, training, validation and test step.
28+
We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading.
29+
The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step.
2730

2831
.. testcode::
2932

@@ -39,18 +42,18 @@ that change in the `Autoencoder` model are the init, forward, training, validati
3942
super().__init__()
4043
self.encoder = Encoder()
4144
self.decoder = Decoder()
45+
self.metric = MSE()
4246

4347
def forward(self, x):
44-
generated = self.decoder(x)
45-
return generated
46-
48+
return self.encoder(x)
49+
4750
def training_step(self, batch, batch_idx):
4851
x, _ = batch
4952

50-
representation = self.encoder(x)
51-
x_hat = self(representation)
53+
representation = self(x)
54+
x_hat = self.decoder(representation)
5255

53-
loss = MSE(x, x_hat)
56+
loss = self.metric(x, x_hat)
5457
return loss
5558

5659
def validation_step(self, batch, batch_idx):
@@ -60,11 +63,11 @@ that change in the `Autoencoder` model are the init, forward, training, validati
6063
return self._shared_eval(batch, batch_idx, 'test')
6164

6265
def _shared_eval(self, batch, batch_idx, prefix):
63-
x, y = batch
64-
representation = self.encoder(x)
65-
x_hat = self(representation)
66+
x, _ = batch
67+
representation = self(x)
68+
x_hat = self.decoder(representation)
6669

67-
loss = F.nll_loss(logits, y)
70+
loss = self.metric(x, x_hat)
6871
result = pl.EvalResult()
6972
result.log(f'{prefix}_loss', loss)
7073
return result
@@ -78,7 +81,7 @@ and we can train this using the same trainer
7881
trainer = Trainer()
7982
trainer.fit(autoencoder)
8083
81-
And remember that the forward method is to define the practical use of a LightningModule.
84+
And remember that the forward method should define the practical use of a LightningModule.
8285
In this case, we want to use the `AutoEncoder` to extract image representations
8386

8487
.. code-block:: python

0 commit comments

Comments
 (0)
Failed to load comments.