<a href="https://colab.research.google.com/github/Pmilivojevic/PyTorch/blob/main/AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn.modules.activation import ReLU
import matplotlib.pyplot as plt
from torch.ao.nn.quantized.modules.conv import ConvTranspose2d

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
transform = transforms.ToTensor()
cifar10_dataset = datasets.CIFAR10(
    root='/content/drive/MyDrive/Colab Notebooks/Dataset',
    train=True,
    download=True,
    transform=transform
)
data_loader = torch.utils.data.DataLoader(
    dataset=cifar10_dataset,
    batch_size=64,
    shuffle=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 66711358.02it/s]


Extracting /content/drive/MyDrive/Colab Notebooks/Dataset/cifar-10-python.tar.gz to /content/drive/MyDrive/Colab Notebooks/Dataset


In [None]:
dataiter = iter(data_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

tensor(0.) tensor(1.)


In [None]:

class AutoEncoderL(nn.Module):
  def __init__(self):
    super(AutoEncoderL, self).__init__()

    self.encoder = nn.Sequential(
        nn.Linear(3*32*32, 1024),
        nn.ReLU(),
        nn.Linear(1024, 256),
        nn.ReLU(),
        nn.Linear(256, 64),
        nn.ReLU(),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, 9)
    )

    self.decoder = nn.Sequential(
        nn.Linear(9, 32),
        nn.ReLU(),
        nn.Linear(32, 64),
        nn.ReLU(),
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Linear(256, 1024),
        nn.ReLU(),
        nn.Linear(1024, 3*32*32),
        nn.Sigmoid()
    )

  def forward(self, x):
    enc = self.encoder(x)
    dec = self.decoder(enc)
    return dec

In [None]:
class AutoEncoderC(nn.Module):
  def __init__(self):
    super().__init__()

    self.encoder = nn.Sequential(
        nn.Conv2d(3, 16, 3, 2, 1),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3, 2, 1),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3, 2, 1),
        nn.ReLU(),
        nn.Conv2d(64, 128, 4)
    )

    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(128, 64, 4),
        nn.ReLU(),
        nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
        nn.ReLU(),
        nn.ConvTranspose2d(16, 3, 3, 2, 1, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    return decoded

In [None]:
model = AutoEncoderC().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [None]:
num_epochs = 180
outputs = []

for epoch in range(num_epochs):
  for batch_ind, (image,_) in enumerate(data_loader):
    # image = image.to(device).reshape(image.shape[0], -1)

    recon = model(image)
    loss = criterion(recon, image)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f'Epoch: {epoch+1}, Loss: {loss.item():.4f}')
  outputs.append((epoch, image, recon))

Epoch: 1, Loss: 0.0178
Epoch: 2, Loss: 0.0146
Epoch: 3, Loss: 0.0104
Epoch: 4, Loss: 0.0096
Epoch: 5, Loss: 0.0086
Epoch: 6, Loss: 0.0080
Epoch: 7, Loss: 0.0090
Epoch: 8, Loss: 0.0075
Epoch: 9, Loss: 0.0061
Epoch: 10, Loss: 0.0079
Epoch: 11, Loss: 0.0064
Epoch: 12, Loss: 0.0050
Epoch: 13, Loss: 0.0061
Epoch: 14, Loss: 0.0060
Epoch: 15, Loss: 0.0064
Epoch: 16, Loss: 0.0054
Epoch: 17, Loss: 0.0060
Epoch: 18, Loss: 0.0059
Epoch: 19, Loss: 0.0045
Epoch: 20, Loss: 0.0049
Epoch: 21, Loss: 0.0050
Epoch: 22, Loss: 0.0040
Epoch: 23, Loss: 0.0056
Epoch: 24, Loss: 0.0062
Epoch: 25, Loss: 0.0059
Epoch: 26, Loss: 0.0049
Epoch: 27, Loss: 0.0052
Epoch: 28, Loss: 0.0048
Epoch: 29, Loss: 0.0061
Epoch: 30, Loss: 0.0052
Epoch: 31, Loss: 0.0050
Epoch: 32, Loss: 0.0063
Epoch: 33, Loss: 0.0050
Epoch: 34, Loss: 0.0054
Epoch: 35, Loss: 0.0048
Epoch: 36, Loss: 0.0052
Epoch: 37, Loss: 0.0048
Epoch: 38, Loss: 0.0050
Epoch: 39, Loss: 0.0053
Epoch: 40, Loss: 0.0045
Epoch: 41, Loss: 0.0050
Epoch: 42, Loss: 0.0054
E

In [None]:
for k in range(0, num_epochs, 6):
  # plt.figure(figsize=(9,2))

  imgs = outputs[k][1].cpu().detach().numpy()
  recons = outputs[k][2].cpu().detach().numpy()

  for i, img in enumerate(imgs):
    if i >= 9: break
    plt.subplot(2, 9, i+1)
    img = img.reshape(-1, 32, 32).transpose(1, 2, 0)
    plt.imshow(img)

  for i, rec in enumerate(recons):
    if i >= 9: break
    plt.subplot(2, 9, 9+i+1)
    rec = rec.reshape(-1, 32, 32).transpose(1, 2, 0)
    plt.imshow(rec)