In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm  # 可选：用于进度条

In [23]:
class ReverseResNet(nn.Module):
    def __init__(self, input_shape, full_connect_shape, q, N):
        super(ReverseResNet, self).__init__()
        self.input_shape = input_shape
        self.full_connect_shape = full_connect_shape  # 格式：(C, H, W)
        self.q = q
        self.N = N

        # 从 full_connect_shape 提取通道数 C
        self.C, self.H, self.W = full_connect_shape

        # 全连接层：将输入扩展到 (C * H * W)
        self.fc = nn.Linear(np.prod(input_shape), np.prod(full_connect_shape))

        # 创建 q 个 BB
        self.bb_layers = nn.ModuleList()
        for _ in range(q):
            sb_layers = nn.ModuleList()
            for _ in range(N):
                # 动态设置 in_channels 和 out_channels 为 C
                sb_layers.append(nn.Conv2d(self.C, self.C, kernel_size=3, stride=1, padding=1))
                sb_layers.append(nn.LeakyReLU(0.01))
            self.bb_layers.append(sb_layers)

        # 最后一层：输出通道数为 1
        self.final_conv = nn.Conv2d(self.C, 1, kernel_size=3, stride=1, padding=1)

        # FvLayer：固定权重
        self.FvLayer = nn.Conv2d(1, 1, kernel_size=(1, 3), padding=0, bias=False)
        self.FvLayer.weight = nn.Parameter(torch.tensor([[[[1, -2, 1]]]], dtype=torch.float32))
        self.FvLayer.weight.requires_grad = False

    def forward(self, x):
        # 全连接层 + Reshape
        x = self.fc(x)
        x = x.view(-1, self.C, self.H, self.W)  # 调整为 (batch_size, C, H, W)

        # 处理每个 BB
        for sb_layers in self.bb_layers:
            bb_input = x
            for layer in sb_layers:
                x = layer(x)
            x = (bb_input + x) * 2

        # 最后一层
        u = self.final_conv(x)
        Fv = self.FvLayer(u)
        return u, Fv

In [24]:
# 初始化模型
input_shape = (2,)  # 输入形状
full_connect_shape = (128, 64, 16)  # 全连接层目标形状，根据 true_output0 的形状确定
q = 2  # BB 的数量
N = 3  # 每个 BB 中 SB 的数量

model = ReverseResNet(input_shape, full_connect_shape, q, N)

# 打印模型结构
print(model)

ReverseResNet(
  (fc): Linear(in_features=2, out_features=131072, bias=True)
  (bb_layers): ModuleList(
    (0-1): 2 x ModuleList(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): LeakyReLU(negative_slope=0.01)
    )
  )
  (final_conv): Conv2d(128, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (FvLayer): Conv2d(1, 1, kernel_size=(1, 3), stride=(1, 1), bias=False)
)


In [25]:
# 确保模型在评估模式（即使未训练）
model.eval()

# 构造输入张量 [1]，形状为 (1,)
# 注意：PyTorch 的输入需要是浮点型张量，且添加 batch 维度（batch_size=1）
input_tensor = torch.tensor([1.0, 2], dtype=torch.float32).unsqueeze(0)  # 形状: (1, 1)

# 不计算梯度，直接前向传播
with torch.no_grad():
    u, Fv = model(input_tensor)

# 打印输出
print("输入张量形状:", input_tensor.shape)
print("输出 u 的形状:", u.shape)
print("输出 Fv 的形状:", Fv.shape)
# print("\n输出 u 的数值示例:")
# print(u)  # 打印 u 的第一个样本、第一个通道的部分数值
# print("\n输出 Fv 的数值示例:")
# print(Fv)  # 打印 Fv 的第一个样本、第一个通道的部分数值

输入张量形状: torch.Size([1, 2])
输出 u 的形状: torch.Size([1, 1, 64, 16])
输出 Fv 的形状: torch.Size([1, 1, 64, 14])


