In [15]:
import torch
import cv2
from torchvision import transforms
from PIL import Image
from torch import nn
import numpy as np

# Định nghĩa lại lớp mô hình để khôi phục mô hình đã lưu
class TrafficSignModel(nn.Module):
    def __init__(self):
        super(TrafficSignModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 43)  # Cập nhật số lớp đầu ra là 43

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Tải mô hình đã huấn luyện từ tệp
model = TrafficSignModel()
model.load_state_dict(torch.load('traffic_sign_model.pth'))
model.eval()  # Chuyển mô hình sang chế độ evaluation

# Định nghĩa các phép biến đổi giống như khi huấn luyện
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Hàm dự đoán biển báo từ ảnh
def predict_traffic_sign(image):
    image = Image.fromarray(image)  # Chuyển đổi từ numpy array sang PIL image
    image = transform(image).unsqueeze(0)  # Áp dụng phép biến đổi và thêm batch dimension

    with torch.no_grad():
        output = model(image)  # Dự đoán
        _, predicted = torch.max(output, 1)  # Lấy nhãn dự đoán

    return predicted.item()


# Các nhãn biển báo giao thông
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai'
    # Thêm đủ các nhãn ở đây
}


# Hàm phát hiện hình tròn đỏ
def detect_circle_red(contour):  # Kiểm tra hình tròn đỏ
    area = cv2.contourArea(contour)
    if area < 500: return False
    perimeter = cv2.arcLength(contour, True)
    if perimeter == 0: return False
    x, y, w, h = cv2.boundingRect(contour)
    aspect_ratio = float(w) / h
    circularity = (4 * np.pi * area) / (perimeter ** 2)

    # Điều kiện cho hình tròn đỏ hợp lệ
    dks = [
        {"circularity_range": (0.65, 1), "aspect_ratio_range": (0.75, 1.2), "height_range": (37, 120), "perimeter_min": 175},
        {"circularity_range": (0.23, 0.24), "aspect_ratio_range": (0.75, 1.2), "height_range": (37, 120), "area_range": (595, 700), "perimeter_range": (165, 180)},
        {"circularity_range": (0.12, 0.13), "area_range": (1590, 1658), "perimeter_range": (390, 420)}
    ]
    for dk in dks:
        if (dk["circularity_range"][0] <= circularity <= dk["circularity_range"][1] and
            dk["aspect_ratio_range"][0] <= aspect_ratio <= dk["aspect_ratio_range"][1] and
            dk["height_range"][0] < h < dk["height_range"][1] and perimeter >= dk["perimeter_min"] and
            dk["area_range"][0] <= area <= dk["area_range"][1] and
            dk["perimeter_range"][0] <= perimeter <= dk["perimeter_range"][1]):
            return True
    return False


# Hàm phát hiện hình tròn xanh
def detect_circle_blue(contour):  # Kiểm tra hình tròn xanh
    area = cv2.contourArea(contour)
    if area < 2300: return False
    perimeter = cv2.arcLength(contour, True)
    if perimeter == 0: return False
    x, y, w, h = cv2.boundingRect(contour)
    aspect_ratio = float(w) / h
    circularity = (4 * np.pi * area) / (perimeter ** 2)

    # Điều kiện cho hình tròn xanh hợp lệ
    small_circle = 0.67 <= circularity <= 1 and 0.9 <= aspect_ratio <= 1.2 and 37 < h < 150
    medium_circle = 0.36 <= circularity < 0.67 and 0.9 <= aspect_ratio <= 1.2 and 37 < h < 150 and area > 8500 and perimeter > 500
    large_circle = 0.25 <= circularity < 0.36 and 0.9 <= aspect_ratio <= 1.2 and 37 < h < 150 and area > 14500 and perimeter > 700

    if small_circle or medium_circle or large_circle:
        return True
    return False


# Hàm phát hiện hình tam giác đỏ
def detect_triangle_red(contour):  # Kiểm tra hình tam giác đỏ
    area = cv2.contourArea(contour)
    if area < 2300: return False
    perimeter = cv2.arcLength(contour, True)
    approx = cv2.approxPolyDP(contour, 0.04 * perimeter, True)
    if len(approx) == 3:  # Hình tam giác có 3 cạnh
        x, y, w, h = cv2.boundingRect(approx)
        aspect_ratio = float(w) / h
        if area < 1400 and perimeter < 150: return False
        if 0.9 < aspect_ratio < 1 and 30 < w < 150 and 30 < h < 150:
            return True
    return False


# Hàm phát hiện hình chữ nhật xanh
def detect_rectangle_blue(contour):  # Kiểm tra hình chữ nhật xanh
    area = cv2.contourArea(contour)
    if area < 1700: return False
    perimeter = cv2.arcLength(contour, True)
    approx = cv2.approxPolyDP(contour, 0.02 * perimeter, True)
    if len(approx) != 4: return False
    x, y, w, h = cv2.boundingRect(approx)
    aspect_ratio = float(w) / h
    if 44 < w < 90 and 32 < h < 60 and 1200 < area < 6000 and 140 < perimeter < 400:
        return True
    return False


