<a href="https://colab.research.google.com/github/Akbaradityafirmansyah/autoencoder-cifar10/blob/main/autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from PIL import Image
import matplotlib.pyplot as plt

# Dataset Autoencoder CIFAR-10
class CIFAR10AutoencoderDataset(Dataset):
    def __init__(self, train=True, transform_input=None, transform_output=None, max_samples=100):
        self.dataset = CIFAR10(root='./data', train=train, download=True)
        self.transform_input = transform_input
        self.transform_output = transform_output
        self.max_samples = max_samples

    def __len__(self):
        return self.max_samples

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        input_img = self.transform_input(img)
        output_img = self.transform_output(img)
        return input_img, output_img

# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Transformasi
transform_input = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

transform_output = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor()
])

# Dataset dan DataLoader
dataset = CIFAR10AutoencoderDataset(train=True, transform_input=transform_input, transform_output=transform_output, max_samples=100)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Setup model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training
loss_list = []
epochs = 20
for epoch in range(epochs):
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    loss_list.append(avg_loss)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

# Simpan grafik loss
os.makedirs("results", exist_ok=True)
plt.figure()
plt.plot(range(1, epochs+1), loss_list, marker='o')
plt.title('Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig('results/loss_plot.png')
plt.close()

# Simpan contoh hasil
model.eval()
with torch.no_grad():
    for i in range(5):
        x, y = dataset[i]
        inp = x.unsqueeze(0).to(device)
        out = model(inp).cpu().squeeze().permute(1, 2, 0).numpy()
        plt.imsave(f"results/output_{i}.png", out)
        plt.imsave(f"results/input_{i}.png", x.permute(1, 2, 0).numpy())
        plt.imsave(f"results/target_{i}.png", y.permute(1, 2, 0).numpy())

        # Simpan gabungan input-target-output
        fig, axs = plt.subplots(1, 3, figsize=(9, 3))
        axs[0].imshow(x.permute(1, 2, 0).numpy())
        axs[0].set_title('Input')
        axs[1].imshow(y.permute(1, 2, 0).numpy())
        axs[1].set_title('Target')
        axs[2].imshow(out)
        axs[2].set_title('Output')
        for ax in axs:
            ax.axis('off')
        plt.tight_layout()
        plt.savefig(f"results/comparison_{i}.png")
        plt.close()


Epoch [1/20], Loss: 0.0416
Epoch [2/20], Loss: 0.0103
Epoch [3/20], Loss: 0.0054
Epoch [4/20], Loss: 0.0031
Epoch [5/20], Loss: 0.0024
Epoch [6/20], Loss: 0.0019
Epoch [7/20], Loss: 0.0015
Epoch [8/20], Loss: 0.0013
Epoch [9/20], Loss: 0.0010
Epoch [10/20], Loss: 0.0008
Epoch [11/20], Loss: 0.0006
Epoch [12/20], Loss: 0.0006
Epoch [13/20], Loss: 0.0006
Epoch [14/20], Loss: 0.0005
Epoch [15/20], Loss: 0.0005
Epoch [16/20], Loss: 0.0005
Epoch [17/20], Loss: 0.0005
Epoch [18/20], Loss: 0.0004
Epoch [19/20], Loss: 0.0004
Epoch [20/20], Loss: 0.0004
