In [239]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from debugpy.launcher import channel
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from PIL import Image
from torch.nn import functional as F

In [240]:
dir_path = 'overDataSet'
num_epochs = 100
learning_rate = 0.001
BATCH_SIZE = 16
SEQ_SIZE = 17
save_interval = 5

In [241]:
# 数据集类
class SeqDataset(Dataset):
    def __init__(self, dir_path, seq_size=17, transform=None):
        self.img_paths = sorted(os.listdir(dir_path), key=lambda x: int(x[:-4]))  # 按数字顺序排序
        self.seq_size = seq_size
        self.transform = transform

    def __len__(self):
        return len(self.img_paths) - self.seq_size

    def __getitem__(self, index):
        img_seq = []
        for i in range(self.seq_size):
            img_path = os.path.join(dir_path, self.img_paths[index + i])
            img = Image.open(img_path)
            if self.transform:
                img = self.transform(img)
            img_seq.append(img)
        label_path = os.path.join(dir_path, self.img_paths[index + self.seq_size])
        label = Image.open(label_path)
        if self.transform:
            label = self.transform(label)
        return torch.stack(img_seq), label

In [242]:
# 图像变换
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整大小
    transforms.ToTensor(),  # 转换为Tensor
])

# 创建数据集和数据加载器
train_data = SeqDataset(dir_path=dir_path, seq_size=SEQ_SIZE, transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, drop_last=True, shuffle=False)

In [243]:
# 自定义损失函数
# ssim = SSIM(data_range=1.0, k1=0.01, k2=0.03,  reduction='sum')

In [244]:
class EncoderMUG2d_LSTM(nn.Module):
    def __init__(self, input_nc=3, encode_dim=1024, lstm_hidden_size=1024, seq_len=SEQ_SIZE, num_lstm_layers=1, bidirectional=False):
        super(EncoderMUG2d_LSTM, self).__init__()
        self.seq_len = seq_len
        self.num_directions = 2 if bidirectional else 1
        self.num_lstm_layers = num_lstm_layers
        self.lstm_hidden_size = lstm_hidden_size
        #3*128*128
        self.encoder = nn.Sequential(
            nn.Conv2d(input_nc, 32, 4,2,1), # 32*64*64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            #32*63*63
            nn.Conv2d(32, 64, 4, 2, 1), # 64*32*32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            #64*31*31
            nn.Conv2d(64, 128, 4, 2, 1), # 128*16*16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(128, 256, 4, 2, 1), # 256*8*8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(256, 512, 4, 2, 1), # 512*4*4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(512, 512, 4, 2, 1),  # 512*2*2 
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(512, 1024, 4, 2, 1),  # 1024*1*1
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
 
        )
 
        self.fc = nn.Linear(1024, encode_dim)
        self.lstm = nn.LSTM(encode_dim, encode_dim, batch_first=True)
 
    def init_hidden(self, x):
        batch_size = x.size(0)
        h = x.data.new(
                self.num_directions * self.num_lstm_layers, batch_size, self.lstm_hidden_size).zero_()
        c = x.data.new(
                self.num_directions * self.num_lstm_layers, batch_size, self.lstm_hidden_size).zero_()
        return torch.Tensor(h), torch.Tensor(c)
 
 
    def forward(self, x):
        #x.shape [batchsize,seqsize,3,128,128]
        B = x.size(0)
        x = x.view(B * SEQ_SIZE, 3, 128, 128) #x.shape[batchsize*seqsize,3,128,128]
        # [batchsize*seqsize, 3, 128, 128] -> [batchsize*seqsize, 1024,1,1]
        x = self.encoder(x)
        #[batchsize * seqsize, 1024, 1, 1]-> [batchsize*seqsize, 1024]
        x = x.view(-1, 1024)
        # [batchsize * seqsize, 1024]
        x = self.fc(x)
        # [batchsize , seqsize ,1024]
        x = x.view(-1, SEQ_SIZE, x.size(1))
        h0, c0 = self.init_hidden(x)
        output, (hn,cn) = self.lstm(x,(h0,c0))
        return hn
 
