## 1. ResNet Based Auto-Encoder

### 1-1. Basic Block

In [None]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, mode):
        super().__init__()

        if mode == 'encode':
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.resize = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, output_padding=stride-1)
            self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.resize = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=1, stride=stride, output_padding=stride-1)
        
        self.conv = nn.Sequential(
            self.conv1,
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            self.conv2,
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential(
            self.resize,
            nn.BatchNorm2d(out_channels)
        )

        self.act = nn.ReLU()

    def forward(self, x):
        identity = x
        x = self.conv(x)

        if x.shape != identity.shape:
            x += self.shortcut(identity)
        else:
            x += identity
        
        x = self.act(x)

        return x


### 1-2. Encoder & Decoder

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cfgs = [(2, 16), (2, 64), (2,256)]
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        res_blk = []
        in_channels = 16
        for (num_blk, out_channels) in self.cfgs:
            res_blk.append(BasicBlock(in_channels, out_channels, stride=2, mode="encode"))
            
            for _ in range(1, num_blk):
                res_blk.append(BasicBlock(out_channels, out_channels, stride=1, mode="encode"))
            
            in_channels = out_channels
        self.encode = nn.Sequential(*res_blk)

    def forward(self, x):
        x = self.conv(x)
        x = self.encode(x)

        return x

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cfgs = [(2,256), (2, 64), (2, 16)]

        res_blk = []
        in_channels = 256
        for (num_blk, out_channels) in self.cfgs:
            res_blk.append(BasicBlock(in_channels, out_channels, stride=2, mode="decode"))
            
            for _ in range(1, num_blk):
                res_blk.append(BasicBlock(out_channels, out_channels, stride=1, mode="decode"))
            
            in_channels = out_channels
        self.decode = nn.Sequential(*res_blk)

        self.de_conv = nn.Sequential(
            nn.ConvTranspose2d(16, 3, 3, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.decode(x)
        x = self.de_conv(x)

        return x

### 1-3. Network

In [None]:
class CNNAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)

        return latent, output

In [None]:
net = CNNAutoEncoder()
random_input = torch.randn(1, 3, 64, 64)
latent, random_output = net(random_input)
print(f"latent shape: {latent.shape}, output shape: {random_output.shape}")

## 2. Transformer Based Auto-Encoder

### 2-1. PatchEmbedding

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size % patch_size == 0,\
            f"img_size({img_size} is not divisable by patch_size({patch_size}))"

        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

### 2-2. Network

In [None]:
class TransformerAutoEncoder(nn.Module):
    def __init__(self, img_size=64, patch_size=16, in_chans=3, embed_dim=768, num_heads=12, num_layers=12):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, embed_dim))
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads
            ),
            num_layers=num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=embed_dim,
                nhead=num_heads
            ),
            num_layers=num_layers
        )
        self.proj = nn.Linear(embed_dim, in_chans * patch_size * patch_size)

    def forward(self, x):
        img_size = 64
        in_chans = 3
        
        x = self.patch_embed(x)
        x = x + self.pos_embed
        
        latent = self.encoder(x)
        x = self.decoder(x, latent)
        
        x = self.proj(x)
        x = x.reshape(x.shape[0], in_chans, img_size, img_size)
        
        return latent, x

In [None]:
net = TransformerAutoEncoder()
random_input = torch.randn(1, 3, 64, 64)
latent, random_output = net(random_input)
print(f"latent shape: {latent.shape}, output shape: {random_output.shape}")

## 3. Supervised Learning with CIFAR10Dataset

### 3-1. Library & HyperParameter

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torch.optim import Adam
from tqdm import tqdm

from torchvision.transforms import v2
from torchvision.datasets import CIFAR10
import torchmetrics

In [None]:
IMG_SIZE = 64
BATCH_SIZE = 8
LR = 1e-3
EPOCH = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 3-2. Plot Function

