In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL.Image as Image
import numpy as np

from torchvision.transforms import ToTensor, Resize, Compose
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from matplotlib import pyplot as plt
from torchmetrics.image import StructuralSimilarityIndexMeasure


In [None]:
class ImageFile(Dataset):
    def __init__(self, image_path, transform=None):
        self.image_path = Path(image_path)
        self.transform = transform
        image_extensions = ('*.jpg', '*.jpeg', '*.png', '*.gif')
        self.image_files = [file for ext in image_extensions 
                            for file in self.image_path.glob(ext)]
    
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image = Image.open(image_file).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, image

In [None]:
transform = Compose([
    Resize((64,64)),
    ToTensor(),
])
image_path = '/opt/data/private/datasets/AnimeFace'

In [None]:
ImageDataset = ImageFile(image_path=image_path, transform=transform)
train_loader = DataLoader(ImageDataset, batch_size=256, shuffle=True)

In [None]:
fig, ax = plt.subplots(1, 10)
for i in range(10):
    train = ImageDataset[i][0].permute(1, 2, 0).numpy()
    ax[i].imshow(train)
    ax[i].axis('off')

images, _ = next(iter(train_loader))
print(images.shape)  # Should print torch.Size([32, 3, 64, 64])

fig, ax = plt.subplots(1, 10)
for i in range(10):
    ax[i].imshow(images[i].permute(1, 2, 0).numpy())
    ax[i].axis('off')

x = images[0]
x.max(), x.min(), x.mean(), x.std()  # Check the tensor values

**输入图像维度**：假设输入图像的维度为\(W\times H\times C\)，其中\(W\)是图像的宽度，\(H\)是图像的高度，\(C\)是图像的通道数。
- **卷积核维度**：设卷积核的大小为\(K\times K\)，卷积核的数量为\(N\)（即输出通道数）。
- **填充（Padding）**：通常用\(P\)表示在图像周围填充的像素数。如果是对称填充，那么在宽度和高度方向上都分别填充\(P\)个像素。
- **步幅（Stride）**：用\(S\)表示，它指的是卷积核在图像上滑动的步长。

经过卷积操作后，输出特征图的维度计算公式如下：

- **输出宽度**：$ W_{out}=\lfloor\frac{W + 2P - K}{S}\rfloor + 1 $
- **输出高度**：$ H_{out}=\lfloor\frac{H + 2P - K}{S}\rfloor + 1 $
- **输出通道数**：$ C_{out}=N $


In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.con_block = nn.Sequential(
            # 3@64x64 -> 64@32x32
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 64@32x32 -> 128@16x16
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 128@16x16 -> 256@8x8
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            # 256@8x8 -> 512@4x4
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            # 512@4x4 -> 1024@1x1
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
        )
        self.fc_block = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
    
    def forward(self, x):
        x = self.con_block(x)
        # shape: (batch_size, 1024, 1, 1)
        x = x.flatten(start_dim=1)
        # shape: (batch_size, 1024)
        x = self.fc_block(x)
        # shape: (batch_size, 256)
        return x       

encoder = Encoder(in_channels=3)
images_encoded = encoder(images)
images_encoded.shape  # Should print torch.Size([32, 256])

反卷积（也称为转置卷积）是卷积的逆操作，其维度变化公式与卷积类似，但计算方式略有不同。以下是二维反卷积的维度变化公式：
- **输入特征图维度**：假设输入特征图的维度为\(W\times H\times C\)，其中\(W\)是特征图的宽度，\(H\)是特征图的高度，\(C\)是特征图的通道数。
- **卷积核维度**：设卷积核的大小为\(K\times K\)，卷积核的数量为\(N\)（即输出通道数）。
- **填充（Padding）**：通常用\(P\)表示在特征图周围填充的像素数。如果是对称填充，那么在宽度和高度方向上都分别填充\(P\)个像素。需要注意的是，在反卷积中，填充的效果与卷积相反，它会减少输出的尺寸。
- **步幅（Stride）**：用\(S\)表示，它指的是卷积核在特征图上滑动的步长。

经过反卷积操作后，输出图像的维度计算公式如下：

