In [1]:
import cv2
import torch
from torchvision import transforms, models 
from PIL import Image
import torch.nn as nn
from matplotlib import pyplot as plt


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


class CustomVGG16(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.7): 
        super(CustomVGG16, self).__init__()

        model = models.vgg16(pretrained=True)
        
        # Unfreeze more layers for fine-tuning
        for param in model.features[8:].parameters():
            param.requires_grad = True

        self.features = nn.Sequential(
            *list(model.features.children())[:24], 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            *list(model.features.children())[24:], 
        )
        self.avgpool = model.avgpool

        # Add new classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


model = CustomVGG16(num_classes=7) 
model.load_state_dict(torch.load('best_model_checkpoint.pth', map_location='cpu'))  
model.to('cpu') 
model.eval()  


class_names = ['Automotive_commercial', 'Entertainment_commercial', 'Financial_commercial', 
               'Food_commercial', 'Healthcare_commercial', 'Insurance_commercial', 
               'Technology_Electronics_commercial']


cap = cv2.VideoCapture(0) 

if not cap.isOpened():
    print("Unable to read capture feed.")

while True:
 
    ret, frame = cap.read()
    if not ret:
        break

    orig_frame = frame.copy()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(frame)

   
    transformed_image = transform(pil_image)
    batch_t = torch.unsqueeze(transformed_image, 0).to('cpu') 

    
    with torch.no_grad():
        out = model(batch_t)

    _, predicted = torch.max(out, 1)
    label = class_names[predicted.item()]

    orig_frame = cv2.cvtColor(orig_frame, cv2.COLOR_RGB2BGR)

 
    cv2.putText(orig_frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    
    cv2.imshow('frame', orig_frame)

 
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()




KeyboardInterrupt: 