In [None]:
import os
import torch
import torchvision
import xml.etree.ElementTree as ET
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as T
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from tqdm import tqdm
import gc

class VOCDataset(Dataset):
    def __init__(self, root, image_set="train", transforms=None):
        self.root = root
        self.transforms = transforms
        image_set_file = os.path.join(root, "ImageSets", "Main", f"{image_set}.txt")
        with open(image_set_file) as f:
            self.image_ids = [x.strip() for x in f.readlines()]
        self.img_dir = os.path.join(root, "JPEGImages")
        self.ann_dir = os.path.join(root, "Annotations")
        self.classes = ["__background__"] + self._find_classes()
        self.image_files = sorted([f for f in os.listdir(self.img_dir) if f.endswith((".jpg", ".png"))])

    def _find_classes(self):
        cls = set()
        for fname in os.listdir(self.ann_dir):
            tree = ET.parse(os.path.join(self.ann_dir, fname))
            for obj in tree.findall("object"):
                cls.add(obj.find("name").text)
        return sorted(list(cls))

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.img_dir, f"{image_id}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{image_id}.xml")

        img = Image.open(img_path).convert("RGB")
        tree = ET.parse(ann_path)
        root = tree.getroot()

        boxes, labels = [], []
        for obj in root.findall("object"):
            label = obj.find("name").text
            labels.append(self.classes.index(label))

            bndbox = obj.find("bndbox")
            bbox = [
                float(bndbox.find("xmin").text),
                float(bndbox.find("ymin").text),
                float(bndbox.find("xmax").text),
                float(bndbox.find("ymax").text),
            ]
            boxes.append(bbox)

        if len(boxes) == 0:
            # 🔁 Prova a caricare un'altra immagine se questa è vuota
            return self.__getitem__((idx + 1) % len(self))

            
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.image_ids)
    

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)   
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def collate_fn(batch):
    return tuple(zip(*batch))


import pandas as pd
from torchmetrics.detection.mean_ap import MeanAveragePrecision

def evaluate_metrics(experiment_name,model, data_loader, device,epoch,set):
    csv_path = f"{experiment_name}/{set}_metrics.csv"
    mAP50 = 0.0
    model.eval()
    metric = MeanAveragePrecision()
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)

    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="\t\tEvaluating"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            metric.update(outputs, targets)
            torch.cuda.empty_cache()

    results = metric.compute()

    # Print
    
    for k, v in results.items():
        if isinstance(v, torch.Tensor):
            if(k=="map_50"):
                print(f"\t\t{k}: {v.item() if v.numel() == 1 else v}")
                mAP50 = v.item()
    metric.reset()
    # Save to CSV
    summary = {k: v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else str(v)
               for k, v in results.items()}
    #summary["epoch"] = epoch
    summary = dict([("epoch", epoch)] + list(summary.items()))
    df = pd.DataFrame([summary])
    write_header = not os.path.exists(csv_path)
    df.to_csv(csv_path, mode='a', header=write_header, index=False)
    #print(f"\nSaved metrics to: {csv_path}")
    return mAP50

import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

def show_predictions(experiment_name,model, dataset, device, score_threshold=0.25,class_names=[]):
    model.eval()
    output_dir=f"{experiment_name}/predictions"
    os.makedirs(output_dir, exist_ok=True)

    for idx in range(len(dataset)):
        img, target = dataset[idx]
        img_tensor = img.to(device).unsqueeze(0)

        with torch.no_grad():
            output = model(img_tensor)[0]

        # Filtro predizioni con punteggio basso
        keep = output['scores'] >= score_threshold
        pred_boxes = output['boxes'][keep].cpu()
        pred_labels = output['labels'][keep].cpu()
        scores = output['scores'][keep].cpu()

        # Disegna box predetti
        pred_img = draw_bounding_boxes((img * 255).byte(), pred_boxes, labels=[f"{class_names[l.item()]}: {s.item():.2f}" for l,s in zip(pred_labels,scores)],
                                       colors="green", width=2,font_size=20,font="arial")

        # Disegna box reali
        # true_boxes = target['boxes'].cpu()
        # true_labels = target['labels'].cpu()
        # full_img = draw_bounding_boxes(pred_img, true_boxes, labels=[f"T:{l.item()}" for l in true_labels],
        #                                colors="green", width=2)

        

        # Nome originale immagine
        original_filename = dataset.image_files[idx] # Es: "img001.jpg"
        output_filename = os.path.splitext(original_filename)[0] + "_pred.jpg"

        # Salva immagine
        img_pil = to_pil_image(pred_img)
        img_pil.save(os.path.join(output_dir, output_filename))
        
    print(f"saved prediction to: {output_dir}")

