In [5]:
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

if torch.cuda.is_available():
    device = torch.device('cuda')  # 使用GPU
    print("GPU is available.")
else:
    device = torch.device('cpu')  # 使用CPU
    print("GPU is not available, using CPU.")


GPU is available.


In [6]:
class ImageSequenceDataset(Dataset):
    def __init__(self, image_folder):
        self.image_folder = image_folder
        self.images = sorted(os.listdir(image_folder), key=lambda x: int(x.split('.')[0]))  # 按序号排序
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),  # 调整图片大小
            transforms.ToTensor(),  # 转换为张量
        ])

    def __len__(self):
        return len(self.images) - 1  # 返回序列数量

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.images[idx])
        img = Image.open(img_path).convert('RGB')  # 读取图片并转换为RGB
        img = self.transform(img)
        
        next_img_path = os.path.join(self.image_folder, self.images[idx + 1])
        next_img = Image.open(next_img_path).convert('RGB')
        next_img = self.transform(next_img)

        return img, next_img


In [7]:
class ImageLSTM(nn.Module):
    def __init__(self):
        super(ImageLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=256 * 256 * 3, hidden_size=256, num_layers=2, batch_first=True)
        self.fc = nn.Linear(256, 256 * 256 * 3)  # 输出下一张图像

    def forward(self, x):
        batch_size, seq_length, C, H, W = x.size()
        x = x.view(batch_size, seq_length, -1)  # 重塑为 (batch_size, seq_length, input_size)
        out, _ = self.lstm(x)  # LSTM层
        out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出
        out = out.view(batch_size, 3, 256, 256)  # 变形为图像
        return out


In [8]:
# 设置参数
image_folder = 'overDataSet'  # 替换为你的图片文件夹路径
batch_size = 4
learning_rate = 0.001
num_epochs = 10

# 创建数据集和数据加载器
dataset = ImageSequenceDataset(image_folder)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、损失函数和优化器
model = ImageLSTM().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for i, (imgs, next_imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        next_imgs = next_imgs.to(device)
        inputs = imgs.unsqueeze(1)  # 为LSTM添加时间步维度 (batch_size, seq_length, C, H, W)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, next_imgs)  # 计算损失
        loss.backward()
        optimizer.step()

        if i % 10 == 0:  # 每10个batch输出一次loss
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}')


Epoch [1/10], Step [0/60], Loss: 0.0074
Epoch [1/10], Step [10/60], Loss: 0.0117
Epoch [1/10], Step [20/60], Loss: 0.0052
Epoch [1/10], Step [30/60], Loss: 0.0180
Epoch [1/10], Step [40/60], Loss: 0.0156
Epoch [1/10], Step [50/60], Loss: 0.0087
Epoch [2/10], Step [0/60], Loss: 0.0137
Epoch [2/10], Step [10/60], Loss: 0.0103
Epoch [2/10], Step [20/60], Loss: 0.0092
Epoch [2/10], Step [30/60], Loss: 0.0145
Epoch [2/10], Step [40/60], Loss: 0.0100
Epoch [2/10], Step [50/60], Loss: 0.0097
Epoch [3/10], Step [0/60], Loss: 0.0099
Epoch [3/10], Step [10/60], Loss: 0.0065
Epoch [3/10], Step [20/60], Loss: 0.0098
Epoch [3/10], Step [30/60], Loss: 0.0086
Epoch [3/10], Step [40/60], Loss: 0.0150
Epoch [3/10], Step [50/60], Loss: 0.0073
Epoch [4/10], Step [0/60], Loss: 0.0044
Epoch [4/10], Step [10/60], Loss: 0.0118
Epoch [4/10], Step [20/60], Loss: 0.0061
Epoch [4/10], Step [30/60], Loss: 0.0037
Epoch [4/10], Step [40/60], Loss: 0.0049
Epoch [4/10], Step [50/60], Loss: 0.0143
Epoch [5/10], Step [

In [15]:
# 预测下一张图片
model.eval()
with torch.no_grad():
    last_img_path = os.path.join(image_folder, '240.png')  # 最后一张图片的名称
    last_img = Image.open(last_img_path).convert('RGB')
    last_img = transforms.Resize((256, 256))(last_img)
    last_img = transforms.ToTensor()(last_img).unsqueeze(0)  # 添加批量维度

    # 输入为当前图像，seq_length = 1
    last_images = last_img.unsqueeze(1).to(device)  # 变为 (1, 1, 3, 256, 256)

    predicted_img = model(last_images)
    predicted_img = predicted_img.squeeze(0).permute(1, 2, 0)  # 变形为HWC
    predicted_img = (predicted_img.cpu().numpy() * 255).astype('uint8')

    # 保存或显示预测结果
    Image.fromarray(predicted_img).save('LSTMpredicted_image.png')
