In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support, classification_report

# Parameters
INPUT_SIZE = 250
NUM_CLASSES = 3  # Update this to match your dataset
MAX_OBJECTS = 10
BATCH_SIZE = 32
EPOCHS = 60
LEARNING_RATE = 0.0001

# Create custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images_dir, labels_dir, input_size, num_classes, max_objects, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.input_size = input_size
        self.num_classes = num_classes
        self.max_objects = max_objects
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')])
        self.label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith('.txt')])
        assert len(self.image_files) == len(self.label_files), "Number of images and labels must match"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        label_path = os.path.join(self.labels_dir, self.label_files[idx])

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        bboxes, class_ids = self.parse_label_file(label_path)

        return image, bboxes, class_ids

    def parse_label_file(self, label_path):
        with open(label_path, 'r') as file:
            lines = file.readlines()
        
        bboxes = np.zeros((self.max_objects, 4), dtype=np.float32)
        class_ids = np.zeros((self.max_objects, self.num_classes), dtype=np.float32)

        for i, line in enumerate(lines):
            if i >= self.max_objects:
                break
            parts = line.strip().split()
            class_id = int(parts[0])
            x_center, y_center, width, height = map(float, parts[1:])
            bboxes[i] = [x_center - width / 2, y_center - height / 2, width, height]
            class_ids[i, class_id] = 1.0

        return bboxes, class_ids

transform = T.Compose([
    T.Resize((INPUT_SIZE, INPUT_SIZE)),
    T.ToTensor()
])

