In [100]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.nn import functional as F

In [101]:
dir_path = 'overDataSet'
num_epochs = 20
learning_rate = 0.001
BATCH_SIZE = 16
SEQ_SIZE = 17
save_interval = 5

In [102]:
# 数据集类
class SeqDataset(Dataset):
    def __init__(self, dir_path, seq_size=17, transform=None):
        self.img_paths = sorted(os.listdir(dir_path), key=lambda x: int(x[:-4]))  # 按数字顺序排序
        self.seq_size = seq_size
        self.transform = transform

    def __len__(self):
        return len(self.img_paths) - self.seq_size

    def __getitem__(self, index):
        img_seq = []
        for i in range(self.seq_size):
            img_path = os.path.join(dir_path, self.img_paths[index + i])
            img = Image.open(img_path)
            if self.transform:
                img = self.transform(img)
            img_seq.append(img)
        label_path = os.path.join(dir_path, self.img_paths[index + self.seq_size])
        label = Image.open(label_path)
        if self.transform:
            label = self.transform(label)
        return torch.stack(img_seq), label

In [103]:
# 图像变换
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整大小
    transforms.ToTensor(),  # 转换为Tensor
])

# 创建数据集和数据加载器
train_data = SeqDataset(dir_path=dir_path, seq_size=SEQ_SIZE, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [104]:
# 编码器
class EncoderMUG2d_LSTM(nn.Module):
    def __init__(self, input_nc=3, encode_dim=1024, lstm_hidden_size=1024, seq_len=SEQ_SIZE, num_lstm_layers=1, bidirectional=False):
        super(EncoderMUG2d_LSTM, self).__init__()
        self.seq_len = seq_len
        self.lstm_hidden_size = lstm_hidden_size
        self.num_directions = 2 if bidirectional else 1
        self.num_lstm_layers = num_lstm_layers

        self.encoder = nn.Sequential(
            nn.Conv2d(input_nc, 32, 4, 2, 1),  # 32*64*64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, 2, 1),  # 64*32*32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1),  # 128*16*16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1),  # 256*8*8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1),  # 512*4*4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 512, 4, 2, 1),  # 512*2*2 
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 4, 2, 1),  # 1024*1*1
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.fc = nn.Linear(1024, encode_dim)
        self.lstm = nn.LSTM(encode_dim, lstm_hidden_size, num_layers=num_lstm_layers, batch_first=True, bidirectional=bidirectional)

    def forward(self, x):
        B = x.size(0)
        print(f"B (batch_size): {B}")
        print(f"x shape before reshape: {x.shape}")
        x = x.view(B * SEQ_SIZE, 3, 128, 128)  # [batchsize*seqsize, 3, 128, 128]
        print(f"x shape after reshape: {x.shape}")
        x = self.encoder(x)  # [batchsize*seqsize, 1024, 1, 1]
        print(f"x shape after encoder: {x.shape}")
        x = x.view(-1, 1024)  # [batchsize * seqsize, 1024]
        print(f"x shape after view: {x.shape}")
        x = self.fc(x)  # [batchsize * seqsize, encode_dim]
        print(f"x shape after fc: {x.shape}")
        x = x.view(B, SEQ_SIZE, x.size(1))  # [batchsize , seqsize , encode_dim][16,17,1024]
        print(f"x shape after final view: {x.shape}")
        
        h0 = torch.zeros(self.num_directions * self.num_lstm_layers, B, self.lstm_hidden_size).to(x.device)
        c0 = torch.zeros(self.num_directions * self.num_lstm_layers, B, self.lstm_hidden_size).to(x.device)
        print(f"h0 shape: {h0.shape}")
        print(f"c0 shape: {c0.shape}")
        output, (hn, cn) = self.lstm(x, (h0, c0))
        return hn[-1]  # 取最后一层的隐藏状态

# 解码器
class DecoderMUG2d(nn.Module):
    def __init__(self, output_nc=3, encode_dim=1024):
        super(DecoderMUG2d, self).__init__()
        self.project = nn.Sequential(
            nn.Linear(encode_dim, 1024 * 1 * 1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4),  # 512*4*4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, stride=2),  # 256*10*10
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4),  # 128*13*13
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2),  # 64*28*28
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, output_nc, 4, stride=2, padding=1),  # 3*128*128
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.project(x)
        x = x.view(-1, 1024, 1, 1)
        return self.decoder(x)

# 综合模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encoder = EncoderMUG2d_LSTM()
        self.decoder = DecoderMUG2d()

    def forward(self, x):
        encoded = self.encoder(x)
        return self.decoder(encoded)

In [105]:
# 训练过程
def train_model():
    model = Net()
    if torch.cuda.is_available():
        model.cuda()
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = nn.MSELoss()

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss = 0.0

        for batch_x, batch_y in train_loader:
            inputs, labels = batch_x.cuda(), batch_y.cuda()
            #print(f"inputs shape: {inputs.shape}")  # 确保输入形状正确
            optimizer.zero_grad()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f'Loss: {train_loss / len(train_loader):.4f}')

        # 保存模型
        if (epoch + 1) % save_interval == 0:
            torch.save(model.state_dict(), f'model_epoch_{epoch + 1}.pth')



train_model()

Epoch 1/20
B (batch_size): 16
x shape before reshape: torch.Size([16, 17, 3, 128, 128])
x shape after reshape: torch.Size([272, 3, 128, 128])
x shape after encoder: torch.Size([272, 1024, 1, 1])
x shape after view: torch.Size([272, 1024])
x shape after fc: torch.Size([272, 1024])
x shape after final view: torch.Size([16, 17, 1024])
h0 shape: torch.Size([1, 16, 1024])
c0 shape: torch.Size([1, 16, 1024])


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (56) must match the size of tensor b (128) at non-singleton dimension 3