In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models

from sklearn.metrics import classification_report, f1_score, accuracy_score

from pathlib import Path
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # evita errores con imágenes dañadas

# Reproducibilidad
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Dispositivo
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("✅ Usando dispositivo:", DEVICE)


✅ Usando dispositivo: mps


In [None]:
# Directorio raíz con el dataset completo
DATA_DIR = Path("/Users/alvarosanchez/Downloads/MURA-v1.1")  # ajusta si cambia

# Tamaño estándar para modelos preentrenados
IMG_SIZE = 224

# Transformaciones para entrenamiento y validación
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),  # augmentación básica
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # estándar ImageNet
                         std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Dataset y dataloaders
train_dataset = datasets.ImageFolder(DATA_DIR / "train", transform=transform_train)
val_dataset = datasets.ImageFolder(DATA_DIR / "valid", transform=transform_val)

BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"📊 Clases: {train_dataset.classes}")
print(f"🧾 Total imágenes - Train: {len(train_dataset)} | Valid: {len(val_dataset)}")