# Hàm phát hiện hình tròn xanh
def detect_circle_red(contour):  # Kiểm tra hình tròn đỏ
    # Tính các đặc trưng hình học
    x, y, w, h = cv2.boundingRect(contour)
    aspect_ratio = float(w) / h
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    
    # Kiểm tra độ tròn
    circularity = 4 * np.pi * (area / (perimeter ** 2)) if perimeter > 0 else 0

    dks = [
        {"circularity_range": (0.65, 1), "aspect_ratio_range": (0.75, 1.2), "height_range": (37, 120), "perimeter_min": 175},
        {"circularity_range": (0.23, 0.24), "aspect_ratio_range": (0.75, 1.2), "height_range": (37, 120), "area_range": (595, 700), "perimeter_range": (165, 180)},
        {"circularity_range": (0.12, 0.13), "area_range": (1590, 1658), "perimeter_range": (390, 420)}
    ]
    
    for dk in dks:
        # Kiểm tra sự tồn tại của các khóa trước khi truy cập
        if "aspect_ratio_range" in dk:
            if (dk["circularity_range"][0] <= circularity <= dk["circularity_range"][1] and
                dk["aspect_ratio_range"][0] <= aspect_ratio <= dk["aspect_ratio_range"][1] and
                dk["height_range"][0] < h < dk["height_range"][1] and perimeter >= dk["perimeter_min"] and
                "area_range" in dk and dk["area_range"][0] <= area <= dk["area_range"][1] and
                "perimeter_range" in dk and dk["perimeter_range"][0] <= perimeter <= dk["perimeter_range"][1]):
                return True
    return False



# Hàm xử lý ảnh
def preprocess_image(image):  # Xử lý ảnh
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    lower_red1 = np.array([0, 150, 50])
    upper_red1 = np.array([10, 255, 150])
    lower_red2 = np.array([150, 100, 20])
    upper_red2 = np.array([180, 255, 150])
    mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
    mask_red = cv2.bitwise_or(mask_red1, mask_red2)

    lower_blue = np.array([90, 50, 70])
    upper_blue = np.array([140, 255, 255])
    mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)

    mask_combined = cv2.bitwise_or(mask_red, mask_blue)
    return mask_combined


# Hàm phát hiện biển báo giao thông trong ảnh
def detect_traffic_signs(image):  # Phát hiện biển báo giao thông
    mask = preprocess_image(image)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    traffic_signs = []

    for contour in contours:
        if detect_circle_red(contour):  # Kiểm tra hình tròn đỏ
            x, y, w, h = cv2.boundingRect(contour)
            traffic_signs.append(('Circle Red', x, y, w, h, 'Red circle sign'))
        elif detect_circle_blue(contour):  # Kiểm tra hình tròn xanh
            x, y, w, h = cv2.boundingRect(contour)
            traffic_signs.append(('Circle Blue', x, y, w, h, 'Blue circle sign'))
        elif detect_triangle_red(contour):  # Kiểm tra hình tam giác đỏ
            x, y, w, h = cv2.boundingRect(contour)
            traffic_signs.append(('Triangle Red', x, y, w, h, 'Red triangle sign'))
        elif detect_rectangle_blue(contour):  # Kiểm tra hình chữ nhật xanh
            x, y, w, h = cv2.boundingRect(contour)
            traffic_signs.append(('Rectangle Blue', x, y, w, h, 'Blue rectangle sign'))

    return traffic_signs


