In [14]:
# Import required libraries
import torch
import torchvision
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from sklearn.metrics import precision_score, recall_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
from IPython.display import display
import io
from google.colab import files

print("Setting up environment...")

Setting up environment...


In [15]:
# Install YOLOv5
!git clone https://github.com/ultralytics/yolov5.git
%cd yolov5
!pip install -r requirements.txt
%cd ..

# Add YOLOv5 to path
import sys
sys.path.append('./yolov5')

# Download YOLO weights if not exists
if not os.path.exists('yolov5n.pt'):
    !wget https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5n.pt

fatal: destination path 'yolov5' already exists and is not an empty directory.
/content/yolov5
/content


In [16]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class CombinedModel(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.save_hyperparameters()

        # Classification branch (MobileNetV2)
        self.classifier = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        self.classifier.classifier[1] = nn.Linear(self.classifier.classifier[1].in_features, num_classes)

        # Object detection branch (YOLOv5)
        from models.yolo import Model as YOLOModel
        self.detector = YOLOModel('./yolov5/models/yolov5n.yaml')

        # Load YOLO weights
        checkpoint = torch.load('yolov5n.pt', map_location='cpu')
        self.detector.load_state_dict(checkpoint['model'].state_dict())

        # Freeze YOLO weights and set to eval mode
        self.detector.eval()
        for param in self.detector.parameters():
            param.requires_grad = False
        for module in self.detector.modules():
            if isinstance(module, nn.BatchNorm2d):
                module.track_running_stats = False

    def forward(self, x):
        x = x.to(self.device)
        class_output = self.classifier(x)

        try:
            det_output = self.detector(x)
        except Exception as e:
            print(f"Detection error: {e}")
            det_output = None

        return class_output, det_output

    def training_step(self, batch, batch_idx):
        images, labels = batch
        class_output, det_output = self(images)
        class_loss = F.cross_entropy(class_output, labels)

        det_loss = torch.tensor(0.0, device=self.device)
        if det_output is not None and hasattr(det_output, 'loss'):
            det_loss = sum(det_output.loss.values())

        total_loss = class_loss + 0.5 * det_loss
        self.log('train_loss', total_loss)
        return total_loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

Using device: cuda


In [17]:
class ImageDetector:
    def __init__(self, model=None, model_path=None, device=None):
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if model is not None:
            self.model = model
        elif model_path is not None:
            checkpoint = torch.load(model_path, map_location=self.device)
            self.model = CombinedModel().to(self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
        else:
            raise ValueError("Either model or model_path must be provided")

        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                       'dog', 'frog', 'horse', 'ship', 'truck']

        # COCO classes for YOLOv5
        self.coco_classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
                           'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
                           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
                           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
                           'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
                           'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
                           'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
                           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
                           'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
                           'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
                           'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
                           'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
                           'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
                           'scissors', 'teddy bear', 'hair drier', 'toothbrush']

    def process_image(self, image_data):
        """Process a single image and return both classification and detection results"""
        try:
            # Convert image data to PIL Image
            if isinstance(image_data, bytes):
                image = Image.open(io.BytesIO(image_data)).convert('RGB')
            elif isinstance(image_data, str):
                image = Image.open(image_data).convert('RGB')
            elif isinstance(image_data, Image.Image):
                image = image_data.convert('RGB')
            else:
                raise ValueError("Unsupported image data type")

            # Save original image for drawing
            original_image = np.array(image)

            # Transform image for model
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)

            # Get predictions
            with torch.no_grad():
                class_output, det_output = self.model(image_tensor)

                # Get classification result
                _, predicted = torch.max(class_output, 1)
                class_name = self.classes[predicted.item()]
                class_probs = torch.softmax(class_output, dim=1)[0]
                class_confidence = class_probs[predicted].item()

                # Get top-3 classifications
                top3_values, top3_indices = torch.topk(class_probs, 3)
                top3_classes = [(self.classes[idx], val.item()) for idx, val in zip(top3_indices, top3_values)]

                # Process detection results
                detections = []
                if det_output is not None and hasattr(det_output, 'pred') and len(det_output.pred[0]):
                    for *xyxy, conf, cls in det_output.pred[0].cpu().numpy():
                        if conf > 0.3:  # Confidence threshold
                            detections.append({
                                'bbox': [int(x) for x in xyxy],
                                'confidence': float(conf),
                                'class': self.coco_classes[int(cls)]
                            })

            return {
                'image': original_image,
                'classification': {
                    'top_class': class_name,
                    'top_confidence': class_confidence,
                    'top3_predictions': top3_classes
                },
                'detections': detections
            }

        except Exception as e:
            print(f"Error processing image: {str(e)}")
            return None

    def visualize_results(self, results):
        """Visualize both classification and detection results"""
        if results is None:
            print("No results to visualize")
            return

        try:
            plt.figure(figsize=(12, 6))
            plt.imshow(cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB))

            # Draw detection boxes
            if results['detections']:
                for det in results['detections']:
                    x1, y1, x2, y2 = det['bbox']
                    plt.gca().add_patch(plt.Rectangle(
                        (x1, y1), x2-x1, y2-y1,
                        fill=False, color='red', linewidth=2
                    ))
                    # Add label
                    plt.text(x1, y1-10,
                            f"{det['class']} {det['confidence']:.2%}",
                            color='red', fontsize=8,
                            bbox=dict(facecolor='white', alpha=0.8))

            # Add classification results to title
            plt.title(f"Classification: {results['classification']['top_class']} ({results['classification']['top_confidence']:.2%})")
            plt.axis('off')
            plt.show()

            # Print detailed results
            print("\n=== Detection Results ===")
            if results['detections']:
                for i, det in enumerate(results['detections'], 1):
                    print(f"Object {i}: {det['class']} (Confidence: {det['confidence']:.2%})")
            else:
                print("No objects detected")

            print("\n=== Classification Results ===")
            print("Top 3 predictions:")
            for class_name, confidence in results['classification']['top3_predictions']:
                print(f"- {class_name}: {confidence:.2%}")

        except Exception as e:
            print(f"Error visualizing results: {str(e)}")

    def process_uploaded_file(self, file_content):
        """Process an uploaded file and show results"""
        try:
            results = self.process_image(file_content)
            if results is not None:
                self.visualize_results(results)
                return True
            return False
        except Exception as e:
            print(f"Error processing uploaded file: {str(e)}")
            return False

