In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_data, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_loader =  torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

# CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 14 * 14, 128)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc(x))
        return x

# LSTM 模型
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        _, (hidden, _) = self.lstm(x)
        x = self.fc(hidden[-1])
        return x

# 結合 CNN 和 LSTM
class CNN_LSTM(nn.Module):
    def __init__(self, cnn, lstm):
        super(CNN_LSTM, self).__init__()
        self.cnn = cnn
        self.lstm = lstm
        self.fc = nn.Linear(128 + 10, 10)  # Adjust output dimensions as necessary
    
    def forward(self, x_img, x_seq):
        x_cnn = self.cnn(x_img)
        x_lstm = self.lstm(x_seq)
        x = torch.cat((x_cnn, x_lstm), dim=1)
        x = self.fc(x)
        return x


In [13]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    model.train()  
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(train_loader):
            x_img = images
            x_seq = images.view(images.size(0), 28, 28)  # 將圖像展平為 (batch_size, 28, 28)，作為序列輸入給 LSTM
            labels = labels

            optimizer.zero_grad()

            outputs = model(x_img, x_seq)

            loss = criterion(outputs, labels)
            loss.backward()

            # 更新參數
            optimizer.step()

            # 統計Loss
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total * 100
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        # Validation
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in val_loader:
                x_img = images
                x_seq = images.view(images.size(0), 28, 28)
                labels = labels

                outputs = model(x_img, x_seq)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            val_acc = correct / total * 100
            print(f"Validation Accuracy: {val_acc:.2f}%")

In [14]:
def evaluate_model(model, test_loader, criterion):
    model.eval()  #
    test_correct = 0
    test_total = 0
    test_loss = 0.0
    with torch.no_grad():
        for test_images, test_labels in test_loader:
            test_x_seq = test_images.view(test_images.size(0), 28, 28)
            test_outputs = model(test_images, test_x_seq)
            loss = criterion(test_outputs, test_labels)
            test_loss += loss.item()

            _, test_predicted = torch.max(test_outputs, 1)
            test_total += test_labels.size(0)
            test_correct += (test_predicted == test_labels).sum().item()

    test_loss = test_loss / len(test_loader)
    test_acc = test_correct / test_total * 100
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")


In [15]:
# 初始化 CNN 和 LSTM 模型
cnn = CNN()  # 實例化 CNN 模型
lstm = LSTM(input_dim=28, hidden_dim=64, output_dim=10)  # 实例化 LSTM 模型

model = CNN_LSTM(cnn, lstm)


In [17]:
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
evaluate_model(model, test_loader, criterion)


Epoch 1/10, Loss: 0.3241, Accuracy: 90.24%
Validation Accuracy: 95.93%
Epoch 2/10, Loss: 0.1002, Accuracy: 96.96%
Validation Accuracy: 97.47%
Epoch 3/10, Loss: 0.0608, Accuracy: 98.15%
Validation Accuracy: 98.67%
Epoch 4/10, Loss: 0.0447, Accuracy: 98.62%
Validation Accuracy: 98.97%
Epoch 5/10, Loss: 0.0333, Accuracy: 98.95%
Validation Accuracy: 98.98%
Epoch 6/10, Loss: 0.0274, Accuracy: 99.17%
Validation Accuracy: 99.22%
Epoch 7/10, Loss: 0.0212, Accuracy: 99.36%
Validation Accuracy: 99.62%
Epoch 8/10, Loss: 0.0165, Accuracy: 99.52%
Validation Accuracy: 99.65%
Epoch 9/10, Loss: 0.0145, Accuracy: 99.55%
Validation Accuracy: 99.79%
Epoch 10/10, Loss: 0.0107, Accuracy: 99.66%
Validation Accuracy: 99.87%
Test Loss: 0.0426, Test Accuracy: 98.70%