# Quá trình phân loại biển báo trong video
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Phát hiện biển báo trong frame
        traffic_signs = detect_traffic_signs(frame)

        # Xử lý từng biển báo được phát hiện
        for sign in traffic_signs:
            label, x, y, w, h, info = sign
            sign_image = frame[y:y+h, x:x+w]
            sign_image_resized = cv2.resize(sign_image, (64, 64))  # Resize để phù hợp với kích thước đầu vào của model
            sign_image_resized = sign_image_resized.astype('float32') / 255.0
            sign_image_resized = np.expand_dims(sign_image_resized, axis=0)  # Thêm dimension batch
            
            # Dự đoán loại biển báo sử dụng model
            prediction = model.predict(sign_image_resized)
            predicted_label = np.argmax(prediction, axis=1)

            # Lấy tên biển báo từ traffic_sign_labels
            traffic_sign_name = traffic_sign_labels.get(predicted_label[0], 'Unknown')  # 'Unknown' nếu không có trong từ điển
            
            # Vẽ bounding box và tên biển báo lên frame
            cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # Vẽ khung bao biển báo
            cv2.putText(frame, traffic_sign_name, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # Hiển thị frame với bounding box và tên biển báo
        cv2.imshow('Traffic Sign Detection and Classification', frame)
        
        # Thoát khi nhấn 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


video_path = 'video1.mp4'  # Đường dẫn tới video
process_video(video_path)  # Bắt đầu xử lý video

  model.load_state_dict(torch.load('traffic_sign_model.pth'))


KeyError: 'perimeter_min'

In [1]:
import torch
import cv2
from torchvision import transforms
from PIL import Image
from torch import nn
import numpy as np

# Định nghĩa lại lớp mô hình để khôi phục mô hình đã lưu
class TrafficSignModel(nn.Module):
    def __init__(self):
        super(TrafficSignModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 43)  # Cập nhật số lớp đầu ra là 43

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Tải mô hình đã huấn luyện từ tệp
model = TrafficSignModel()
model.load_state_dict(torch.load('traffic_sign_model.pth'))
model.eval()  # Chuyển mô hình sang chế độ evaluation

# Định nghĩa các phép biến đổi giống như khi huấn luyện
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Hàm dự đoán biển báo từ ảnh
def predict_traffic_sign(image):
    image = Image.fromarray(image)  # Chuyển đổi từ numpy array sang PIL image
    image = transform(image).unsqueeze(0)  # Áp dụng phép biến đổi và thêm batch dimension

    with torch.no_grad():
        output = model(image)  # Dự đoán
        _, predicted = torch.max(output, 1)  # Lấy nhãn dự đoán

    return predicted.item()


def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Phát hiện biển báo trong frame
        traffic_signs = detect_traffic_signs(frame)

        # Xử lý từng biển báo được phát hiện
        for sign in traffic_signs:
            # Giải nén các giá trị trả về (label, x, y, w, h, info)
            label, x, y, w, h, info = sign  # Điều chỉnh giải nén phù hợp

            sign_image = frame[y:y+h, x:x+w]
            sign_image_resized = cv2.resize(sign_image, (64, 64))  # Resize để phù hợp với kích thước đầu vào của model
            sign_image_resized = sign_image_resized.astype('float32') / 255.0
            sign_image_resized = np.expand_dims(sign_image_resized, axis=0)  # Thêm dimension batch
            
            # Dự đoán loại biển báo sử dụng model
            prediction = model.predict(sign_image_resized)
            predicted_label = np.argmax(prediction, axis=1)

            # Lấy tên biển báo từ traffic_sign_labels
            traffic_sign_name = traffic_sign_labels.get(predicted_label[0], 'Unknown')  # 'Unknown' nếu không có trong từ điển
            
            # Vẽ bounding box và tên biển báo lên frame
            cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # Vẽ khung bao biển báo
            cv2.putText(frame, traffic_sign_name, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # Hiển thị frame với bounding box và tên biển báo
        cv2.imshow('Traffic Sign Detection and Classification', frame)
        
        # Thoát khi nhấn 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()




# Các nhãn biển báo giao thông
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai'
    # Thêm đủ các nhãn ở đây
}


def detect_circle_red(contour):
    
    area=cv2.contourArea(contour)
    
    #Loai tru cac hinh tron nho, tranh phat hien sai
    if area<500:
        return False

    perimeter=cv2.arcLength(contour,True)
    
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)
    
    #Danh sach dieu kien
    dks=[
        {
            "circularity_range": (0.65,1),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "perimeter_min": 175
        },
        {
            "circularity_range": (0.23,0.24),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "area_range": (595,700),
            "perimeter_range": (165,180)
        },
        {
            "circularity_range": (0.12,0.13),
            "area_range": (1590,1658),
            "perimeter_range": (390,420),
        }
    ]

    #kemtra
    for dk in dks:
        if (dk.get("circularity_range",(0,1))[0]<=circularity<=dk.get("circularity_range",(0,1))[1] and
            dk.get("aspect_ratio_range",(0,float("inf")))[0]<=aspect_ratio<=dk.get("aspect_ratio_range",(0,float("inf")))[1] and
            dk.get("height_range",(0,float("inf")))[0]<h<dk.get("height_range",(0,float("inf")))[1] and
            perimeter>=dk.get("perimeter_min",0) and
            dk.get("area_range",(0,float("inf")))[0]<=area<=dk.get("area_range",(0,float("inf")))[1] and
            dk.get("perimeter_range",(0,float("inf")))[0]<=perimeter<=dk.get("perimeter_range",(0,float("inf")))[1]):
            return True

    return False



#Ham phat hien hinh tron mau xanh
def detect_circle_blue(contour):
    area=cv2.contourArea(contour)
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)

    #Dieu kien cho cac loai hinh tron khac nhau
    small_circle=0.67<=circularity<=1 and 0.9<=aspect_ratio<=1.2 and 37<h<150
    medium_circle=0.36<=circularity<0.67 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>8500 and perimeter>500
    large_circle=0.25<=circularity<0.36 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>14500 and perimeter>700

    #Dieu kien loai tru
    exclusion_area_perimeter=area<2500 and perimeter<210

    #Kiem tra cac dieu kien
    if exclusion_area_perimeter:
        return False
    if small_circle or medium_circle or large_circle:
        return True

    return False



def detect_triangle_red(contour):
    
    area=cv2.contourArea(contour)
    
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.04*perimeter,True)

    if len(approx) == 3:  #Hinh tam giac co 3 canh
        
        x,y,w,h=cv2.boundingRect(approx)
        aspect_ratio=float(w)/h

        #Dieu kien loai tru cho hinh tam giac nho
        if area<1400 and perimeter<150:
            return False

        #Dieu kien cho hinh tam giac hop le
        valid_triangle=0.9<aspect_ratio<1 and 30<w<150 and 30<h<150
        
        if valid_triangle:
            return True

    return False


def detect_rectangle_blue(contour):
    
    area=cv2.contourArea(contour)
    
    if area<1700:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.02*perimeter,True)

    if len(approx)!=4:
        return False

    x,y,w,h=cv2.boundingRect(approx)
    aspect_ratio=float(w)/h

    #Cac dieu kien kich thuoc cho hinh chu nhat mong muon
    large_rectangle=w<150 and area>19000
    medium_rectangle=44<w<90 and 32<h<60 and 1200<area<6000 and 140<perimeter<400
    unwanted_rectangle=95<w<153 and 50<h<86 and perimeter<460 and area<8300

    #Loai tru cac truong hop chu vi va dien tich khong phu hop
    high_perimeter_exclusion=perimeter>700 and area<10000
    low_area_exclusion=perimeter>100 and area<900

    #Cac dieu kien ty le khung hinh va kich thuoc
    small_aspect_ratio=0.9<aspect_ratio<2 and 20<w<90 and 20<h<185
    large_aspect_ratio=0.9<aspect_ratio<2 and 90<w<300 and 20<h<185

    #Kiem tra cac dieu kien loai tru va mong muon
    if large_rectangle or medium_rectangle:
        return True
    
    if unwanted_rectangle or high_perimeter_exclusion or low_area_exclusion:
        return False
    
    if small_aspect_ratio or large_aspect_ratio:
        return True

    return False