train_dataset = CustomDataset('../yolo_data_v3/train/images', '../yolo_data_v3/train/labels', INPUT_SIZE, NUM_CLASSES, MAX_OBJECTS, transform)
val_dataset = CustomDataset('../yolo_data_v3/val/images', '../yolo_data_v3/val/labels', INPUT_SIZE, NUM_CLASSES, MAX_OBJECTS, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define the model
class ObjectDetectionModel(nn.Module):
    def __init__(self, num_classes, num_anchors, input_size):
        super(ObjectDetectionModel, self).__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.input_size = input_size
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(512 * (input_size // 32) * (input_size // 32), 1024),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(1024, 256),
            nn.LeakyReLU(0.1)
        )
        
        self.bbox_head = nn.Linear(256, num_anchors * 4)
        self.class_head = nn.Linear(256, num_anchors * num_classes)

    def forward(self, x):
        x = self.features(x)
        bboxes = self.bbox_head(x).view(-1, self.num_anchors, 4)
        class_logits = self.class_head(x).view(-1, self.num_anchors, self.num_classes)
        return bboxes, class_logits

    def inference(self, x, score_threshold=0.5, iou_threshold=0.5):
        bboxes, class_logits = self.forward(x)
        
        # Convert logits to probabilities
        class_probs = torch.softmax(class_logits, dim=-1)
        
        # Filter out low score boxes
        scores, labels = class_probs.max(dim=-1)
        keep = scores > score_threshold
        
        bboxes = bboxes[keep]
        scores = scores[keep]
        labels = labels[keep]
        
        # Apply NMS
        keep = nms(bboxes, scores, iou_threshold)
        
        return bboxes[keep], labels[keep], scores[keep]

def nms(boxes, scores, iou_threshold=0.5):
    if len(boxes) == 0:
        return []
    
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    
    areas = (x2 - x1) * (y2 - y1)
    _, order = scores.sort(0, descending=True)
    
    keep = []
    while order.numel() > 0:
        i = order[0]
        keep.append(i.item())
        
        if order.numel() == 1:
            break
        
        xx1 = torch.max(x1[i], x1[order[1:]])
        yy1 = torch.max(y1[i], y1[order[1:]])
        xx2 = torch.min(x2[i], x2[order[1:]])
        yy2 = torch.min(y2[i], y2[order[1:]])
        
        w = (xx2 - xx1).clamp(min=0)
        h = (yy2 - yy1).clamp(min=0)
        inter = w * h
        
        iou = inter / (areas[i] + areas[order[1:]] - inter)
        
        order = order[1:][iou <= iou_threshold]
    
    return keep

model = ObjectDetectionModel(NUM_CLASSES, MAX_OBJECTS, INPUT_SIZE)

# Define loss functions
def bbox_loss_fn(pred_bboxes, true_bboxes):
    return nn.functional.mse_loss(pred_bboxes, true_bboxes)

def class_loss_fn(pred_class_logits, true_class_ids):
    return nn.functional.binary_cross_entropy_with_logits(pred_class_logits, true_class_ids)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Initialize lists to store loss and accuracy
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# Training loop
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    correct_train = 0
    total_train = 0
    for images, true_bboxes, true_class_ids in train_loader:
        optimizer.zero_grad()
        pred_bboxes, pred_class_logits = model(images)
        
        bbox_loss = bbox_loss_fn(pred_bboxes, true_bboxes)
        class_loss = class_loss_fn(pred_class_logits, true_class_ids)
        
        loss = bbox_loss + class_loss
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted_train = torch.max(pred_class_logits, 2)
        _, true_train = torch.max(true_class_ids, 2)
        total_train += true_train.size(0) * true_train.size(1)
        correct_train += (predicted_train == true_train).sum().item()

    train_losses.append(running_loss / len(train_loader))
    train_accuracies.append(correct_train / total_train)
    
    # Validation loop
    model.eval()
    val_loss = 0
    correct_val = 0
    total_val = 0
    all_true_labels = []
    all_pred_labels = []
    with torch.no_grad():
        for images, true_bboxes, true_class_ids in val_loader:
            pred_bboxes, pred_class_logits = model(images)
            
            bbox_loss = bbox_loss_fn(pred_bboxes, true_bboxes)
            class_loss = class_loss_fn(pred_class_logits, true_class_ids)
            
            val_loss += bbox_loss.item() + class_loss.item()
            
            _, predicted_val = torch.max(pred_class_logits, 2)
            _, true_val = torch.max(true_class_ids, 2)
            total_val += true_val.size(0) * true_val.size(1)
            correct_val += (predicted_val == true_val).sum().item()
            
            # Collect all true and predicted class labels for confusion matrix
            true_class_ids_np = true_class_ids.cpu().numpy()
            pred_class_logits_np = pred_class_logits.cpu().numpy()
            all_true_labels.extend(np.argmax(true_class_ids_np, axis=2).flatten())
            all_pred_labels.extend(np.argmax(pred_class_logits_np, axis=2).flatten())

    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(correct_val / total_val)
    
    precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_pred_labels, average='weighted', zero_division=1)
    
    print(f"Epoch {epoch + 1}, Train Loss: {train_losses[-1]}, Val Loss: {val_losses[-1]}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}")

# Compute confusion matrix
conf_matrix = confusion_matrix(all_true_labels, all_pred_labels, labels=list(range(NUM_CLASSES)))
disp = ConfusionMatrixDisplay(conf_matrix, display_labels=list(range(NUM_CLASSES)))
disp.plot(cmap=plt.cm.Blues)
plt.show()

# Plot loss curves
plt.figure(figsize=(10, 5))
plt.plot(range(EPOCHS), train_losses, label='Train Loss')
plt.plot(range(EPOCHS), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()
plt.show()

# Plot accuracy curves
plt.figure(figsize=(10, 5))
plt.plot(range(EPOCHS), train_accuracies, label='Train Accuracy')
plt.plot(range(EPOCHS), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curves')
plt.legend()
plt.show()

# Classification report for detailed metrics
print(classification_report(all_true_labels, all_pred_labels, target_names=[f'Class {i}' for i in range(NUM_CLASSES)], zero_division=1))

# Visualization function
def visualize_predictions(images, pred_bboxes, pred_class_logits):
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    pred_bboxes = pred_bboxes.cpu().detach().numpy()
    pred_class_logits = pred_class_logits.cpu().detach().numpy()
    pred_class_ids = np.argmax(pred_class_logits, axis=-1)
    
    for i in range(images.shape[0]):
        plt.imshow(images[i])
        for j in range(pred_bboxes.shape[1]):
            bbox = pred_bboxes[i, j]
            class_id = pred_class_ids[i, j]
            confidence = np.max(pred_class_logits[i, j])
            if confidence > 0.5:  # Only show high-confidence predictions
                plt.gca().add_patch(plt.Rectangle(
                    (bbox[0] * INPUT_SIZE, bbox[1] * INPUT_SIZE),
                    bbox[2] * INPUT_SIZE,
                    bbox[3] * INPUT_SIZE,
                    fill=False,
                    edgecolor='red',
                    linewidth=2
                ))
                plt.text(
                    bbox[0] * INPUT_SIZE,
                    bbox[1] * INPUT_SIZE - 5,
                    f'Class {class_id}, Conf: {confidence:.2f}',
                    bbox=dict(facecolor='red', alpha=0.5),
                    fontsize=12,
                    color='white'
                )
        plt.show()

# Visualize predictions on validation set
model.eval()
with torch.no_grad():
    for images, true_bboxes, true_class_ids in val_loader:
        pred_bboxes, pred_class_logits = model(images)
        visualize_predictions(images, pred_bboxes, pred_class_logits)
        break  # Visualize only the first batch
