# Experimento 2

Jeffrey Leiva Cascante 2021016720

Richard León Chinchilla 2019003759


In [None]:
import sys
import os
import hydra
from Config.config import Configuration
from DenoisingAutoencoder import DenoisingAutoEncoder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from DataModule import ButterflyDataModule
import wandb
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import io
from sklearn.cluster import KMeans
from threadpoolctl import threadpool_limits
from PIL import Image

wandb.login()

notebook_dir = os.path.dirname(os.path.abspath(os.getcwd()))
config_dir = os.path.join(notebook_dir, 'proyecto-transfer-learning\\Config')

sys.argv = [arg for arg in sys.argv if not arg.startswith("--")]

@hydra.main(config_path=config_dir, config_name="config", version_base=None)
def trainining_denoising_autoencoder(config: Configuration):
    data_module = ButterflyDataModule(config.DATASET.DATA_DIR,config.TRAIN.BATCH_SIZE, False, 0,
                                      config.TRAIN.NUM_WORKERS)
    data_module.setup()


    model = DenoisingAutoEncoder(config.MODEL.LATENT_DIM,config.TRAIN.LEARNING_RATE)

    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=3,
        verbose=False,
        mode='min'
    )

    wandb_logger = WandbLogger(
        project="Denoising-Autoencoder",
        log_model=True,
    )

    trainer = L.Trainer(
        max_epochs=config.TRAIN.NUM_EPOCHS,
        accelerator=config.TRAIN.ACCELERATION,
        precision=config.TRAIN.PRECISION,
        callbacks=[early_stop_callback],
        logger=wandb_logger,
        devices=1 if torch.cuda.is_available() else None
    )


    trainer.fit(model, data_module)

    #Extraer los vectores latentes
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    latent_vectors, labels = extract_latent_vectors(model,data_module.val_dataloader(),device)

    #Convertir los tensores a numpy arrays
    latent_vectors_np = latent_vectors.cpu().numpy()

    #Aplicar t-SNE
    tsne = TSNE(n_components=2, random_state=0)
    with threadpool_limits(limits=1):
        latent_2d = tsne.fit_transform(latent_vectors_np)

    #Visualizar el espacio latente
    plot_latent_space(latent_2d,labels,"Espacio Latente con Labels Reales",logger=wandb_logger)

    #Aplicar K-means
    n_clusters = len(np.unique(labels)) #Numero de clases
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(latent_vectors_np)

    #Visualizar el espacio latente con las etiquetas de los clusters
    plot_latent_space(latent_2d,cluster_labels,"Espacio Latente con Labels de Clusters",logger=wandb_logger)

    #Finalizar el experimento de wandb

    wandb_logger.experiment.finish()

# Funcion para extraer los vectores latentes
def extract_latent_vectors(model,dataloader,device):
        latent_vectors = []
        labels = []
        with torch.no_grad():
            for x, y in dataloader:
                x = x.to(device)
                z = model.encode(x)
                latent_vectors.append(z.cpu())
                labels.append(y)
        latent_vectors = torch.cat(latent_vectors)
        labels = torch.cat(labels)
        return latent_vectors, labels.cpu().numpy()

def plot_latent_space(latent_2d,labels=None,title="Espacio Latente",logger=None):
    plt.figure(figsize=(10,10))
    if labels is not None:
        scatter = plt.scatter(latent_2d[:,0],latent_2d[:,1],c=labels,cmap='tab10')
        unique_labels = np.unique(labels).tolist()
        plt.legend(handles=scatter.legend_elements()[0],labels=unique_labels)
    else:
        plt.scatter(latent_2d[:,0],latent_2d[:,1],alpha=0.7)
    plt.title(title)
    plt.xlabel("Componente 1")
    plt.ylabel("Componente 2")
    if logger is not None:
        #Guardar la imagen en un buffer
        buf = io.BytesIO()
        plt.savefig(buf,format='png')
        buf.seek(0)
        #Convertir el buffer en una imagen de PIL
        image = Image.open(buf)

        #Subir la imagen a wandb
        wandb_image = wandb.Image(image,caption=title)
        logger.experiment.log({title:wandb_image})
    else:
        plt.show()
    plt.close()


    

trainining_denoising_autoencoder()
        
        