# Hàm xử lý ảnh
def preprocess_image(image):  # Xử lý ảnh
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    lower_red1 = np.array([0, 150, 50])
    upper_red1 = np.array([10, 255, 150])
    lower_red2 = np.array([150, 100, 20])
    upper_red2 = np.array([180, 255, 150])
    mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
    mask_red = cv2.bitwise_or(mask_red1, mask_red2)

    lower_blue = np.array([90, 50, 70])
    upper_blue = np.array([140, 255, 255])
    mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)

    mask_combined = cv2.bitwise_or(mask_red, mask_blue)
    return mask_combined


# Hàm phát hiện biển báo giao thông trong ảnh
def detect_traffic_signs(image):  # Phát hiện biển báo giao thông
    mask = preprocess_image(image)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    traffic_signs = []
    for contour in contours:
        if detect_circle_red(contour):  # Kiểm tra hình tròn đỏ
            traffic_signs.append(('circle_red', contour))
        elif detect_circle_blue(contour):  # Kiểm tra hình tròn xanh
            traffic_signs.append(('circle_blue', contour))
        elif detect_triangle_red(contour):  # Kiểm tra hình tam giác đỏ
            traffic_signs.append(('triangle_red', contour))
        elif detect_rectangle_blue(contour):  # Kiểm tra hình chữ nhật xanh
            traffic_signs.append(('rectangle_blue', contour))

    return traffic_signs

# Main function to run the program
video_path = 'video1.mp4'  # Đường dẫn tới video
process_video(video_path)  # Bắt đầu xử lý video


  model.load_state_dict(torch.load('traffic_sign_model.pth'))


RuntimeError: Error(s) in loading state_dict for TrafficSignModel:
	size mismatch for fc2.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([43, 512]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([43]).

In [None]:
import torch
import cv2
from torchvision import transforms
from PIL import Image
from torch import nn
import numpy as np

# Định nghĩa lại lớp mô hình để khôi phục mô hình đã lưu
class TrafficSignModel(nn.Module):
    def __init__(self):
        super(TrafficSignModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 43)  # Cập nhật số lớp đầu ra là 43

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Tải mô hình đã huấn luyện từ tệp
model = TrafficSignModel()
model.load_state_dict(torch.load('traffic_sign_model.pth'))
model.eval()  # Chuyển mô hình sang chế độ evaluation

# Định nghĩa các phép biến đổi giống như khi huấn luyện
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def detect_circle_red(contour):
    
    area=cv2.contourArea(contour)
    
    #Loai tru cac hinh tron nho, tranh phat hien sai
    if area<500:
        return False

    perimeter=cv2.arcLength(contour,True)
    
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)
    
    #Danh sach dieu kien
    dks=[
        {
            "circularity_range": (0.65,1),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "perimeter_min": 175
        },
        {
            "circularity_range": (0.23,0.24),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "area_range": (595,700),
            "perimeter_range": (165,180)
        },
        {
            "circularity_range": (0.12,0.13),
            "area_range": (1590,1658),
            "perimeter_range": (390,420),
        }
    ]

    #kemtra
    for dk in dks:
        if (dk.get("circularity_range",(0,1))[0]<=circularity<=dk.get("circularity_range",(0,1))[1] and
            dk.get("aspect_ratio_range",(0,float("inf")))[0]<=aspect_ratio<=dk.get("aspect_ratio_range",(0,float("inf")))[1] and
            dk.get("height_range",(0,float("inf")))[0]<h<dk.get("height_range",(0,float("inf")))[1] and
            perimeter>=dk.get("perimeter_min",0) and
            dk.get("area_range",(0,float("inf")))[0]<=area<=dk.get("area_range",(0,float("inf")))[1] and
            dk.get("perimeter_range",(0,float("inf")))[0]<=perimeter<=dk.get("perimeter_range",(0,float("inf")))[1]):
            return True

    return False



#Ham phat hien hinh tron mau xanh
def detect_circle_blue(contour):
    area=cv2.contourArea(contour)
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)

    #Dieu kien cho cac loai hinh tron khac nhau
    small_circle=0.67<=circularity<=1 and 0.9<=aspect_ratio<=1.2 and 37<h<150
    medium_circle=0.36<=circularity<0.67 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>8500 and perimeter>500
    large_circle=0.25<=circularity<0.36 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>14500 and perimeter>700

    #Dieu kien loai tru
    exclusion_area_perimeter=area<2500 and perimeter<210

    #Kiem tra cac dieu kien
    if exclusion_area_perimeter:
        return False
    if small_circle or medium_circle or large_circle:
        return True

    return False



