In [77]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [78]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('DEVICE:',device)

DEVICE: cpu


In [79]:
trans = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root=r'C:\Users\Administrator\Desktop\Dataset',
                                                  download=True,
                                                  transform=trans,
                                                  train=True)

test_dataset = datasets.MNIST(root=r'C:\Users\Administrator\Desktop\Dataset',
                                                 download=True,
                                                 transform=trans,
                                                 train=False)

train_loader = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,drop_last=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=32,shuffle=False,drop_last=True)

In [80]:
class Encoder(nn.Module):
    def __init__(self,encoded_space_dim,fc2_input_dim):
        super().__init__()

        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=2,padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=2,padding=0),
            nn.ReLU(True)
        )

        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_lin = nn.Sequential(
            nn.Linear(3*3*32,128),
            nn.ReLU(True),
            nn.Linear(128,encoded_space_dim)
        )

    def forward(self,x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x

In [81]:
class Decoder(nn.Module):
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()

        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim,128),
            nn.ReLU(True),
            nn.Linear(128,3*3*32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32,3,3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=3,stride=2,output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=16,out_channels=8,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=8,out_channels=1,kernel_size=3,stride=2,padding=1,output_padding=1)
        )

    def forward(self,x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [82]:
encoder = Encoder(encoded_space_dim=4, fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=4, fc2_input_dim=128)
encoder = encoder.to(device)
decoder = decoder.to(device)

In [83]:
def add_noise(inputs,noise_factor=0.3):
    noisy = inputs + torch.randn_like(inputs) * noise_factor
    noisy = torch.clip(noisy,0.,1.)
    return noisy

In [84]:
params_to_optimizer = [{'params':encoder.parameters()},{'params':decoder.parameters()}]
optimizer = torch.optim.Adam(params_to_optimizer,lr=0.001,weight_decay=1e-5)
loss_fn = nn.MSELoss()

num_epochs = 30

In [85]:
for epoch in tqdm(range(num_epochs)):
    train_loss = list()
    encoder.train()
    decoder.train()
    for img_batch,_ in tqdm(train_loader):
        image_noisy = add_noise(img_batch,noise_factor=0.3)
        image_noisy = image_noisy.to(device)

        encoded_data = encoder(image_noisy)
        decoded_data = decoder(encoded_data)

        loss = loss_fn(decoded_data, image_noisy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss.detach().cpu().numpy())
    avg_loss = np.mean(train_loss)

    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        origin_images = list()
        outputs = list()

        for image_batch,_ in test_dataset:
            image_batch = image_batch.to(device)
            encoded_data = encoder(img_batch)
            decoded_data = decoder(encoded_data)
            outputs.append(decoded_data.cpu())
            origin_images.append(image_batch.cpu())

        outputs = torch.cat(outputs)
        origin_images = torch.cat(origin_images)
        val_loss = loss_fn(outputs,origin_images)
    print(f'[TRAIN LOSS: {avg_loss}] [VAL LOSS: {val_loss}]')

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/468 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: [enforce fail at C:\cb\pytorch_1000000000000\work\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 40140800000000 bytes.