In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )
            
    def forward(self, x):
        out = self.left(x)
        out = out + self.shortcut(x)
        out = F.relu(out)
        
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=5):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)      
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)        
        self.fc = nn.Linear(14336, num_classes)
        
        
    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)  # 展平成一维
        #print(out.shape)  # 调试信息，打印输出尺寸
        out = self.fc(out)
        return out
    
def ResNet18():
    return ResNet(ResidualBlock)

In [3]:
class ImageClassifier:
    def __init__(self, model_path, class_names, num_classes=5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._load_model(model_path, num_classes)
        self.class_names = class_names
        self.transform = transforms.Compose([
            transforms.Resize((135, 240)),
            transforms.ToTensor(),
        ])

    def _load_model(self, model_path, num_classes):
        model = ResNet(ResidualBlock, num_classes).to(self.device)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint)
        model.eval()
        return model

    def classify_image(self, image_path):
        image = Image.open(image_path)
        image = image.crop([300, 80, 1620, 950])
        image = self.transform(image)
        image = image.unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(image)
            probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()

        return {self.class_names[i]: probabilities[i] for i in range(len(self.class_names))}

In [4]:
# 示例使用
if __name__ == "__main__":
    model_path = "1resnet18_weights.pth"
    image_path = "test/image000005.jpg"
    class_names = ['tri','T','sin','for','dou']
    classifier = ImageClassifier(model_path, class_names)
    probabilities = classifier.classify_image(image_path)
    
    for class_name, prob in probabilities.items():
        print(f"{class_name}: {prob:.5f}")

tri: 0.00013
T: 0.99972
sin: 0.00001
for: 0.00014
dou: 0.00000


  return F.conv2d(input, weight, bias, self.stride,
