In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# 数据集路径
train_dir = 'Pin/train'
validation_dir = 'Pin/test'

# 设置参数
input_size = 224
batch_size = 8
num_epochs = 5 
num_classes = 2
bidirectional = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
    ]),
    'validation': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
    ]),
}

image_datasets = {x: torchvision.datasets.ImageFolder(train_dir if x == 'train' else validation_dir, data_transforms[x]) for x in ['train', 'validation']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'validation']}

classes = image_datasets['train'].classes
classes_index = image_datasets['train'].class_to_idx
print(classes)
print(classes_index)


# 使用VGG16提取特征
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
    param.requires_grad = False

vgg16.classifier = nn.Sequential(
    nn.Linear(in_features=25088, out_features=100, bias=True),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5, inplace=False),
    nn.Linear(in_features=100, out_features=2, bias=True),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5, inplace=False),
)

vgg16.to(device)

# 添加LSTM层
class VGG16_LSTM(nn.Module):
    def __init__(self, base_model):
        super(VGG16_LSTM, self).__init__()
        self.base_model = base_model
        self.lstm = nn.LSTM(input_size=2, hidden_size=1, num_layers=2, batch_first=True)
        self.fc = nn.Linear(1 * (2 if bidirectional else 1), num_classes, bias=True)

    def forward(self, x):
        x = self.base_model(x)
        x, _ = self.lstm(x.view(x.size(0), 1, -1))
        x = self.fc(x[:, -1, :])
        return x

model = VGG16_LSTM(vgg16).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# 训练模型
for epoch in range(num_epochs):
    print("epoch:", epoch)
    for phase in ['train', 'validation']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders[phase]:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
            
torch.save(model.state_dict(), 'morris_md2_use_lstm.pth')

In [3]:
import numpy as np

def create_custom_vgg16():
    vgg16 = models.vgg16(pretrained=True)
    vgg16.classifier = nn.Sequential(
        nn.Linear(in_features=25088, out_features=100, bias=True),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5, inplace=False),
        nn.Linear(in_features=100, out_features=2, bias=True),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5, inplace=False),
    )
    return vgg16

# 加载训练好的模型
model_path = 'morris_md2_use_lstm.pth'
vgg16_custom = create_custom_vgg16()
model = VGG16_LSTM(vgg16_custom)
model.load_state_dict(torch.load(model_path))
model.to(device)

model.eval()

label = np.array(['md2','morris'])

# preprocessing
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor() 
])

def predict(image_path):
    # open photo
    img = Image.open(image_path)
    # prossing and add more dimension
    img = transform(img).unsqueeze(0)
    # Get model prediction results
    outputs = model(img)
    # get max position
    _, predicted = torch.max(outputs,1)
    # convert to label name
    print(label[predicted.item()])
    
predict('morris_test2.jpg')

md2
