In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
import torch.utils.data as data
import pytorch_lightning as pl

In [None]:
train_set = MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())

train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 8 : 2 비율로 train_set과 valid_set 분할
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size])

In [None]:
# DataLoader 변환
train_loader = data.DataLoader(train_set)
valid_loader = data.DataLoader(valid_set)
test_loader = data.DataLoader(test_set)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)


In [None]:

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
# model
model = LitAutoEncoder(Encoder(), Decoder())

## Trainer parameter 수정

In [None]:
trainer = pl.Trainer(
    max_epochs=3,
    # 저장 경로 지정
    default_root_dir='ckpt/'
    )
trainer.fit(model, train_loader, valid_loader)
trainer.test(model, dataloaders=test_loader)

# ckpt로 훈련하기

In [None]:
model = LitAutoEncoder(Encoder(), Decoder())
trainer = pl.Trainer()

trainer.fit(model,
            train_dataloaders=train_loader,
            val_dataloaders=valid_loader,
            ckpt_path='ckpt/lightning_logs/version_0/checkpoints/epoch=1-step=96000.ckpt')