In [None]:
import torch 
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid
import numpy as np
from tqdm import tqdm
import os, time
from datetime import datetime

from skimage.metrics import structural_similarity as ssim_metric
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import normalized_root_mse as nrmse_metric
from torchvision.utils import save_image
from pytorch_fid import fid_score

In [None]:
def load_latest_checkpoint(checkpoint_dir, model, optimizer):
    """
    가장 최근 수정된 체크포인트(.pth 파일)를 자동으로 찾아서 불러오는 함수입니다.

    Parameters:
    - checkpoint_dir (str): 체크포인트가 저장된 디렉토리 경로
    - model (torch.nn.Module): 학습할 모델 객체
    - optimizer (torch.optim.Optimizer): 옵티마이저 객체

    Returns:
    - int: 저장된 epoch (없으면 0 반환)
    """

    # 1. 디렉토리 내의 .pth 파일 목록을 가져옴
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]

    # 2. 체크포인트가 하나도 없으면 경고 출력 후 epoch 0 반환
    if not checkpoint_files:
        print("No Check Point Files. Start New Training.")
        return 0

    # 3. 가장 최근에 수정된 파일을 찾음
    checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    latest_checkpoint = os.path.join(checkpoint_dir, checkpoint_files[0])
    print(f"The Latest Check Point: {latest_checkpoint}")

    # 4. 체크포인트 파일 로드
    checkpoint = torch.load(latest_checkpoint)

    # 5. 모델과 옵티마이저에 state_dict 적용
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # 6. 저장된 epoch 반환
    return checkpoint.get('epoch', 0)

In [None]:
torch.__version__

In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
# set the random seed
torch.manual_seed(316)

## Load Data

In [None]:
# Device setup
def get_device():
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")
device = get_device()

transform = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ]
)

### Download

In [None]:
# Dataset
class AlohaImageDataset(Dataset):
    def __init__(self, hf_dataset, camera_key="observation.images.top"):
        self.data = hf_dataset
        self.camera_key = camera_key
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx][self.camera_key]
        image = self.transform(image)
        return image, 0

In [None]:
dataset = load_dataset("lerobot/aloha_sim_insertion_human_image")

### Split

In [None]:
full_dataset = AlohaImageDataset(dataset["train"])

len_trainset = int(len(full_dataset) * 0.8)
len_valset = len(full_dataset) - len_trainset

trainset, valset = random_split(full_dataset, [len_trainset, len_valset])

### Training Set

In [None]:
# loader for the training set
batch_size = 1
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=False
)

### Validation Set

In [None]:
# loader for the validation set
valloader = torch.utils.data.DataLoader(
    valset,
    batch_size=batch_size,
    shuffle=False
)

In [None]:
image, _ = next(iter(trainloader))
print(image.shape)

## Build Neural Network

In [None]:
# Autoencoder
class Autoencoder4x(nn.Module):
    def __init__(self):
        super(Autoencoder4x, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),  # 1 → 4채널
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # 해상도 2배 축소
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 1, kernel_size=2, stride=2),  # 원래 해상도 복원
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

model = Autoencoder4x().to(device)
print(model)

## Training and Validation

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [None]:
def add_noise(inputs, noise_factor=0.3):
    noised = inputs + torch.randn_like(inputs) * noise_factor
    noised = torch.clip(noised, 0.0, 1.0)
    return noised

### Check point

In [None]:
DATASET = "CNN"
n_epochs = 20
# CKPT_DIR = f"ckpt/{DATASET}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{n_epochs}"
CKPT_DIR = f"ckpt/{DATASET}"
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "autoencoder.pth")

start_epoch = 0
if os.path.exists(ckpt_path):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")
else:
    print("Starting new training")

In [None]:
# Training
train_loss_history = []
val_loss_history = []
start_time = time.time()

for epoch in tqdm(range(start_epoch, n_epochs)):
    train_loss = 0
    model.train()
    for images, _ in trainloader:
        images = images.to(device)
        # image_noised = add_noise(images)
        output, _ = model(images)
        # denoised, _ = model(image_noised)
        loss = criterion(output, images)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    train_loss /= len(trainloader)
    train_loss_history.append(train_loss)

    val_loss = 0
    model.eval()
    with torch.no_grad():
        for images, _ in valloader:
            images = images.to(device)
            # image_noised = add_noise(images)
            output, _ = model(images)
            loss = criterion(output, images)
            val_loss += loss.item()
    val_loss /= len(valloader)
    val_loss_history.append(val_loss)

    print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, ckpt_path)

print(f"Training complete in {time.time() - start_time:.2f} seconds.")

## Meterics

In [None]:
fig, ax = plt.subplots(figsize=(6,6))

ax.plot(train_loss_history, label="Train Loss")

ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax.set_title("Train Loss")
ax.legend()

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6,6))

ax.plot(train_loss_history, label="Train Loss")
ax.plot(val_loss_history, label="Validation Loss")

ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax.set_title("Train Loss & Validation Loss")
ax.legend()

plt.show()

## Test

In [None]:
n_test_img = 6

