In [107]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset

import numpy as np
import os

In [108]:
# 包含四个下采样层和四个上采样层的 U-Net 结构
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super(UNet, self).__init__()

        # 定义下采样层
        self.down = nn.ModuleList()
        # shape变化: x -> (N, 512, w/2, h/2)
        self.down.append(self._make_conv_block(in_channels, 64))
        self.down.append(self._make_conv_block(64, 128))
        self.down.append(self._make_conv_block(128, 256))
        self.down.append(self._make_conv_block(256, 512))
        # 最大池化层，使尺寸下降，下采样
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # 定义上采样层
        self.up = nn.ModuleList()
        # (N, 512, w/2, h/2) -> (N, 256, w, h)
        self.up.append(self._make_upconv(512, 256))
        # (N, 256, w, h) + (N, 256, w, h) -> (N, 256, w, h)
        self.up.append(self._make_conv_block(256 + 256, 256))
        # (N, 256, w, h) -> (N, 128, 2w, 2h)
        self.up.append(self._make_upconv(256, 128))
        # (N, 128, 2w, 2h) + (N, 128, 2w, 2h) -> (N, 128, 2w, 2h)
        self.up.append(self._make_conv_block(128 + 128, 128))
        # (N, 128, 2w, 2h) -> (N, 64, 4w, 4h)
        self.up.append(self._make_upconv(128, 64))
        # (N, 64, 4w, 4h) + (N, 64, 4w, 4h) -> (N, 64, 4w, 4h)
        self.up.append(self._make_conv_block(64 + 64, 64))
        # (N, 64, 4w, 4h) -> (N, 32, 8w, 8h)
        self.up.append(self._make_upconv(64, 32))
        # (N, 32, 8w, 8h) + (N, 32, 8w, 8h) -> (N, 32, 8w, 8h)
        self.up.append(self._make_conv_block(32 + 32, 32))

        # 定义输出层
        self.out = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        skips = []
        for downsample in self.down:
            x = downsample(x)
            skips.append(x)
            x = self.pool(x)
        print(f"x.shape最开始: {x.shape}")
        
        skips = list(reversed(skips[:-1]))

        for i, upsample in enumerate(self.up):
            skip = skips[i]
            print(f"skip.shape: {skip.shape}")
            x = upsample(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            print(f"x.shape然后{x.shape}")
            x = torch.cat([x, skip], dim=1)

        # 输出 sigma 平方值
        sigma2 = self.out(x)

        return sigma2

    def _make_conv_block(self, in_channels, out_channels):
        """
        定义卷积块：
        不改变特征映射的空间维度
        仅改变通道数
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def _make_upconv(self, in_channels, out_channels):
        """
        上采样块：
        将尺寸扩大到原来的两倍
        """
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

In [109]:
def custom_loss(sigma_sq, v, v_t):
    loss = 0.5 * torch.log(sigma_sq) + (v_t - v) ** 2 / (2 * sigma_sq)
    return torch.mean(loss)

In [110]:
class MyDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.data_files = sorted(os.listdir(self.data_path))

    def __len__(self):
        return len(self.data_files)

    """
    `__getitem__` 方法会在每次加载一个数据时被调用，
    它会从指定路径中读取 `.npy` 文件，并将其转换为一个 PyTorch 张量。
    然后，使用 PyTorch 提供的 `DataLoader` 类，将数据划分为批次进行训练。
    """
    def __getitem__(self, index):
        # Load data from file
        data = np.load(os.path.join(self.data_path, self.data_files[index]))
        # data = data[0:4]
        # Convert to tensor
        data = torch.from_numpy(data).float()
        return data

In [111]:
def load_data(data_path, batch_size):
    # Create data loader
    dataset = MyDataset(data_path)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return data_loader

In [112]:
def train(model, optimizer, data_loader, num_epochs, device):
    model.to(device)

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_metric = 0.0
        num_batches = 0

        for batch in data_loader:
            
            batch = batch.to(device)
            
            inputs = batch[:, :4, :, :]
            v = batch[:, 2:4, :, :]
            v_t = batch[:, 4:6, :, :]
            
            # 将梯度清零
            optimizer.zero_grad()
            # 前向传递
            sigma2 = model(inputs)
            # 计算损失和评估指标
            loss = custom_loss(inputs, v, v_t)
            metric = -loss.item()
            # 反向传播和优化
            loss.backward()
            optimizer.step()
            # 更新损失和评估指标
            epoch_loss += loss.item()
            epoch_metric += metric
            num_batches += 1

        # 计算平均损失和评估指标
        avg_loss = epoch_loss / num_batches
        avg_metric = epoch_metric / num_batches

        # 打印训练进度
        print(f"Epoch {epoch+1}/{num_epochs}: Loss={avg_loss:.4f}, Metric={avg_metric:.4f}")
        
    torch.save(model.state_dict(), 'model.pt')

In [113]:
"""
------------------------------训练部分------------------------------
"""
# 加载数据
data_path = '/home/panding/code/UR/data-chair'
batch_size = 1

my_data_loader = load_data(data_path, batch_size)

# 初始化模型、优化器和设备
net = UNet(in_channels=4, out_channels=1)
Adam_optimizer = optim.Adam(net.parameters(), lr=0.001)
my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 训练循环
my_num_epochs = 10

In [114]:
train(model=net, optimizer=Adam_optimizer, data_loader=my_data_loader, num_epochs=my_num_epochs, device=my_device)

x.shape最开始: torch.Size([1, 512, 24, 32])
skip.shape: torch.Size([1, 256, 96, 128])
x.shape然后torch.Size([1, 256, 96, 128])
skip.shape: torch.Size([1, 128, 192, 256])
x.shape然后torch.Size([1, 256, 192, 256])
skip.shape: torch.Size([1, 64, 384, 512])


RuntimeError: Given transposed=1, weight of size [256, 128, 2, 2], expected input[1, 384, 192, 256] to have 256 channels, but got 384 channels instead