In [None]:
import io
from PIL import Image
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from kafka import KafkaConsumer

In [None]:
model_path = '/home/jovyan/UCF_Crime.pth'

In [None]:
# Kafka consumer setup
consumer = KafkaConsumer(
    'frame_topic',
    bootstrap_servers='localhost:9092',
    auto_offset_reset='earliest'
)

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
import torch.nn.functional as F  # Add this line

# Define your neural network architecture
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        # Pre-trained DenseNet121 without the final fully connected layers
        self.densenet = models.densenet121(pretrained=True)
        self.densenet = nn.Sequential(*list(self.densenet.children())[:-1])  # Remove last layer

    def forward(self, x):
        # Use the pre-trained DenseNet121 model
        return self.densenet(x)

class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        return x

class FinalModel(nn.Module):
    def __init__(self, num_classes):
        super(FinalModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.classifier = Classifier(num_classes)

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

# Define and compile the model
model = FinalModel(15)

# Load the weights from the .pth file
model_weights_path = '/home/jovyan/UCF_Crime.pth'
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
model.eval()  # Set the model to evaluation mode

In [None]:
# Define a list to store the predicted class indices
predicted_classes = []

In [None]:
# Define the class labels
CLASS_LABELS = ['Abuse','Arrest','Arson','Assault','Normals','Explosion','Fighting',
                'Burglary', 'Protest', 'RoadAccidents','Robbery','Shooting','Shoplifting',
                'Stealing','Vandalism']

In [None]:
# Define transformation for image preprocessing
transform = transforms.Compose([
transforms.Resize((64, 64)),  # Resize image to fit the input size of the model
        transforms.ToTensor(),  # Convert PIL image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize image
    ])

In [None]:
# Kafka consumer loop
for message in consumer:
    print("\nMessage Received from Kafka Producer")
    try:
        # Preprocess image
        image_stream = io.BytesIO(message.value)
        image = Image.open(image_stream).convert('RGB')
        processed_image = transform(image).unsqueeze(0)  # Add batch dimension
        
        # Convert the processed image tensor to numpy array
        image_np = processed_image.squeeze(0).permute(1, 2, 0).numpy()

        # Display the image
        plt.imshow(image)
        plt.axis('off')
        plt.show()
            
        # Perform inference
        with torch.no_grad():
            output = model(processed_image)
            probabilities = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()

            if CLASS_LABELS:
                print("Predicted Class Label:", CLASS_LABELS[predicted_class])

        # Send prediction back to Kafka or do further processing
        print("Predicted Class:", predicted_class)

    except Exception as e:
        print("Failed to process or display image:", e)