In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm  

In [2]:
unlabeled_set_size = 0.25
labeled_train_absolute_set_size = 0.1
labeled_test_absolute_set_size = round(1 - (labeled_train_absolute_set_size + unlabeled_set_size), 2)

labeled_train_relative_set_size = round((labeled_train_absolute_set_size / (1 - unlabeled_set_size)), 2)
labeled_test_relative_set_size = 1 - labeled_train_relative_set_size

In [3]:
base_dir = os.path.join('Plant_leave_diseases_dataset', 'original')
os.makedirs('best_models', exist_ok=True)
model_save_path = \
    os.path.join('best_models', f'h1_{int(unlabeled_set_size*100)}-{int(labeled_train_absolute_set_size*100)}-{int(labeled_test_absolute_set_size*100)}_VAE.pth')
encoder_save_path = \
    os.path.join('best_models', f'h1_{int(unlabeled_set_size*100)}-{int(labeled_train_absolute_set_size*100)}-{int(labeled_test_absolute_set_size*100)}_EncoderVAE.pth')

In [4]:
class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=256*14*14, z_dim=32):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            nn.Unflatten(dim=1, unflattened_size=(256, 14, 14)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        h_decoded = self.fc3(z)
        x_reconstructed = self.decoder(h_decoded)
        return x_reconstructed, mu, logvar

def loss_function(reconstructed_x, x, mu, logvar):
    MSE = F.mse_loss(reconstructed_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

In [5]:
data_transforms = {
    'all': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1), 
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

In [6]:

base_dir = 'Plant_leave_diseases_dataset/original'
full_dataset = datasets.ImageFolder(base_dir, transform=data_transforms['all'])


indices = list(range(len(full_dataset)))

# Get the directory paths of images
image_paths = [sample[0] for sample in full_dataset.samples]

labels = [os.path.split(os.path.dirname(path))[-1] for path in image_paths]

In [7]:
#Obtenemos el 20% de los datos 
val_indices, train_indices = train_test_split(indices, test_size=unlabeled_set_size, stratify=labels, random_state=42)

#Obtenemos las etiquetas de los datos de entrenamiento
val_labels = [labels[i] for i in val_indices]

#dividir el 20% en 10% de entrenamiento y 10% de validación
_, val_indices = train_test_split(val_indices, test_size=labeled_test_relative_set_size, stratify=val_labels, random_state=42)

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, train_indices)

In [8]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

print(f"Número de imágenes en el conjunto de entrenamiento: {len(train_loader.dataset)}")
print(f"Número de imágenes en el conjunto de validación: {len(val_loader.dataset)}")

Número de imágenes en el conjunto de entrenamiento: 15372
Número de imágenes en el conjunto de validación: 15372


In [9]:
def train_vae(vae, train_loader, val_loader, optimizer, device, num_epochs=10, patience=5, model_save_path='best_vae.pth', encoder_save_path='best_encoder.pth'):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    vae.train()
    
    for epoch in range(num_epochs):
        total_train_loss = 0
        vae.train()  # Asegurarse de que el modelo esté en modo de entrenamiento
        
        # Añadir la barra de progreso para el epoch actual
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
            for batch in train_loader:
                images, _ = batch
                images = images.to(device)
                optimizer.zero_grad()
                reconstructed_images, mu, logvar = vae(images)
                loss = loss_function(reconstructed_images, images, mu, logvar)
                loss.backward()
                optimizer.step()
                total_train_loss += loss.item()
                pbar.update(1)
        
        avg_train_loss = total_train_loss / len(train_loader.dataset)
        
        # Paso de validación
        vae.eval()  # Asegurarse de que el modelo esté en modo de evaluación
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images, _ = batch
                images = images.to(device)
                reconstructed_images, mu, logvar = vae(images)
                MSE = F.mse_loss(reconstructed_images, images, reduction='sum')
                KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                loss = MSE + KLD
                total_val_loss += loss.item()
                print(f'Batch MSE: {MSE.item()}, KLD: {KLD.item()}, Total Loss: {loss.item()}')
        
        avg_val_loss = total_val_loss / len(val_loader.dataset)
        
        # Verificar si el modelo actual es el mejor
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            # Guardar el modelo completo de la VAE
            torch.save(vae.state_dict(), model_save_path)
            # Guardar la parte del encoder de la VAE
            torch.save(vae.encoder.state_dict(), encoder_save_path)
            print(f'\nMejor modelo guardado con pérdida de validación: {avg_val_loss:.4f}')
        else:
            epochs_no_improve += 1
        
        # Verificar early stopping
        if epochs_no_improve >= patience:
            print(f'\nEarly stopping activado. No ha mejorado en {patience} épocas.')
            break
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

In [10]:
def load_best_vae(vae, model_save_path='best_vae.pth', encoder_save_path='best_encoder.pth'):
    vae.load_state_dict(torch.load(model_save_path))
    vae.encoder.load_state_dict(torch.load(encoder_save_path))
    vae.eval()



In [11]:
def evaluate_vae(vae, dataloader, device):
    vae.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            images, _ = batch
            images = images.to(device)
            reconstructed_images, mu, logvar = vae(images)
            loss = loss_function(reconstructed_images, images, mu, logvar)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader.dataset)
    print(f'Pérdida promedio: {avg_loss:.4f}')

    # Visualizar algunas de las imágenes originales y reconstruidas
    images, _ = next(iter(dataloader))
    images = images.to(device)
    reconstructed_images, _, _ = vae(images)
    
    images = images.cpu().detach().numpy()
    reconstructed_images = reconstructed_images.cpu().detach().numpy()
    
    fig, axes = plt.subplots(2, 10, figsize=(20, 4))
    for i in range(10):
        # Imágenes originales
        axes[0, i].imshow(images[i].squeeze(), cmap='gray')
        axes[0, i].axis('off')
        
        # Imágenes reconstruidas
        axes[1, i].imshow(reconstructed_images[i].squeeze(), cmap='gray')
        axes[1, i].axis('off')
    
    axes[0, 0].set_title('Imágenes Originales')
    axes[1, 0].set_title('Imágenes Reconstruidas')
    plt.show()

In [12]:
# Determinar el dispositivo
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f'Usando dispositivo: {device}')

# Definir y entrenar el modelo con lr = 0.001
vae = VAE().to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

# Entrenar la VAE
train_vae(vae, train_loader, val_loader, optimizer, device, num_epochs=10, patience=5, model_save_path=model_save_path, encoder_save_path=encoder_save_path)

# Cargar el mejor modelo y encoder, y evaluar
load_best_vae(vae, model_save_path=model_save_path, encoder_save_path=encoder_save_path)
evaluate_vae(vae, val_loader, device)

Usando dispositivo: mps


Epoch 1/10: 100%|██████████| 3843/3843 [01:48<00:00, 35.38batch/s]