# loader for the testset
testloader = torch.utils.data.DataLoader(
    valset,
    batch_size=n_test_img,
    shuffle=True
)

In [None]:
model.eval()
with torch.no_grad():
    sample_images, _ = next(iter(testloader))  # 배치에서 일부 가져오기
    # sample_images = sample_images.to(device)
    # images_noised = add_noise(sample_images)
    sample_images = sample_images[:8].to(device)
    reconstructed, latent = model(sample_images)

# 압축된 latent space를 480x640으로 업샘플링 (4배 압축 대비)
latent_vis = F.interpolate(latent, size=(480, 640), mode='bilinear', align_corners=False)

# 평가 지표 계산
h, w = 120, 160
FID_REAL = "fid_temp/real_120x160"
FID_FAKE = "fid_temp/fake_120x160"
os.makedirs(FID_REAL, exist_ok=True)
os.makedirs(FID_FAKE, exist_ok=True)

psnr_list, ssim_list, nrmse_list = [], [], []

with torch.no_grad():
    for batch_idx, (images, _) in enumerate(testloader):
        images = images.to(device)
        # noised = add_noise(images) 
        decoded, _ = model(images)

        resized_original = F.interpolate(images, size=(h, w), mode='bilinear', align_corners=False)
        resized_decoded = F.interpolate(decoded, size=(h, w), mode='bilinear', align_corners=False)

        for i in range(images.size(0)):
            save_image(resized_original[i], f"{FID_REAL}/{batch_idx}_{i}.png")
            save_image(resized_decoded[i], f"{FID_FAKE}/{batch_idx}_{i}.png")

        original_np = resized_original.cpu().numpy()
        decoded_np = resized_decoded.cpu().numpy()
        for i in range(images.size(0)):
            psnr_list.append(psnr_metric(original_np[i, 0], decoded_np[i, 0], data_range=1.0))
            ssim_list.append(ssim_metric(original_np[i, 0], decoded_np[i, 0], data_range=1.0))
            nrmse_list.append(nrmse_metric(original_np[i, 0], decoded_np[i, 0]))

fid_val = fid_score.calculate_fid_given_paths([FID_REAL, FID_FAKE], batch_size=32, device=device, dims=2048)
print("=== Autoencoder Reconstruction Metrics ===")
print(f"[{h}×{w}] PSNR={np.mean(psnr_list):.3f}, NRMSE={np.mean(nrmse_list):.3f}, "
      f"SSIM={np.mean(ssim_list):.3f}, FID={fid_val:.3f}")

In [None]:
# # 1. 원본 이미지
# plt.figure(figsize=(4, 4))
# plt.imshow(images[0].cpu().squeeze(0), cmap="gray")
# plt.title("Original")
# plt.axis("off")
# plt.show()

# # 2. 압축 이미지 (latent 평균)
# compressed_image = latent_vis[0].cpu().mean(0)
# plt.figure(figsize=(4, 4))
# plt.imshow(compressed_image, cmap="gray")
# plt.title("2x Compressed")
# plt.axis("off")
# plt.show()

# # 3. 복원 이미지
# plt.figure(figsize=(4, 4))
# plt.imshow(denoised[0].cpu().squeeze(0), cmap="gray")
# plt.title("Reconstructed")
# plt.axis("off")
# plt.show()

# 1. 원본 이미지
grid = make_grid(sample_images.cpu(), nrow=8, padding=2)
plt.figure(figsize=(12, 2))
# plt.imshow(images[0].cpu().squeeze(0), cmap="gray")
plt.title("Original Images")
plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
plt.axis("off")
plt.show()

# 2. 압축 이미지 (latent 평균)
latent_mean = latent.mean(dim=1, keepdim=True)  # (B, 1, H, W)
latent_norm = (latent_mean - latent_mean.min()) / (latent_mean.max() - latent_mean.min() + 1e-8)
latent_vis = F.interpolate(latent_mean, size=(480, 640), mode='bilinear', align_corners=False)

grid_latent = make_grid(latent_vis.cpu(), nrow=8, padding=2)
plt.figure(figsize=(12, 2))
# plt.imshow(compressed_image, cmap="gray")
plt.title("Compressed Latent Images (Upsampled to 480x640)")
plt.imshow(grid_latent.permute(1, 2, 0).squeeze(), cmap="gray", vmin=0, vmax=1)
plt.axis("off")
plt.show()

# 3. 복원 이미지
grid_reconstructed = make_grid(reconstructed.cpu(), nrow=8, padding=2)
plt.figure(figsize=(12, 2))
# plt.imshow(denoised[0].cpu().squeeze(0), cmap="gray")
plt.title("Reconstructed Images")
plt.imshow(grid_reconstructed.permute(1, 2, 0).squeeze(), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
from fvcore.nn import FlopCountAnalysis
from ptflops import get_model_complexity_info

model.eval()

macs, params = get_model_complexity_info(
    model, 
    (1, 480, 640), 
    as_strings=True,
    print_per_layer_stat=False
)

print(f"[ptflops] MACs (FLOPs): {macs}")
print(f"[ptflops] Parameters: {params}")