In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import torchvision
import torch.optim as optim

# Definir el Dataset personalizado
class MosquitoDataset(Dataset):
    def __init__(self, images_folder, labels_folder, transform=None):
        self.images_folder = images_folder
        self.labels_folder = labels_folder
        self.transform = transform
        self.images = [f for f in os.listdir(images_folder) if f.endswith('.jpeg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_file = self.images[idx]
        img_path = os.path.join(self.images_folder, img_file)
        label_file = img_file.replace('.jpeg', '.txt')
        label_path = os.path.join(self.labels_folder, label_file)

        # Cargar la imagen
        image = Image.open(img_path).convert("RGB")

        # Leer la anotación
        boxes = []
        labels = []
        with open(label_path, 'r') as file:
            for line in file.readlines():
                class_id, x_center, y_center, width, height = map(float, line.strip().split())
                labels.append(int(class_id))
                # Convertir las anotaciones de formato (x_center, y_center, width, height)
                # al formato (xmin, ymin, xmax, ymax) que espera el SSD
                xmin = (x_center - width / 2)
                ymin = (y_center - height / 2)
                xmax = (x_center + width / 2)
                ymax = (y_center + height / 2)
                boxes.append([xmin, ymin, xmax, ymax])

        # Convertir a formato tensor para PyTorch
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        # Crear diccionario de targets
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels

        # Aplicar transformaciones (opcional)
        if self.transform:
            image = self.transform(image)

        return image, target

# Transformaciones que se aplicarán a las imágenes
transform = T.Compose([
    T.ToTensor(),  # Convertir la imagen a tensor
])

# Crear datasets para train y val
train_dataset = MosquitoDataset(images_folder="train", labels_folder="train", transform=transform)
val_dataset = MosquitoDataset(images_folder="val", labels_folder="val", transform=transform)

# Crear dataloaders para train y val
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# Inicializar el modelo SSD
model = torchvision.models.detection.ssd300_vgg16(pretrained=True)
model.head.classification_head.num_classes = 6  # Ajustar el número de clases (6 clases en tu caso)

# Mover el modelo a GPU si está disponible
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Definir el optimizador
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Entrenamiento del modelo
num_epochs = 10  # Cambia según cuántas épocas quieras entrenar

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, targets in train_loader:
        # Convertir los datos a tensores y moverlos a GPU si está disponible
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        # Calcular las pérdidas
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        epoch_loss += losses.item()

        # Retropropagación
        losses.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}")

# Guardar el modelo entrenado
torch.save(model.state_dict(), 'ssd_weights.pth')
print("Modelo guardado como ssd_weights.pth")

# Evaluación del modelo
model.eval()
with torch.no_grad():
    for images, targets in val_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Obtener predicciones
        predictions = model(images)
        # Aquí podrías calcular métricas de evaluación, como IoU, mAP, etc.




Epoch 1/10, Loss: 10624.562443733215


In [None]:
torch.save(model.state_dict(), "ssd_model.pth")
print("Modelo guardado.")
