In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import DatasetFolder, ImageFolder
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchsummary import summary
from tqdm import tqdm
from PIL import Image
import csv
import pandas as pd
import numpy as np
import cv2

torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
signs_dict = {
    "Speed limit (5km)": 0,
    "Speed limit (15km)": 1,
    "Speed limit (20km)": 2,
    "Speed limit (30km)": 3,
    "Speed limit (40km)": 4,
    "Speed limit (50km)": 5,
    "Speed limit (60km)": 6,
    "Speed limit (70km)": 7,
    "speed limit (80km)": 8,
    "speed limit (100km)": 9,
    "speed limit (120km)": 10,
    "End of speed limit": 11,
    "End of speed limit (50km)": 12,
    "End of speed limit (80km)": 13,
    "Dont overtake from Left": 14,
    "No stopping": 15,
    "No Uturn": 16,
    "No Car": 17,
    "No horn": 18,
    "No entry": 19,
    "No passage": 20,
    "Dont Go Right": 21,
    "Dont Go Left or Right": 22,
    "Dont Go Left": 23,
    "Dont Go straight": 24,
    "Dont Go straight or Right": 25,
    "Dont Go straight or left": 26,
    "Go right or straight": 27,
    "Go left or straight": 28,
    "Village": 29,
    "Uturn": 30,
    "ZigZag Curve": 31,
    "Bicycles crossing": 32,
    "Keep Right": 33,
    "Keep Left": 34,
    "Roundabout mandatory": 35,
    "Watch out for cars": 36,
    "Slow down and give way": 37,
    "Continuous detours": 38,
    "Slow walking": 39,
    "Horn": 40,
    "Uphill steep slope": 41,
    "Downhill steep slope": 42,
    "Under Construction": 43,
    "Heavy Vehicle Accidents": 44,
    "Parking inspection": 45,
    "Stop at intersection": 46,
    "Train Crossing": 47,
    "Fences": 48,
    "Dangerous curve to the right": 49,
    "Go Right": 50,
    "Go Left or right": 51,
    "Dangerous curve to the left": 52,
    "Go Left": 53,
    "Go straight": 54,
    "Go straight or right": 55,
    "Children crossing": 56,
    "Care bicycles crossing": 57,
    "Danger Ahead": 58,
    "Traffic signals": 59,
    "Zebra Crossing": 60,
    "Road Divider": 61
}

In [3]:
class CA_Block(nn.Module):
    def __init__(self, channel, h, w, reduction=16):
        super(CA_Block, self).__init__()
 
        self.h = h
        self.w = w
 
        self.avg_pool_x = nn.AdaptiveAvgPool2d((h, 1))
        self.avg_pool_y = nn.AdaptiveAvgPool2d((1, w))
 
        self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel//reduction, kernel_size=1, stride=1, bias=False)
 
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(channel//reduction)
 
        self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
        self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
 
        self.sigmoid_h = nn.Sigmoid()
        self.sigmoid_w = nn.Sigmoid()
 
    def forward(self, x):
 
        x_h = self.avg_pool_x(x).permute(0, 1, 3, 2)
        x_w = self.avg_pool_y(x)
 
        x_cat_conv_relu = self.relu(self.conv_1x1(torch.cat((x_h, x_w), 3)))
 
        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([self.h, self.w], 3)
 
        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
 
        out = x * s_h.expand_as(x) * s_w.expand_as(x)
 
        return out

In [4]:
class ResnetCA(nn.Module):
    def __init__(self, model=models.resnet18(pretrained=False)):
        super(ResnetCA,self).__init__()
        self.resnet = nn.Sequential(*list(model.children())[:-2])
        self.ca = CA_Block(channel=512, h=7, w=7)
        self.avg_pool = nn.Sequential(*list(model.children())[-2:])[0]
        self.fc = nn.Linear(in_features=512, out_features=62, bias=True)
    def forward(self, x):
        x=self.resnet(x)
        # x=x.view(-1,49,512)
        x=self.ca(x)
        x=x.view(-1,512,7,7)
        x=self.avg_pool(x)
        x=x.view(x.size(0), -1)
        x=self.fc(x)
        return x

resnet_ca = ResnetCA()



In [5]:
class SuperResolutionTransform:
    def __init__(self, model_path, scale=4):
        self.sr = cv2.dnn_superres.DnnSuperResImpl_create()
        self.sr.readModel(model_path)
        self.sr.setModel("espcn", scale)

    def __call__(self, img):
        img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        upsampled_img_cv = self.sr.upsample(img_cv)
        upsampled_img = Image.fromarray(cv2.cvtColor(upsampled_img_cv, cv2.COLOR_BGR2RGB))
        return upsampled_img

In [6]:
# 定义数据集类
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_files = os.listdir(folder_path)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.folder_path, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]   

# 定义预处理变换
transform = transforms.Compose([
    SuperResolutionTransform("ESPCN_x4.pb"),  # 超分辨率重构
    transforms.Grayscale(num_output_channels=3),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # 颜色调整
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
    
# 加载模型
model_path = "./model/best_model.pth"
model = torch.load(model_path)
model.eval()
model.to(device)

# 加载数据集
test_dataset = ImageDataset("./data/test_set/unknow/", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# 预测并保存结果到CSV
results = []
for images, filenames in test_loader:
    with torch.no_grad():
        outputs = model(images.to(device))
        _, predicted = torch.max(outputs, 1)
        results.append((filenames[0], predicted.item()))

# 将结果保存到CSV
df = pd.DataFrame(results, columns=["ImageID", "label"])
df.to_csv("predictions.csv", index=False)