In [18]:
def train_model():
    """Train the combined model"""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    print("Loading datasets...")
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                               download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    print("Initializing model...")
    model = CombinedModel().to(device)
    trainer = Trainer(max_epochs=4, accelerator='auto', devices=1)

    print("Starting training...")
    try:
        trainer.fit(model, train_loader)
        print("Training completed successfully!")

        torch.save({
            'model_state_dict': model.state_dict(),
            'device': str(device)
        }, 'combined_model.pt')
        print("Model saved successfully!")

        return model
    except Exception as e:
        print(f"Training error: {str(e)}")
        return None

def upload_and_detect(detector):
    """Handle file upload and detection"""
    print("Please upload an image file...")
    uploaded = files.upload()

    for filename, content in uploaded.items():
        print(f"\nProcessing {filename}...")
        detector.process_uploaded_file(content)

# Main execution
if __name__ == "__main__":
    print("Starting training pipeline...")
    model = train_model()

    if model is not None:
        print("\nInitializing image detector...")
        detector = ImageDetector(model=model)

        while True:
            user_input = input("\nWould you like to process an image? (yes/no): ")
            if user_input.lower() != 'yes':
                break
            upload_and_detect(detector)

    print("Program completed!")

Starting training pipeline...
Loading datasets...
Files already downloaded and verified



                 from  n    params  module                                  arguments                     
  0                -1  1      1760  models.common.Conv                      [3, 16, 6, 2, 2]              
  1                -1  1      4672  models.common.Conv                      [16, 32, 3, 2]                
  2                -1  1      4800  models.common.C3                        [32, 32, 1]                   
  3                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                
  4                -1  2     29184  models.common.C3                        [64, 64, 2]                   
  5                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  6                -1  3    156928  models.common.C3                        [128, 128, 3]                 
  7                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              
  8                -1  1    296448  

Initializing model...


 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     22912  models.common.C3                        [128, 64, 1, False]           
 18                -1  1     36992  models.common.Conv                      [64, 64, 3, 2]                
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1     74496  models.common.C3                        [128, 128, 1, False]          
 21                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1    296448  models.common.C3                        [256, 256, 1, False]          
 24      [17, 20, 23]  1    115005  models.yolo.Detect                      [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90,

Starting training...


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=4` reached.


Training completed successfully!
Model saved successfully!

Initializing image detector...

Would you like to process an image? (yes/no): Yes
Please upload an image file...


Saving tesla.jpg to tesla.jpg

Processing tesla.jpg...

=== Detection Results ===
No objects detected

=== Classification Results ===
Top 3 predictions:
- automobile: 96.90%
- cat: 1.62%
- horse: 0.52%

Would you like to process an image? (yes/no): yes
Please upload an image file...


Saving plane1.jpg to plane1 (2).jpg

Processing plane1 (2).jpg...

=== Detection Results ===
No objects detected

=== Classification Results ===
Top 3 predictions:
- airplane: 82.82%
- bird: 8.82%
- automobile: 2.49%

Would you like to process an image? (yes/no): Yes
Please upload an image file...


Saving Ronald.jpeg to Ronald (1).jpeg

Processing Ronald (1).jpeg...

=== Detection Results ===
No objects detected

=== Classification Results ===
Top 3 predictions:
- dog: 49.90%
- bird: 25.79%
- cat: 17.37%

Would you like to process an image? (yes/no): no
Program completed!