In [26]:
class Trainer:
    def __init__(self, model, optimizer, criterion_u, criterion_Fv, device="cuda"):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.criterion_u = criterion_u
        self.criterion_Fv = criterion_Fv
        self.device = device
        self.train_losses = []
        self.val_losses = []
        
    def _train_step(self, inputs, targets_u, targets_Fv):
        self.model.train()
        self.optimizer.zero_grad()
        
        # 前向传播
        preds_u, preds_Fv = self.model(inputs)
        
        # 计算损失
        loss_u = self.criterion_u(preds_u, targets_u)
        loss_Fv = self.criterion_Fv(preds_Fv, targets_Fv)
        total_loss = loss_u + loss_Fv
        
        # 反向传播
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item()
    
    def _val_step(self, inputs, targets_u, targets_Fv):
        self.model.eval()
        with torch.no_grad():
            preds_u, preds_Fv = self.model(inputs)
            loss_u = self.criterion_u(preds_u, targets_u)
            loss_Fv = self.criterion_Fv(preds_Fv, targets_Fv)
            return (loss_u + loss_Fv).item()
    
    def fit(self, train_loader, val_loader, num_epochs=50, early_stop_patience=5, save_path="best_model.pth"):
        best_val_loss = float('inf')
        early_stop_counter = 0
        
        for epoch in range(num_epochs):
            # 训练阶段
            train_loss = 0.0
            for inputs, (targets_u, targets_Fv) in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
                inputs = inputs.to(self.device)
                targets_u = targets_u.unsqueeze(1).to(self.device)
                targets_Fv = targets_Fv.unsqueeze(1).to(self.device)
                train_loss += self._train_step(inputs, targets_u, targets_Fv)
            
            # 验证阶段
            val_loss = 0.0
            for inputs, (targets_u, targets_Fv) in val_loader:
                inputs = inputs.to(self.device)
                targets_u = targets_u.unsqueeze(1).to(self.device)
                targets_Fv = targets_Fv.unsqueeze(1).to(self.device)
                val_loss += self._val_step(inputs, targets_u, targets_Fv)
            
            # 计算平均损失
            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            
            # 早停和保存模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), save_path)
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                if early_stop_counter >= early_stop_patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
            
            print(f"Epoch {epoch+1}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    def plot_loss(self):
        import matplotlib.pyplot as plt
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion_u = nn.MSELoss()
criterion_Fv = nn.MSELoss()

In [None]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    criterion_u=criterion_u,
    criterion_Fv=criterion_Fv,
    device="cuda"
)

In [None]:
# 假设已有 train_loader 和 val_loader
train_loader = DataLoader(...)
val_loader = DataLoader(...)

In [None]:
trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=50,
    early_stop_patience=5,
    save_path="best_model.pth"
)

In [None]:
trainer.plot_loss()

In [5]:
from tqdm.notebook import tqdm
import time

# 假设你有一个迭代对象
iterable = range(100)

# 使用 tqdm 创建进度条
for i in tqdm(iterable):
    # 模拟一些工作
    time.sleep(0.1)
    
    # 如果你需要打印其他信息，可以使用 tqdm.write
    if i % 10 == 0:
        tqdm.write(f"Processing step {i}")

  0%|          | 0/100 [00:00<?, ?it/s]

Processing step 0
Processing step 10
Processing step 20
Processing step 30
Processing step 40
Processing step 50
Processing step 60
Processing step 70
Processing step 80
Processing step 90


In [4]:
from tqdm import tqdm
import time

for i in tqdm(range(100)):
    time.sleep(0.1)
    if i % 10 == 0:
        tqdm.write(f"Step {i} completed")

  2%|██▏                                                                                                          | 2/100 [00:00<00:10,  9.73it/s]

Step 0 completed


 12%|████████████▉                                                                                               | 12/100 [00:01<00:09,  9.77it/s]

Step 10 completed


 22%|███████████████████████▊                                                                                    | 22/100 [00:02<00:07,  9.77it/s]

Step 20 completed


 32%|██████████████████████████████████▌                                                                         | 32/100 [00:03<00:06,  9.76it/s]

Step 30 completed


 42%|█████████████████████████████████████████████▎                                                              | 42/100 [00:04<00:05,  9.76it/s]

Step 40 completed


 52%|████████████████████████████████████████████████████████▏                                                   | 52/100 [00:05<00:04,  9.76it/s]

Step 50 completed


 62%|██████████████████████████████████████████████████████████████████▉                                         | 62/100 [00:06<00:03,  9.76it/s]

Step 60 completed


 72%|█████████████████████████████████████████████████████████████████████████████▊                              | 72/100 [00:07<00:02,  9.76it/s]

Step 70 completed


 82%|████████████████████████████████████████████████████████████████████████████████████████▌                   | 82/100 [00:08<00:01,  9.77it/s]

Step 80 completed


 92%|███████████████████████████████████████████████████████████████████████████████████████████████████▎        | 92/100 [00:09<00:00,  9.77it/s]

Step 90 completed


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.80it/s]