def load_custom_fasterrcnn_model(model_path, num_classes):
    """
    Carica un modello Faster R-CNN ResNet50 addestrato per un task personalizzato,
    usando i pesi salvati con `state_dict`.
    
    Args:
        model_path (str): Percorso del file .pth contenente lo state_dict salvato.
        num_classes (int): Numero totale di classi (incluso lo sfondo).
        
    Returns:
        torch.nn.Module: Modello caricato e spostato su GPU (se disponibile).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Carica il modello base
    model = get_model(num_classes=num_classes)

    # Carica i pesi salvati
    model.load_state_dict(torch.load(model_path, map_location=device,weights_only=True))

    # Sposta il modello sul dispositivo corretto
    model.to(device)
    model.eval()

    return model

def main(experiment_name):
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    dataset_root = "dataset"
    save_weights_dir = f"{experiment_name}/weights"
    num_epochs = 10
    writer = SummaryWriter(f"logs/faster_rcnn50/{experiment_name}")

    transforms = T.Compose([
        T.ToTensor()
    ])

    dataset_train = VOCDataset(dataset_root, image_set="train", transforms=transforms)
    dataset_valid = VOCDataset(dataset_root, image_set="val", transforms=transforms)
    dataset_test = VOCDataset(dataset_root, image_set="test", transforms=transforms)
    data_loader_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=collate_fn)
    data_loader_valid = DataLoader(dataset_valid, batch_size=1, shuffle=True, collate_fn=collate_fn)
    data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True, collate_fn=collate_fn)


    class_names = dataset_train.classes
    model = get_model(num_classes=len(dataset_train.classes))
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    best_score_valid = 0.0
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        total_loss_valid = 0.0
        print(f"Epoca {epoch+1}/{num_epochs}: \n\tLoss sul train set:")
        for images, targets in tqdm(data_loader_train, desc=f"\t\t"):
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            losses.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            
            total_loss += losses.item()
            torch.cuda.empty_cache()
            

        
        lr_scheduler.step()
        
        print(f"\t\t-Score: {total_loss:.4f}")
        print(f"\tLoss sul Valid set:")
        with torch.set_grad_enabled(False):
            for images, targets in tqdm(data_loader_valid, desc=f"\t\t"):
                images = list(img.to(device) for img in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                
                total_loss_valid += losses.item()
                
                torch.cuda.empty_cache()
            
        print(f"\t\t-Score: {total_loss_valid:.4f}")    
        writer.add_scalar("loss/train", total_loss, epoch+1)
        writer.add_scalar("loss/valid", total_loss_valid, epoch+1)
        
        

        
        model.eval()  #Entra in modalità evaluation
        #EVALUATION PER EPOCA SUL VALID SET
        print("\tEvaluation sul valid set:")
        current_score_valid = evaluate_metrics(experiment_name=experiment_name,model=model, data_loader=data_loader_valid, device=device, epoch=epoch+1,set="valid")

        print("\n")

        os.makedirs(save_weights_dir, exist_ok=True)
        
        if current_score_valid > best_score_valid:
            best_score_valid = current_score_valid
            torch.save(model.state_dict(), f"{experiment_name}/weights/fasterrcnn_voc_best.pth")
                
            
            #print(f"New best model saved with score: {best_score:.4f}")

        torch.save(model.state_dict(), f"{experiment_name}/weights/fasterrcnn_voc_last.pth")

        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    del model
    

    print("Training complete. Model saved.")
    

    #EVALUATION SUL TEST SET CON IL BEST MODEL

    model =  load_custom_fasterrcnn_model(model_path=f"{experiment_name}/weights/fasterrcnn_voc_best.pth", num_classes=len(dataset_train.classes))
    print("Score sul test set:\t")

    evaluate_metrics(experiment_name=experiment_name,model=model, data_loader=data_loader_test, device=device, epoch=epoch+1,set="test")
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    #fase di inferenza

    
    show_predictions(experiment_name=experiment_name,model=model,dataset=dataset_test,device=device,score_threshold=0.25,class_names=class_names)
    
    del model
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    

if __name__ == "__main__":
    main("prova5")
    



#mettere nel file csv le metriche di ogni epoca ✅
#inserire weight and biases (oppure tensorboard) ✅
#salvare il best.pth del modello ✅
#fare la evaluation sul valid set oltre a quella di train(aggiungi la loss/val su tensorboard), salvare il best.pth in base alla map sul valid set(segna epoca nel file) ✅
#fare evaluation sul test set con il best.pth salvato ✅
#refactor con divisione in celle, oppure vari file.py
#dare le etichette delle classi e le probabilità nelle predizioni ✅
#cambiare il colore della bounding box in base alla classe in maniera randomica 
#(opzionale) rendere il codice efficiente(memoria) anche per dispositivi non cuda 






Epoca 1/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:11<00:00, 11.37it/s]


		-Score: 49.8442
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:06<00:00,  4.02it/s]


		-Score: 8.8469
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.18it/s]


		map_50: 0.3200683891773224


Epoca 2/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 11.85it/s]


		-Score: 35.1666
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.40it/s]


		-Score: 7.0129
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.25it/s]


		map_50: 0.3753499984741211


Epoca 3/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 11.84it/s]


		-Score: 23.1115
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.48it/s]


		-Score: 5.9534
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.32it/s]


		map_50: 0.40663155913352966


Epoca 4/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 11.86it/s]


		-Score: 16.2887
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s]


		-Score: 6.0756
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:06<00:00,  4.07it/s]


		map_50: 0.4163799583911896


Epoca 5/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:11<00:00, 11.72it/s]


		-Score: 14.2460
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.35it/s]


		-Score: 6.2530
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.67it/s]


		map_50: 0.41756150126457214


Epoca 6/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 12.20it/s]


		-Score: 13.6011
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.32it/s]


		-Score: 6.7496
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


		map_50: 0.40954074263572693


Epoca 7/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 12.17it/s]


		-Score: 12.7985
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.60it/s]


		-Score: 6.5547
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.72it/s]


		map_50: 0.4075435698032379


Epoca 8/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 11.98it/s]


		-Score: 12.7087
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.40it/s]


		-Score: 6.6675
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.23it/s]


		map_50: 0.40744587779045105


Epoca 9/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 12.06it/s]


		-Score: 12.5950
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.47it/s]


		-Score: 6.6241
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.42it/s]


		map_50: 0.40746641159057617


Epoca 10/10: 
	Loss sul train set:


		: 100%|██████████| 130/130 [00:10<00:00, 12.03it/s]


		-Score: 12.4641
	Loss sul Valid set:


		: 100%|██████████| 25/25 [00:05<00:00,  4.50it/s]


		-Score: 6.6412
	Evaluation sul valid set:


		Evaluating: 100%|██████████| 25/25 [00:05<00:00,  4.28it/s]


		map_50: 0.40723878145217896


Training complete. Model saved.
Score sul test set:	


		Evaluating: 100%|██████████| 15/15 [00:00<00:00, 16.37it/s]


		map_50: 0.42436742782592773
saved prediction to: prova5/predictions


