In [None]:
import os
import cv2
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.cuda.amp import GradScaler, autocast

# ==========================================
# 1. TU CLASE DATASET (Con una pequeña mejora)
# ==========================================
class PreloadedDataset(Dataset):
    def __init__(self, img_dir, lbl_dir, img_size=800): # Sugerencia: Subir a 800 para barandillas
        self.img_dir = img_dir
        self.lbl_dir = lbl_dir
        self.img_size = img_size

        self.img_files = sorted([
            f for f in os.listdir(img_dir)
            if f.lower().endswith((".jpg", ".png", ".jpeg"))
        ])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        lbl_path = os.path.join(self.lbl_dir, img_name.rsplit(".", 1)[0] + ".txt")

        # Imagen
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.img_size, self.img_size))
        
        # Tensor (C, H, W)
        img_tensor = torch.from_numpy(img).float().permute(2, 0, 1) / 255.0

        # Labels
        boxes, labels = self.load_yolo_labels(lbl_path, self.img_size, self.img_size)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]) # Faster R-CNN a veces lo pide para evaluar
        }

        return img_tensor, target

    def load_yolo_labels(self, path, w, h):
        boxes = []
        labels = []

        if not os.path.exists(path):
            return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

        with open(path) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 5: continue
                cls = int(float(parts[0]))
                x, y, bw, bh = map(float, parts[1:])

                # Conversión a Pascal VOC
                x1 = (x - bw / 2) * w
                y1 = (y - bh / 2) * h
                x2 = (x + bw / 2) * w
                y2 = (y + bh / 2) * h

                # Validar que la caja tenga área (evita errores en ResNet)
                if (x2 > x1) and (y2 > y1):
                    boxes.append([x1, y1, x2, y2])
                    labels.append(cls + 1) # 0 es fondo, tus clases son 1-5

        if len(boxes) == 0:
            return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

        return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

# ==========================================
# 2. FUNCIÓN DE UNIÓN (CRÍTICA)
# ==========================================
# Faster R-CNN no acepta batches normales porque cada imagen tiene n cajas distintas
def collate_fn(batch):
    return tuple(zip(*batch))

# ==========================================
# 3. MODELO Y CONFIGURACIÓN
# ==========================================
def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

# ==========================================
# 4. BUCLE DE ENTRENAMIENTO LOCAL
# ==========================================
if __name__ == "__main__":
    # Configuración de rutas
    TRAIN_IMG_DIR = "Data/train/images"
    TRAIN_LBL_DIR = "Data/train/labels"
    VAL_IMG_DIR   = "Data/val/images"
    VAL_LBL_DIR   = "Data/val/labels"

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # 1. Datasets y Loaders
    # Nota: He subido img_size a 800 para detectar mejor los railings (barandillas)
    train_dataset = PreloadedDataset(TRAIN_IMG_DIR, TRAIN_LBL_DIR, img_size=800)
    val_dataset   = PreloadedDataset(VAL_IMG_DIR, VAL_LBL_DIR, img_size=800)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, 
                                    num_workers=2, collate_fn=collate_fn)

    # 2. Modelo
    num_classes = 5 # 4 clases + 1 fondo
    model = get_model(num_classes)
    model.to(device)

    # 3. Optimización (SGD es mucho mejor para fachadas que Adam)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

    # 4. Mixed Precision (Para que quepa en tus 6-8GB de VRAM)
    scaler = GradScaler()

    print(f"Entrenando en {device} con ResNet-50...")

    for epoch in range(12):
        model.train()
        i = 0
        for images, targets in train_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            with autocast():
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()

            if i % 20 == 0:
                print(f"Epoch: {epoch}, Iter: {i}, Loss: {losses.item():.4f}")
            i += 1

        lr_scheduler.step()
        
        # Guardar checkpoint
        if epoch % 5 == 0:
            torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pth")

    torch.save(model.state_dict(), "faster_rcnn_fachadas_final.pth")
    print("¡Entrenamiento finalizado!")