In [None]:
def plot_loss(save_name, train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(train_losses)+1), train_losses, label ='Train_Loss', marker ='o')
    plt.plot(range(1, len(val_losses)+1), val_losses, label ='Validation_Loss', marker ='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    title = f"{save_name}_loss"
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.savefig(f'./result/{title}.png')
    plt.close()

In [None]:
# Plot and save images
def plot_img(save_name, view_data, decoded_data, epoch):
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))
    for idx, ax in enumerate(axes.flat):
        ax.axis('off')

        if idx < 5:
            original_image = view_data[idx].detach().cpu().permute(1, 2, 0).numpy()
            print(f"Original Image - Min: {original_image.min()}, Max: {original_image.max()}")
            original_image = np.clip(original_image, 0, 1)
            ax.imshow(original_image)
            ax.set_title('Original')
        
        else:
            decoded_image = decoded_data[idx-5].detach().cpu().permute(1, 2, 0).numpy()
            print(f"Decoded Image - Min: {decoded_image.min()}, Max: {decoded_image.max()}")
            decoded_image = np.clip(decoded_image, 0, 1)
            ax.imshow(decoded_image)
            ax.set_title('Decoded')
            
    plt.tight_layout()
    plt.savefig(f'./result/epoch_{save_name}/epoch_{epoch}_images.png')
    plt.close(fig)

### 3-3. Dataset & DataLoader

In [None]:
# Oxford Dataset
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.uint8),
    v2.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
    v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
])

val_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.uint8),
    v2.CenterCrop(size=(IMG_SIZE, IMG_SIZE)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
])

# train_datasets
train_dataset = CIFAR10(root = '../../datasets/CIFAR10', download=True, train = True, transform=train_transforms)
val_dataset  = CIFAR10(root = '../../datasets/CIFAR10', download=True, train = False, transform=val_transforms)

# dataloader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

### 3-4. Train & Evaluation

In [None]:
# train code
def train(net, train_loader, criterion, optimizer, scaler, device):
    train_loss = 0.0
    
    net.train()
    for inputs, _ in tqdm(train_loader):
        inputs = inputs.to(device)
        optimizer.zero_grad()

        if scaler is not None:
            with torch.cuda.amp.autocast():
                _, outputs = net(inputs)
                loss = criterion(outputs, inputs)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            _, outputs = net(inputs)
            loss = criterion(outputs, inputs)
            
            loss.backward()
            optimizer.step()
        train_loss += loss.item()
    
    train_loss /= len(train_loader)

    return train_loss


In [None]:
# evaluate code
def eval(net, val_loader, criterion, device):
    val_loss = 0.0
    psnr = 0.0
    ssim = 0.0
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure().to(device)
    
    net.eval()
    with torch.no_grad():
        for inputs, _ in tqdm(val_loader):
            inputs = inputs.to(device)

            with torch.cuda.amp.autocast():
                _, outputs = net(inputs)
                loss = criterion(outputs, inputs)

            val_loss += loss.item()
            psnr += psnr_metric(outputs, inputs).item()
            ssim += ssim_metric(outputs, inputs).item()
        
    val_loss /= len(val_loader)
    psnr /= len(val_loader)
    ssim /= len(val_loader)
    print(f"Test Loss: {val_loss:.4f}, PSNR: {psnr:.2f}, SSIM: {ssim:.4f}")

    return val_loss

In [None]:
def trainer(net, save_name, train_loader, val_loader, criterion, optimizer, scaler, device):
    view_data = next(iter(train_loader))[0][:5]
    view_data = view_data.to(device)

    train_losses = []
    val_losses = []
    best_loss = float('inf')

    os.makedirs(f"result/epoch_{save_name}", exist_ok=True)

    for epoch in range(EPOCH):
        train_loss = train(net, train_loader, criterion, optimizer, scaler, device)
        train_losses.append(train_loss)

        val_loss = eval(net, val_loader, criterion, device)
        val_losses.append(val_loss)

        print(f"Epoch [{epoch+1}/{EPOCH}]")
        print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            weights = net.state_dict()
            torch.save(weights, f'./pth/{save_name}.pth')
        
        test_x = view_data
        _, decoded_data = net(test_x)

        plot_img(view_data, decoded_data, epoch)

    plot_loss(save_name, train_losses, val_losses)

## 4. Train CNN Auto-Encoder

In [None]:
save_name = "cnn_cifar"
net = CNNAutoEncoder().to(device)

criterion = MSELoss()
optimizer = Adam(net.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [None]:
trainer(net, save_name, train_loader, val_loader, criterion, optimizer, scaler, device)

## 5. Train Transformer Auto-Encoder

In [None]:
save_name = "transformer_cifar"
net = TransformerAutoEncoder().to(device)

criterion = MSELoss()
optimizer = Adam(net.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [None]:
trainer(net, save_name, train_loader, val_loader, criterion, optimizer, scaler, device)