- **输出宽度**：$W_{out}=(W - 1)\times S - 2P + K$
- **输出高度**：$H_{out}=(H - 1)\times S - 2P + K$
- **输出通道数**：$C_{out}=N$


In [None]:
class Decoder(nn.Module):
    def __init__(self, out_channels=3):
        super().__init__()
        self.fc_block = nn.Sequential(
            # shape: (batch_size, 256) -> (batch_size, 2048)
            nn.Linear(256, 512),
            nn.ReLU(),
            # shape: (batch_size, 512) -> (batch_size, 1024)
            nn.Linear(512, 1024),
            nn.ReLU(),
        )
        self.con_block = nn.Sequential(
            # 2048@1x1 -> 1024@4x4
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            # 1024@4x4 -> 512@8x8
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            # 512@8x8 -> 256@16x16
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 256@16x16 -> 128@32x32
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 128@32x32 -> 64@64x64
            nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )
    
    def forward(self, x):
        x = self.fc_block(x)
        # shape: (batch_size, 1024)
        x = x.view(-1, 1024, 1, 1)
        # shape: (batch_size, 1024, 1, 1)
        x = self.con_block(x)
        # shape: (batch_size, out_channels, 64, 64)
        return F.tanh(x)


decoder = Decoder(out_channels=3)
images_decoded = decoder(images_encoded)
images_decoded.shape  # Should print torch.Size([32, 3, 64, 64])

fig, ax = plt.subplots(1, 10)
for i in range(10):
    ax[i].imshow((images_decoded[i].permute(1, 2, 0).detach().numpy()+1)/2)
    ax[i].axis('off')

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.encoder = Encoder(in_channels=in_channels)
        self.decoder = Decoder(out_channels=out_channels)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
class SSIMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
    
    def forward(self, x, y):
        return 1 - self.ssim(x, y)

In [None]:
def mini_batch_train(data_loader,model,optimizer,loss_fn,device='cuda'):
    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
        # 将数据放到GPU上
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        # 置为训练状态
        model.train()
        # Step 1 - 前向计算预测值
        yhat = model(x_batch)
        # Step 2 - 计算损失
        mini_batch_loss = loss_fn(yhat, y_batch)
        # Step 3 - 计算梯度
        mini_batch_loss.backward()
        # Step 4 - 参数更新
        optimizer.step()
        optimizer.zero_grad()

        mini_batch_losses.append(mini_batch_loss.item())
                  
    loss = np.mean(mini_batch_losses)
    return loss

In [None]:
def train(model, train_loader, test_loader, loss_fn, optimizer, epochs, device='cuda'):
    model.to(device)
    loss_fn.to(device)
    # 循环轮数计数
    total_epochs = 0

    losses = []  # 每轮训练的损失

    for epoch in range(epochs):
        model.train()
        total_epochs += 1

        # 进入mini-batch的内循环
        loss = mini_batch_train(train_loader,model,optimizer,loss_fn,device)
        losses.append(loss)
        print(f"Epoch {total_epochs}/{epochs}, Loss: {loss:.4f}")
    
    return losses

In [None]:
epochs = 5
lr = 1e-3
model = AutoEncoder(in_channels=3, out_channels=3)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
loss_fn = SSIMLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_losses = train(model, train_loader, None, loss_fn, optimizer, epochs, device=device)

In [None]:
plt.style.use('fivethirtyeight')
def plot_losses():
    fig = plt.figure(figsize=(10, 4))
    plt.plot(train_losses, label='Training Loss', c='b')
    plt.yscale('log')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    return fig

fig=plot_losses()

In [None]:
fake_image = torch.randn(5,256).to(device)
with torch.inference_mode():
    reconstructed_image = model.decoder(fake_image)
reconstructed_image.shape

In [None]:
fig, ax = plt.subplots(1, 5, figsize=(5, 5))
for i in range(5):
   ax[i].imshow((reconstructed_image[i].cpu().permute(1, 2, 0).numpy()+1)/2)
   ax[i].axis('off')


In [None]:
# for i in range(5):
#     plt.imsave(f'./data/outputs/img_{i}.png', reconstructed_image[i].cpu().permute(1, 2, 0).numpy())

In [None]:
# bimport torchsummary

In [None]:
# torchsummary.summary(model, input_size=(3, 64, 64), device="cuda")