In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torchsummary import summary

In [48]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('DEVICE:', device)

DEVICE: cpu


In [49]:
transform = transforms.Compose([transforms.ToTensor()])

train_set = datasets.MNIST(root=r'C:\Users\Administrator\Desktop\Dataset',
                           transform=transform,
                           train=True,
                           download=True
                           )
test_set = datasets.MNIST(root=r'C:\Users\Administrator\Desktop\Dataset',
                           transform=transform,
                           train=False,
                           download=True
                           )

In [50]:
train_loader =DataLoader(train_set, batch_size=128, shuffle=True)
test_loader =DataLoader(test_set, batch_size=32, shuffle=True)

In [51]:
class Encoder(nn.Module):
    def __init__(self, encoded_dim):
        super().__init__()

        self.encoder_conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2, padding=0),
            nn.ReLU()
        )

        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_linear = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, encoded_dim)
        )

    def forward(self, x):
        x = self.encoder_conv(x)
        x = self.flatten(x)
        x = self.encoder_linear(x)
        return x

In [52]:
class Decoder(nn.Module):
    def __init__(self, encoded_dim):
        super().__init__()

        self.decoder_linear = nn.Sequential(
            nn.Linear(encoded_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU()
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))
        
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, x):
        x = self.decoder_linear(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [53]:
encoder = Encoder(encoded_dim=4)
decoder = Decoder(encoded_dim=4)
encoder = encoder.to(device)
decoder = decoder.to(device)

params_to_oprimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
optimizer = optim.Adam(params=params_to_oprimize, lr=1e-3, weight_decay=1e-5)
loss_function = nn.MSELoss()

In [54]:
EPOCHS = 5

In [55]:
def add_noise(inputs, noise_factor=0.3):
    noisy = inputs + torch.randn_like(inputs) * noise_factor
    noisy = torch.clip(noisy, 0., 1.)
    return noisy

In [56]:
for epoch in range(EPOCHS):
    encoder.train()
    decoder.train()
    train_loss = []
    for imgs, _ in train_loader:
        optimizer.zero_grad()
        image_noisy = add_noise(imgs, noise_factor=0.3)
        image_noisy = image_noisy.to(device)
        encoded_data = encoder(image_noisy)
        decoded_data = decoder(encoded_data)
        loss = loss_function(decoded_data, image_noisy)
        loss.backward()
        optimizer.step()
        train_loss.append(loss)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        val_out_list = []
        val_labels_list = []
        for val_imgs, _ in test_loader:
            val_imgs = val_imgs.to(device)
            encoded_data = encoder(val_imgs)
            decoded_data = decoder(encoded_data)
            val_out_list.append(decoded_data)
            val_labels_list.append(val_imgs)
        val_out = torch.cat(val_out_list)
        val_labels = torch.cat(val_labels_list)
        val_loss = loss_function(val_out, val_labels)