In [18]:
from torchvision import datasets, transforms
import torch
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

mnist = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.ToTensor())

loader = torch.utils.data.DataLoader(
    dataset=mnist,
    batch_size=32,
    shuffle=True)

lr=1e-3 
weight_decay=1e-5
n_epochs = 5

### Обучить AE собственной архитектуры на MNIST(с линейными слоями)

In [19]:
class ael(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 4)         
        )
        self.dec = nn.Sequential(
            nn.Linear(4, 16),
            nn.ReLU(),
            nn.Linear(16, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )
    def forward(self, x):
        enc = self.enc(x)
        dec = self.dec(enc)
        return dec

In [20]:
model = ael()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=lr, 
    weight_decay=weight_decay)
criterion = nn.MSELoss()

for epoch in range(n_epochs):
    for img, _ in loader:
        img = img.reshape(-1, 784)
        recon = model(img)
        loss = criterion(recon, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch+1}, loss={loss.item():.3f}') 

epoch: 1, loss=0.041
epoch: 2, loss=0.040
epoch: 3, loss=0.033
epoch: 4, loss=0.035
epoch: 5, loss=0.032


### Обучить AE собственной архитектуры на MNIST(со сверточными слоями)

In [21]:
class aecnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7)
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1), 
            nn.Sigmoid()
        )
    def forward(self, x):   
        enc = self.enc(x) 
        dec = self.dec(enc) 
        return dec

In [22]:
model = aecnn()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=lr, 
    weight_decay=weight_decay)
criterion = nn.MSELoss()
n_epochs = 5

for epoch in range(n_epochs):
    for (img, _) in loader:
        rec = model(img)
        loss = criterion(rec, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch+1}, loss={loss.item():.3f}') 

epoch: 1, loss=0.007
epoch: 2, loss=0.005
epoch: 3, loss=0.004
epoch: 4, loss=0.003
epoch: 5, loss=0.003


### Обучить VAE собственной архитектуры на MNIST

In [23]:
class vae(nn.Module):
    def __init__(self):
        super().__init__()
        self.mu = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7)
        )
        self.sigma = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7)
        )
        self.enc = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7)
        )
        self.dec= nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=3),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        enc = self.enc(x) 
        mu = self.mu(enc)
        sigma = torch.exp(self.sigma(enc))
        sample = mu + torch.normal(torch.zeros_like(sigma), torch.ones_like(sigma)) * sigma
        dec = self.dec(sample)
        return dec

In [25]:
model = vae()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=lr, 
    weight_decay=weight_decay
)
criterion = nn.MSELoss()

for epoch in range(n_epochs):
    for (img, _) in loader:
        recon = model(img)
        loss = criterion(recon, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch+1}, loss={loss.item():.3f}') 

epoch: 1, loss=0.015
epoch: 2, loss=0.011
epoch: 3, loss=0.009
epoch: 4, loss=0.008
epoch: 5, loss=0.008