def detect_triangle_red(contour):
    
    area=cv2.contourArea(contour)
    
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.04*perimeter,True)

    if len(approx) == 3:  #Hinh tam giac co 3 canh
        
        x,y,w,h=cv2.boundingRect(approx)
        aspect_ratio=float(w)/h

        #Dieu kien loai tru cho hinh tam giac nho
        if area<1400 and perimeter<150:
            return False

        #Dieu kien cho hinh tam giac hop le
        valid_triangle=0.9<aspect_ratio<1 and 30<w<150 and 30<h<150
        
        if valid_triangle:
            return True

    return False


def detect_rectangle_blue(contour):
    
    area=cv2.contourArea(contour)
    
    if area<1700:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.02*perimeter,True)

    if len(approx)!=4:
        return False

    x,y,w,h=cv2.boundingRect(approx)
    aspect_ratio=float(w)/h

    #Cac dieu kien kich thuoc cho hinh chu nhat mong muon
    large_rectangle=w<150 and area>19000
    medium_rectangle=44<w<90 and 32<h<60 and 1200<area<6000 and 140<perimeter<400
    unwanted_rectangle=95<w<153 and 50<h<86 and perimeter<460 and area<8300

    #Loai tru cac truong hop chu vi va dien tich khong phu hop
    high_perimeter_exclusion=perimeter>700 and area<10000
    low_area_exclusion=perimeter>100 and area<900

    #Cac dieu kien ty le khung hinh va kich thuoc
    small_aspect_ratio=0.9<aspect_ratio<2 and 20<w<90 and 20<h<185
    large_aspect_ratio=0.9<aspect_ratio<2 and 90<w<300 and 20<h<185

    #Kiem tra cac dieu kien loai tru va mong muon
    if large_rectangle or medium_rectangle:
        return True
    
    if unwanted_rectangle or high_perimeter_exclusion or low_area_exclusion:
        return False
    
    if small_aspect_ratio or large_aspect_ratio:
        return True

    return False

# Hàm dự đoán biển báo từ ảnh
def predict_traffic_sign(image):
    image = Image.fromarray(image)  # Chuyển đổi từ numpy array sang PIL image
    image = transform(image).unsqueeze(0)  # Áp dụng phép biến đổi và thêm batch dimension

    with torch.no_grad():
        output = model(image)  # Dự đoán
        _, predicted = torch.max(output, 1)  # Lấy nhãn dự đoán

    return predicted.item()


def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Phát hiện biển báo trong frame
        traffic_signs = detect_traffic_signs(frame)

        # Xử lý từng biển báo được phát hiện
        for sign in traffic_signs:
            # Giải nén các giá trị trả về (label, contour)
            sign_type, contour = sign  # Chỉ cần 2 giá trị: loại biển báo và contour

            x, y, w, h = cv2.boundingRect(contour)
            sign_image = frame[y:y+h, x:x+w]
            sign_image_resized = cv2.resize(sign_image, (64, 64))  # Resize để phù hợp với kích thước đầu vào của model
            sign_image_resized = sign_image_resized.astype('float32') / 255.0
            sign_image_resized = np.expand_dims(sign_image_resized, axis=0)  # Thêm dimension batch

            # Dự đoán loại biển báo sử dụng model
            prediction = model(torch.tensor(sign_image_resized).permute(0, 3, 1, 2).float())  # Chuyển đổi tensor về đúng định dạng
            predicted_label = torch.argmax(prediction, axis=1).item()

            # Lấy tên biển báo từ traffic_sign_labels
            traffic_sign_name = traffic_sign_labels.get(predicted_label, 'Unknown')  # 'Unknown' nếu không có trong từ điển
            
            # Vẽ bounding box và tên biển báo lên frame
            cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # Vẽ khung bao biển báo
            cv2.putText(frame, traffic_sign_name, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # Hiển thị frame với bounding box và tên biển báo
        cv2.imshow('Traffic Sign Detection and Classification', frame)
        
        # Thoát khi nhấn 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

# Các nhãn biển báo giao thông
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai'
    # Thêm đủ các nhãn ở đây
}

# Hàm phát hiện biển báo giao thông trong ảnh
def detect_traffic_signs(image):  # Phát hiện biển báo giao thông
    mask = preprocess_image(image)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    traffic_signs = []
    for contour in contours:
        if detect_circle_red(contour):  # Kiểm tra hình tròn đỏ
            traffic_signs.append(('circle_red', contour))
        elif detect_circle_blue(contour):  # Kiểm tra hình tròn xanh
            traffic_signs.append(('circle_blue', contour))
        elif detect_triangle_red(contour):  # Kiểm tra hình tam giác đỏ
            traffic_signs.append(('triangle_red', contour))
        elif detect_rectangle_blue(contour):  # Kiểm tra hình chữ nhật xanh
            traffic_signs.append(('rectangle_blue', contour))

    return traffic_signs

# Hàm xử lý ảnh
def preprocess_image(image):  # Xử lý ảnh
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    lower_red1 = np.array([0, 150, 50])
    upper_red1 = np.array([10, 255, 150])
    lower_red2 = np.array([150, 100, 20])
    upper_red2 = np.array([180, 255, 150])
    mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
    mask_red = cv2.bitwise_or(mask_red1, mask_red2)

    lower_blue = np.array([90, 50, 70])
    upper_blue = np.array([140, 255, 255])
    mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)

    mask_combined = cv2.bitwise_or(mask_red, mask_blue)
    return mask_combined


# Main function to run the program
video_path = 'video1.mp4'  # Đường dẫn tới video
process_video(video_path)  # Bắt đầu xử lý video


  model.load_state_dict(torch.load('traffic_sign_model.pth'))


