In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import torch.optim as optim
import torch.nn as nn

In [None]:
class PotholeDataset(Dataset):
    def __init__(self, csv_file, transform=None, patch_size=(800, 800), overlap=100):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.patch_size = patch_size
        self.overlap = overlap

    def __len__(self):
        return len(self.data['image_path'].unique())

    def __getitem__(self, idx):
        img_path = self.data['image_path'].unique()[idx]
        image = Image.open(img_path).convert("RGB")

        rows = self.data[self.data['image_path'] == img_path]
        boxes = []
        labels = []
        for _, row in rows.iterrows():
            x, y, w, h = row['x'], row['y'], row['w'], row['h']
            boxes.append([x, y, x + w, y + h])
            labels.append(1)

        patches, patch_boxes, patch_labels = self.create_patches(image, boxes, labels)

        return patches, patch_boxes, patch_labels

    def create_patches(self, image, boxes, labels):
        patches = []
        patch_boxes = []
        patch_labels = []
        width, height = image.size
        pw, ph = self.patch_size

        for y_offset in range(0, height, ph - self.overlap):
            for x_offset in range(0, width, pw - self.overlap):
                patch_width = min(pw, width - x_offset)
                patch_height = min(ph, height - y_offset)
                
                patch = image.crop((x_offset, y_offset, x_offset + patch_width, y_offset + patch_height))
                
                patch_bboxes = []
                patch_labels_local = []
                for box, label in zip(boxes, labels):
                    x_min, y_min, x_max, y_max = box
                    if (x_min >= x_offset and x_max <= x_offset + patch_width) and (y_min >= y_offset and y_max <= y_offset + patch_height):
                        adjusted_box = [
                            x_min - x_offset,
                            y_min - y_offset,
                            x_max - x_offset,
                            y_max - y_offset
                        ]
                        
                        if adjusted_box[2] > adjusted_box[0] and adjusted_box[3] > adjusted_box[1]:
                            patch_bboxes.append(adjusted_box)
                            patch_labels_local.append(label)

                if len(patch_bboxes) == 0:
                    patch_labels_local = [0]  
                
                if self.transform:
                    patch = self.transform(patch)
                patches.append(patch)
                patch_boxes.append(torch.tensor(patch_bboxes, dtype=torch.float32))
                patch_labels.append(torch.tensor(patch_labels_local, dtype=torch.int64))

        return patches, patch_boxes, patch_labels

transform = transforms.Compose([
    transforms.ToTensor(),
])


In [None]:
csv_file_path = "/Users/vivekreddypalsani/Documents/mahindra/mahindra-3rd/3rd-fall/DIP/project/Dataset/train/Train_data/Cropped_data/csv/combined_samples.csv"  # Replace with your CSV file path
dataset = PotholeDataset(csv_file=csv_file_path, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_loss = 0

    for patches, patch_boxes, patch_labels in data_loader:
        optimizer.zero_grad()
        
        for patch, boxes, labels in zip(patches[0], patch_boxes[0], patch_labels[0]):
            patch = patch.to(device)
            boxes = boxes.to(device)
            labels = labels.to(device)

            target = {"boxes": boxes, "labels": labels}

            if len(boxes) == 0: 
                target = {"boxes": torch.empty((0, 4), dtype=torch.float32).to(device),
                          "labels": torch.empty((0,), dtype=torch.int64).to(device)}

            loss_dict = model([patch], [target])
            losses = sum(loss for loss in loss_dict.values())
            epoch_loss += losses.item()

            losses.backward()

        optimizer.step()

    print(f"Epoch Loss: {epoch_loss:.4f}")

print("Training complete.")