class DecoderMUG2d(nn.Module):
    def __init__(self, output_nc=3, encode_dim=1024): #output size: 64x64
        super(DecoderMUG2d, self).__init__()
 
        self.project = nn.Sequential(
            nn.Linear(encode_dim, 1024*1*1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4), # 512*4*4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(512, 256, 4, stride=2), # 256*10*10
            nn.BatchNorm2d(256),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(256, 128, 4), # 128*13*13
            nn.BatchNorm2d(128),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(128, 64, 4,stride=2),  # 64*28*28
            nn.BatchNorm2d(64),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(64, 32, 4),  # 32*31*31
            nn.BatchNorm2d(32),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(32, 16, 4,stride=2),  # 16*64*64
            nn.BatchNorm2d(16),
            nn.ReLU(True),
 
            nn.ConvTranspose2d(16, output_nc, 4, stride=2, padding=1),  # 3*128*128
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = self.project(x)
        x = x.view(-1, 1024, 1, 1)
        decode = self.decoder(x)
        return decode
 
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.n1 = EncoderMUG2d_LSTM()
        self.n2 = DecoderMUG2d()
 
    def forward(self, x):
        output = self.n1(x)
        output = self.n2(output) #B*3*128*128
        return output

In [245]:
def weighted_mse_loss(output, target, weight):
    return torch.mean(weight * (output - target) ** 2)



# 训练过程
def train_model():
    model = Net()
    if torch.cuda.is_available():
        model.cuda()
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = nn.MSELoss(reduction='sum')

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss = 0.0

        for batch_x, batch_y in train_loader:
            inputs, labels = batch_x.cuda(), batch_y.cuda()
            #print(f"inputs shape: {inputs.shape}")  # 确保输入形状正确
            optimizer.zero_grad()
            optimizer.zero_grad()
            outputs = model(inputs)
            #loss = loss_func(outputs, labels)
            # 动态计算权重，对白色背景区域赋予较小权重
            weight = torch.where(labels < 0.95, 1.0, 0.1)  # 对非白色区域赋予更高权重
            loss = weighted_mse_loss(outputs, labels, weight)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f'Loss: {train_loss / len(train_loader):.4f}')
        
        if (epoch + 1) % 1 == 0:  # 每 1 次，保存一下解码的图片和原图片
            fake_image = outputs.data.cpu().squeeze()
            real_image = labels.data.cpu().squeeze()
            if not os.path.exists('./conv_autoencoder'):
                os.mkdir('./conv_autoencoder')
            utils.save_image(fake_image, './conv_autoencoder/decode_image_{}.png'.format(epoch + 1))
            utils.save_image(real_image, './conv_autoencoder/raw_image_{}.png'.format(epoch + 1))
        if epoch >= 40:
            learning_rate = learning_rate * 0.5

        # 保存模型
    
    torch.save(model.state_dict(), f'model_csdn_lstm.pth')



train_model()

Epoch 1/100
Loss: 0.0235
Epoch 2/100
Loss: 0.0157
Epoch 3/100
Loss: 0.0113
Epoch 4/100
Loss: 0.0094
Epoch 5/100
Loss: 0.0086
Epoch 6/100
Loss: 0.0083
Epoch 7/100
Loss: 0.0082
Epoch 8/100
Loss: 0.0081
Epoch 9/100
Loss: 0.0080
Epoch 10/100
Loss: 0.0080
Epoch 11/100
Loss: 0.0080
Epoch 12/100
Loss: 0.0079
Epoch 13/100
Loss: 0.0078
Epoch 14/100
Loss: 0.0078
Epoch 15/100
Loss: 0.0077
Epoch 16/100
Loss: 0.0077
Epoch 17/100
Loss: 0.0076
Epoch 18/100
Loss: 0.0075
Epoch 19/100
Loss: 0.0076
Epoch 20/100
Loss: 0.0077
Epoch 21/100
Loss: 0.0077
Epoch 22/100
Loss: 0.0077
Epoch 23/100
Loss: 0.0075
Epoch 24/100
Loss: 0.0073
Epoch 25/100
Loss: 0.0071
Epoch 26/100
Loss: 0.0074
Epoch 27/100
Loss: 0.0071
Epoch 28/100
Loss: 0.0071
Epoch 29/100
Loss: 0.0068
Epoch 30/100
Loss: 0.0067
Epoch 31/100
Loss: 0.0069
Epoch 32/100
Loss: 0.0066
Epoch 33/100
Loss: 0.0066
Epoch 34/100
Loss: 0.0073
Epoch 35/100
Loss: 0.0073
Epoch 36/100
Loss: 0.0066
Epoch 37/100
Loss: 0.0059
Epoch 38/100
Loss: 0.0057
Epoch 39/100
Loss: 0.

KeyboardInterrupt: 