RuntimeError: Error(s) in loading state_dict for TrafficSignModel:
	size mismatch for fc2.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([43, 512]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([43]).

In [3]:
import cv2
import torch
import numpy as np
import os
from torchvision import transforms

# Load the trained model
model = torch.load('traffic_sign_model.pth')
model.eval()

# Define traffic sign labels (example)
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai',
    # Add more labels here
}

# Define the transform for input image preprocessing (matching the model's input size)
transform = transforms.Compose([
    transforms.ToPILImage(),  # Convert numpy array to PIL image
    transforms.Resize((64, 64)),  # Resize to the model's expected size
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize if required
])

# Clean image function
def clean_images():
    file_list = os.listdir('./')
    for file_name in file_list:
        if '.png' in file_name:
            os.remove(file_name)

# Process image function (preprocessing steps)
def preprocess_image(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    lower_red1 = np.array([0, 150, 50])
    upper_red1 = np.array([10, 255, 150])
    lower_red2 = np.array([150, 100, 20])
    upper_red2 = np.array([180, 255, 150])
    mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
    mask_red = cv2.bitwise_or(mask_red1, mask_red2)

    lower_blue = np.array([105, 100, 120])
    upper_blue = np.array([110, 255, 255])
    mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)

    mask_red = cv2.GaussianBlur(mask_red, (5, 5), 0)
    mask_red = cv2.erode(mask_red, None, iterations=2)
    mask_red = cv2.dilate(mask_red, None, iterations=2)

    mask_blue = cv2.GaussianBlur(mask_blue, (5, 5), 0)
    mask_blue = cv2.erode(mask_blue, None, iterations=2)
    mask_blue = cv2.dilate(mask_blue, None, iterations=2)

    return mask_red, mask_blue
def detect_circle_red(contour):
    
    area=cv2.contourArea(contour)
    
    #Loai tru cac hinh tron nho, tranh phat hien sai
    if area<500:
        return False

    perimeter=cv2.arcLength(contour,True)
    
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)
    
    #Danh sach dieu kien
    dks=[
        {
            "circularity_range": (0.65,1),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "perimeter_min": 175
        },
        {
            "circularity_range": (0.23,0.24),
            "aspect_ratio_range": (0.75,1.2),
            "height_range": (37,120),
            "area_range": (595,700),
            "perimeter_range": (165,180)
        },
        {
            "circularity_range": (0.12,0.13),
            "area_range": (1590,1658),
            "perimeter_range": (390,420),
        }
    ]

    #kemtra
    for dk in dks:
        if (dk.get("circularity_range",(0,1))[0]<=circularity<=dk.get("circularity_range",(0,1))[1] and
            dk.get("aspect_ratio_range",(0,float("inf")))[0]<=aspect_ratio<=dk.get("aspect_ratio_range",(0,float("inf")))[1] and
            dk.get("height_range",(0,float("inf")))[0]<h<dk.get("height_range",(0,float("inf")))[1] and
            perimeter>=dk.get("perimeter_min",0) and
            dk.get("area_range",(0,float("inf")))[0]<=area<=dk.get("area_range",(0,float("inf")))[1] and
            dk.get("perimeter_range",(0,float("inf")))[0]<=perimeter<=dk.get("perimeter_range",(0,float("inf")))[1]):
            return True

    return False



#Ham phat hien hinh tron mau xanh
def detect_circle_blue(contour):
    area=cv2.contourArea(contour)
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    if perimeter == 0:
        return False

    x,y,w,h=cv2.boundingRect(contour)
    aspect_ratio=float(w)/h
    circularity=(4*np.pi*area)/(perimeter**2)

    #Dieu kien cho cac loai hinh tron khac nhau
    small_circle=0.67<=circularity<=1 and 0.9<=aspect_ratio<=1.2 and 37<h<150
    medium_circle=0.36<=circularity<0.67 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>8500 and perimeter>500
    large_circle=0.25<=circularity<0.36 and 0.9<=aspect_ratio<=1.2 and 37<h<150 and area>14500 and perimeter>700

    #Dieu kien loai tru
    exclusion_area_perimeter=area<2500 and perimeter<210

    #Kiem tra cac dieu kien
    if exclusion_area_perimeter:
        return False
    if small_circle or medium_circle or large_circle:
        return True

    return False



def detect_triangle_red(contour):
    
    area=cv2.contourArea(contour)
    
    if area<2300:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.04*perimeter,True)

    if len(approx) == 3:  #Hinh tam giac co 3 canh
        
        x,y,w,h=cv2.boundingRect(approx)
        aspect_ratio=float(w)/h

        #Dieu kien loai tru cho hinh tam giac nho
        if area<1400 and perimeter<150:
            return False

        #Dieu kien cho hinh tam giac hop le
        valid_triangle=0.9<aspect_ratio<1 and 30<w<150 and 30<h<150
        
        if valid_triangle:
            return True

    return False


