# Эксперименты с автоэнкодером

## Подготовка данных

In [13]:
import torch
import torchvision
import matplotlib.pyplot as plt

from torch import nn
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

from src.utils import plot_reconstructed, grid_plot

In [15]:
data = CIFAR10('../data/', train=True, download=True)

mean = data.data.mean(axis=(0,1,2))/255
std = data.data.std(axis=(0,1,2))/255
print(f'mean: {mean}')
print(f'std: {std}')

Files already downloaded and verified
mean: [0.49139968 0.48215841 0.44653091]
std: [0.24703223 0.24348513 0.26158784]


In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

data = CIFAR10('../data/', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=data, batch_size=16, shuffle=True)

Files already downloaded and verified


# Обучение автоэнкодера

In [19]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 3, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(12, 4, 3, stride=2, padding=1),
            nn.LeakyReLU()
            )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 12, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(12, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        coded = self.encoder(x)
        decoded = self.decoder(coded)

        return decoded

In [20]:
net = Autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)


num_epochs = 100
tb = SummaryWriter("../outputs/autoencoder_runs")

for epoch in range(num_epochs):
    for batch, _ in data_loader:
        reconstructed = net(batch)
        

        loss = criterion(reconstructed, batch)
        tb.add_scalar("Loss/train", loss, epoch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Log results with a last batch from the loop
    grid = grid_plot(batch, reconstructed)
    tb.add_image('Original vs Reconstructed', grid, epoch)

tb.close()

In [None]:
torch.save(net.state_dict(), 'outputs/autoencoder_model.pth')

In [None]:
print(image.shape)

m = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten()
    # nn.Linear(256, 128)
    )

m(image).shape