In [11]:
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [12]:
class ImageSequenceDataset(Dataset):
    def __init__(self, image_folder):
        self.image_folder = image_folder
        self.images = sorted(os.listdir(image_folder), key=lambda x: int(x.split('.')[0]))  # 按序号排序
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),  # 调整图片大小
            transforms.ToTensor(),  # 转换为张量
        ])

    def __len__(self):
        return len(self.images) - 1  # 返回序列数量

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.images[idx])
        img = Image.open(img_path).convert('RGB')  # 读取图片并转换为RGB
        img = self.transform(img)
        
        next_img_path = os.path.join(self.image_folder, self.images[idx + 1])
        next_img = Image.open(next_img_path).convert('RGB')
        next_img = self.transform(next_img)

        return img, next_img


In [13]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  # 假设输入图像经过pool后为64x64
        self.fc2 = nn.Linear(128, 3 * 256 * 256)  # 输出下一张图像

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 64 * 64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.view(-1, 3, 256, 256)  # 变形为图像
        return x


In [14]:
# 设置参数
image_folder = 'overDataSet'  # 替换为你的图片文件夹路径
batch_size = 4
learning_rate = 0.001
num_epochs = 3

# 创建数据集和数据加载器
dataset = ImageSequenceDataset(image_folder)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for imgs, next_imgs in dataloader:
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, next_imgs)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [1/3], Loss: 0.0068
Epoch [2/3], Loss: 0.0101
Epoch [3/3], Loss: 0.0185


In [15]:
# 预测下一张图片
model.eval()
with torch.no_grad():
    last_img_path = os.path.join(image_folder, '240.png')  # 最后一张图片的名称
    last_img = Image.open(last_img_path).convert('RGB')
    last_img = transforms.Resize((256, 256))(last_img)
    last_img = transforms.ToTensor()(last_img).unsqueeze(0)  # 添加批量维度

    predicted_img = model(last_img)
    predicted_img = predicted_img.squeeze(0).permute(1, 2, 0)  # 变形为HWC
    predicted_img = (predicted_img.numpy() * 255).astype('uint8')

    # 保存或显示预测结果
    Image.fromarray(predicted_img).save('predicted_image.png')