def detect_rectangle_blue(contour):
    
    area=cv2.contourArea(contour)
    
    if area<1700:
        return False

    perimeter=cv2.arcLength(contour,True)
    approx=cv2.approxPolyDP(contour,0.02*perimeter,True)

    if len(approx)!=4:
        return False

    x,y,w,h=cv2.boundingRect(approx)
    aspect_ratio=float(w)/h

    #Cac dieu kien kich thuoc cho hinh chu nhat mong muon
    large_rectangle=w<150 and area>19000
    medium_rectangle=44<w<90 and 32<h<60 and 1200<area<6000 and 140<perimeter<400
    unwanted_rectangle=95<w<153 and 50<h<86 and perimeter<460 and area<8300

    #Loai tru cac truong hop chu vi va dien tich khong phu hop
    high_perimeter_exclusion=perimeter>700 and area<10000
    low_area_exclusion=perimeter>100 and area<900

    #Cac dieu kien ty le khung hinh va kich thuoc
    small_aspect_ratio=0.9<aspect_ratio<2 and 20<w<90 and 20<h<185
    large_aspect_ratio=0.9<aspect_ratio<2 and 90<w<300 and 20<h<185

    #Kiem tra cac dieu kien loai tru va mong muon
    if large_rectangle or medium_rectangle:
        return True
    
    if unwanted_rectangle or high_perimeter_exclusion or low_area_exclusion:
        return False
    
    if small_aspect_ratio or large_aspect_ratio:
        return True

    return False

# Detect traffic signs (you can use your existing functions for detecting red/blue shapes)
def detect_traffic_signs(image):
    mask_red, mask_blue = preprocess_image(image)
    contours_red, _ = cv2.findContours(mask_red, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours_blue, _ = cv2.findContours(mask_blue, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    traffic_signs = []
    red_bboxes = []

    for contour in contours_red:
        # Use your circle and triangle detection functions
        if detect_circle_red(contour):  # Assuming detect_circle_red is defined as before
            x, y, w, h = cv2.boundingRect(contour)
            red_bboxes.append((x, y, w, h))
            traffic_signs.append(('Red Circle', x, y, w, h))

        elif detect_triangle_red(contour):  # Assuming detect_triangle_red is defined as before
            x, y, w, h = cv2.boundingRect(contour)
            red_bboxes.append((x, y, w, h))
            traffic_signs.append(('Red Triangle', x, y, w, h))

    for contour in contours_blue:
        x, y, w, h = cv2.boundingRect(contour)
        # Skip red area to avoid double detection
        is_inside_red_area = any(rx <= x <= rx + rw and ry <= y <= ry + rh for rx, ry, rw, rh in red_bboxes)
        if not is_inside_red_area:
            if detect_circle_blue(contour):  # Assuming detect_circle_blue is defined as before
                traffic_signs.append(('Blue Circle', x, y, w, h))
            elif detect_rectangle_blue(contour):  # Assuming detect_rectangle_blue is defined as before
                traffic_signs.append(('Blue Rectangle', x, y, w, h))

    return traffic_signs

# Main function to process video
def main(video_file):
    vidcap = cv2.VideoCapture(video_file)
    frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_rate = vidcap.get(cv2.CAP_PROP_FPS)

    output_file = "output_video.avi"
    fourcc = cv2.VideoWriter_fourcc(*'MJPG')
    out = cv2.VideoWriter(output_file, fourcc, frame_rate, (frame_width, frame_height))

    while True:
        success, frame = vidcap.read()
        if not success:
            break

        traffic_signs = detect_traffic_signs(frame)

        for shape, x, y, w, h in traffic_signs:
            sign_image = frame[y:y + h, x:x + w]  # Crop the detected sign
            sign_image_resized = transform(sign_image)  # Preprocess the image for model input

            sign_image_resized = sign_image_resized.unsqueeze(0)  # Add batch dimension
            with torch.no_grad():
                output = model(sign_image_resized)  # Get model output
                _, predicted = torch.max(output, 1)  # Get predicted class

            label = traffic_sign_labels.get(predicted.item(), "Unknown")
            color = (0, 255, 0)  # Green for bounding box

            # Draw bounding box and label
            cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
            cv2.putText(frame, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

        # Write the frame with bounding boxes and labels to the output video
        out.write(frame)

        # Show the frame with annotations
        cv2.imshow('Traffic Sign Detection', frame)

        # Exit on pressing 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    vidcap.release()
    out.release()
    cv2.destroyAllWindows()

# Run the program with the video file
if __name__ == '__main__':
    main('video1.mp4')


  model = torch.load('traffic_sign_model.pth')


AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

In [4]:
import torch
import torch.nn as nn
import cv2
import numpy as np

# Define the model architecture (must match the model architecture used during training)
class TrafficSignModel(nn.Module):
    def __init__(self):
        super(TrafficSignModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        # Use a dummy tensor to calculate the size of the feature map after convolution and pooling
        self.dummy_tensor = torch.zeros(1, 3, 32, 32)  # Example input size
        self._initialize_fc_layers()

    def _initialize_fc_layers(self):
        # Get the size of the output after passing through conv layers
        with torch.no_grad():
            x = self.pool(torch.relu(self.conv1(self.dummy_tensor)))
            x = self.pool(torch.relu(self.conv2(x)))
            # Flatten the tensor for the fully connected layer
            flattened_size = x.numel()
        
        # Initialize fc1 with the dynamically calculated input size
        self.fc1 = nn.Linear(flattened_size, 512)
        self.fc2 = nn.Linear(512, 10)  # Assuming 10 classes for traffic signs (adjust if necessary)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define traffic sign labels (example)
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai'
}

# Initialize the model
model = TrafficSignModel()

# Load the trained weights
model.load_state_dict(torch.load('traffic_sign_model.pth'))
model.eval()  # Set the model to evaluation mode

# Function to predict traffic sign from an image
def predict_traffic_sign(image_path):
    # Load the image
    test_image = cv2.imread(image_path)  # Example image
    if test_image is None:
        print("Error: Image not found.")
        return
    
    # Preprocess the image: resize, transpose, and convert to tensor
    test_image = cv2.resize(test_image, (32, 32))  # Resize to match model input size (32x32)
    test_image = test_image.transpose((2, 0, 1))  # Change to (C, H, W)
    test_image = torch.tensor(test_image, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    # Run the model on the image
    with torch.no_grad():  # No need to compute gradients during inference
        output = model(test_image)
        _, predicted_class = torch.max(output, 1)  # Get the class with the highest probability

    # Get the predicted label
    predicted_label = traffic_sign_labels[predicted_class.item()]
    print(f'Predicted Traffic Sign: {predicted_label}')
    return predicted_label

# Example usage: Predict traffic sign for an input image
image_path = '00243_00000_000010.png'  # Change this to your image file path
predicted_label = predict_traffic_sign(image_path)


  model.load_state_dict(torch.load('traffic_sign_model.pth'))


RuntimeError: Error(s) in loading state_dict for TrafficSignModel:
	size mismatch for fc1.weight: copying a param with shape torch.Size([512, 32768]) from checkpoint, the shape in current model is torch.Size([512, 8192]).

In [5]:
import torch
import torch.nn as nn
import cv2
import numpy as np

# Define the model architecture (must match the model architecture used during training)
class TrafficSignModel(nn.Module):
    def __init__(self):
        super(TrafficSignModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        # Use a dummy tensor to calculate the size of the feature map after convolution and pooling
        self.dummy_tensor = torch.zeros(1, 3, 32, 32)  # Example input size
        self._initialize_fc_layers()

    def _initialize_fc_layers(self):
        # Get the size of the output after passing through conv layers
        with torch.no_grad():
            x = self.pool(torch.relu(self.conv1(self.dummy_tensor)))
            x = self.pool(torch.relu(self.conv2(x)))
            # Flatten the tensor for the fully connected layer
            flattened_size = x.numel()  # This will give you the size for fc1
        
        # Initialize fc1 with the dynamically calculated input size
        self.fc1 = nn.Linear(flattened_size, 512)
        self.fc2 = nn.Linear(512, 43)  # Assuming 43 classes for traffic signs

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define traffic sign labels (example)
traffic_sign_labels = {
    0: 'Thang hoac Phai, cam queo Trai',
    1: 'Cam di nguoc chieu',
    2: 'Cam re trai',
    3: 'Canh bao co tre em',
    4: 'Cam dau xe',
    5: 'Di cham thoi',
    6: 'Cam dung va do xe',
    7: 'Huong di theo vach ke duong',
    8: 'Gap khuc phai',
    9: 'Huong phai di vung phai',
    10: 'Biển báo 10', 
}

# Initialize the model
model = TrafficSignModel()

# Load the trained weights
model.load_state_dict(torch.load('traffic_sign_model.pth'))
model.eval()  # Set the model to evaluation mode

# Function to predict traffic sign from an image
def predict_traffic_sign(image_path):
    # Load the image
    test_image = cv2.imread(image_path)  # Example image
    if test_image is None:
        print("Error: Image not found.")
        return
    
    # Preprocess the image: resize, transpose, and convert to tensor
    test_image = cv2.resize(test_image, (32, 32))  # Resize to match model input size (32x32)
    test_image = test_image.transpose((2, 0, 1))  # Change to (C, H, W)
    test_image = torch.tensor(test_image, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    # Run the model on the image
    with torch.no_grad():  # No need to compute gradients during inference
        output = model(test_image)
        _, predicted_class = torch.max(output, 1)  # Get the class with the highest probability

    # Get the predicted label
    predicted_label = traffic_sign_labels[predicted_class.item()]
    print(f'Predicted Traffic Sign: {predicted_label}')
    return predicted_label

# Example usage: Predict traffic sign for an input image
image_path = '00243_00000_000010.png'  # Change this to your image file path
predicted_label = predict_traffic_sign(image_path)


  model.load_state_dict(torch.load('traffic_sign_model.pth'))


RuntimeError: Error(s) in loading state_dict for TrafficSignModel:
	size mismatch for fc1.weight: copying a param with shape torch.Size([512, 32768]) from checkpoint, the shape in current model is torch.Size([512, 8192]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([43, 512]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([43]).

In [7]:
import torch
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image

# Load PyTorch model
model = torch.load('traffic_sign_model.pth')
# model.eval()

# Định dạng chuyển đổi đầu vào cho mô hình (có thể thay đổi theo mô hình của bạn)
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Hàm xử lý và dự đoán ảnh
def predict_traffic_sign(image):
    img = Image.fromarray(image)  # Convert numpy array to PIL image
    img = transform(img).unsqueeze(0)  # Thêm batch dimension
    with torch.no_grad():
        output = model(img)
        _, predicted = torch.max(output, 1)
    return predicted.item()

# Xử lý video với OpenCV
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Dự đoán loại biển báo
        label = predict_traffic_sign(frame)  # Gọi hàm dự đoán của PyTorch

        # Vẽ thông tin lên ảnh (ví dụ: label biển báo)
        cv2.putText(frame, str(label), (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Hiển thị kết quả
        cv2.imshow('Traffic Sign Detection', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()

process_video('video1.mp4')


  model = torch.load('traffic_sign_model.pth')


TypeError: 'collections.OrderedDict' object is not callable