## **Tarea 3**

In [4]:
import os 
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt
import sys
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import numpy as np
from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd

# Verificar si se está ejecutando dentro de un notebook
def is_notebook() -> bool:
    try:
        return 'ipykernel' in sys.modules
    except NameError:
        return False

class Autoencoder(nn.Module):
    def __init__(self, latent_dim=8):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        
        elf.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # Capa adicional
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 128),  # Ajustar tamaño si añades más capas
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )
        
        # Decodificador: Añadir una capa convolucional adicional
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64 * 4 * 4),  # Ajustar tamaño si añades más capas
            nn.ReLU(),
            nn.Unflatten(1, (64, 4, 4)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),  # Capa adicional
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

class LitAutoencoder(pl.LightningModule):
    def __init__(self, config):
        super(LitAutoencoder, self).__init__()
        self.save_hyperparameters()
        self.learning_rate = config.model.learning_rate
        self.weight_decay = config.model.weight_decay
        self.autoencoder = Autoencoder(latent_dim=config.model.latent_dim)
        
    def forward(self, x):
        return self.autoencoder(x)
    
    def _common_step(self, batch, batch_idx, step_type):
        x, _ = batch
        z = self.autoencoder.encode(x)
        x_hat = self.autoencoder.decode(z)
        loss = nn.functional.mse_loss(x_hat, x)
        
        psnr = 20 * torch.log10(1.0 / torch.sqrt(loss))
        latent_norm = torch.norm(z, dim=1).mean()
        
        self.log(f'{step_type}/loss', loss)
        self.log(f'{step_type}/psnr', psnr)
        self.log(f'{step_type}/latent_norm', latent_norm)
        
        return {
            'loss': loss,
            'psnr': psnr,
            'reconstructions': x_hat,
            'originals': x,
            'latent': z
        }
    
    def training_step(self, batch, batch_idx):
        results = self._common_step(batch, batch_idx, 'train')
        return results['loss']
    
    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'val')
    
    def test_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, 'test')
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
                "frequency": 1
            }
        }

class ReconstructionCallback(pl.Callback):
    def __init__(self, val_samples, save_dir, val_loader, epoch_interval=5, num_samples=10):
        super().__init__()
        self.val_samples = val_samples
        self.epoch_interval = epoch_interval
        self.num_samples = num_samples
        self.save_dir = save_dir
        self.val_loader = val_loader
        os.makedirs(save_dir, exist_ok=True)
    
    def on_validation_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.epoch_interval == 0 or epoch == trainer.max_epochs - 1:
            val_samples = self.val_samples.to(pl_module.device)
            reconstructed = pl_module(val_samples)
            fig = self.plot_reconstruction(val_samples, reconstructed, epoch)
            
            save_path = os.path.join(self.save_dir, f'reconstruction_epoch_{epoch}.png')
            plt.savefig(save_path)
            plt.close(fig)
            
            if epoch == trainer.max_epochs - 1:
                self.visualize_latent_space(pl_module)
    
    def plot_reconstruction(self, originals, reconstructed, epoch):
        fig = plt.figure(figsize=(20, 4))
        plt.suptitle(f'Epoch {epoch}')
        
        originals = originals[:self.num_samples].cpu().detach()
        reconstructed = reconstructed[:self.num_samples].cpu().detach()
        
        for i in range(self.num_samples):
            ax = plt.subplot(2, self.num_samples, i + 1)
            plt.imshow(originals[i].squeeze(0), cmap='gray')
            if i == 0:
                plt.title("Original")
            plt.axis("off")
            
            ax = plt.subplot(2, self.num_samples, i + 1 + self.num_samples)
            plt.imshow(reconstructed[i].squeeze(0), cmap='gray')
            if i == 0:
                plt.title("Reconstrucción")
            plt.axis("off")
        
        plt.tight_layout()
        return fig
    
    def visualize_latent_space(self, pl_module):
        latent_vectors = []
        labels = []
        
        pl_module.eval()
        
        with torch.no_grad():
            for batch in self.val_loader:
                x, y = batch
                z = pl_module.autoencoder.encode(x.to(pl_module.device))
                latent_vectors.append(z.cpu())
                labels.extend(y.numpy())
        
        latent_vectors = torch.cat(latent_vectors, dim=0).numpy()
        
        tsne = TSNE(n_components=2, random_state=42)
        latent_2d = tsne.fit_transform(latent_vectors)
        
        plt.figure(figsize=(10, 10))
        scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='tab10')
        plt.colorbar(scatter)
        plt.title('Visualización t-SNE del Espacio Latente')
        plt.savefig(os.path.join(self.save_dir, 'latent_space_tsne.png'))
        plt.close()

# Adaptar para Jupyter o script regular
def main(cfg: DictConfig):
    pl.seed_everything(cfg.seed)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cfg.dataset.normalize.mean, 
                           std=cfg.dataset.normalize.std)
    ])
    
    full_train_dataset = datasets.FashionMNIST(
        root=cfg.dataset.root,
        train=True,
        download=True,
        transform=transform
    )
    
    train_size = int((1 - cfg.dataset.val_split) * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset,
        [train_size, val_size]
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=True,
        num_workers=cfg.dataset.num_workers
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=False,
        num_workers=cfg.dataset.num_workers
    )
    
    test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(
            root=cfg.dataset.root,
            train=False,
            download=True,
            transform=transform
        ),
        batch_size=cfg.dataset.batch_size,
        shuffle=False,
        num_workers=cfg.dataset.num_workers
    )
    
    model = LitAutoencoder(cfg)
    
    val_samples, _ = next(iter(val_loader))
    val_samples = val_samples[:cfg.visualization.num_samples]
    
    callbacks = [
        ReconstructionCallback(
            val_samples,
            save_dir=cfg.visualization.save_dir,
            val_loader=val_loader,
            epoch_interval=cfg.visualization.epoch_interval,
            num_samples=cfg.visualization.num_samples
        ),
        ModelCheckpoint(
            monitor='val/loss',
            dirpath=os.path.join(cfg.visualization.save_dir, 'checkpoints'),
            filename='autoencoder-{epoch:02d}-{val_loss:.2f}',
            save_top_k=3,
            mode='min'
        ),
        EarlyStopping(
            monitor='val/loss',
            patience=10,
            mode='min'
        )
    ]
    
    trainer = pl.Trainer(
        max_epochs=cfg.trainer.max_epochs,
        callbacks=callbacks,
        accelerator=cfg.trainer.accelerator,
        devices=cfg.trainer.devices,
        log_every_n_steps=cfg.trainer.log_every_n_steps,
        gradient_clip_val=cfg.trainer.gradient_clip_val,
        precision=cfg.trainer.precision,
        check_val_every_n_epoch=cfg.trainer.check_val_every_n_epoch
    )
    
    trainer.fit(model, train_loader, val_loader)
    trainer.test(model, test_loader)
    
    return model, trainer

if is_notebook():
    # Si es un notebook, se configura OmegaConf manualmente
    cfg = OmegaConf.create({
        'seed': 42,
        'dataset': {
            'name': 'FashionMNIST',
            'root': './data',
            'batch_size': 128,
            'num_workers': 4,
            'val_split': 0.2,
            'normalize': {
                'mean': [0.5],
                'std': [0.5]
            }
        },
        'model': {
            'latent_dim': 8,
            'learning_rate': 0.001,
            'weight_decay': 1e-5
        },
        'trainer': {
            'max_epochs': 50,
            'accelerator': 'auto',
            'devices': 1,
            'log_every_n_steps': 50,
            'gradient_clip_val': 0.5,
            'precision': 32,
            'check_val_every_n_epoch': 1
        },
        'visualization': {
            'num_samples': 10,
            'epoch_interval': 5,
            'save_dir': 'visualization_results'
        }
    })
    main(cfg)
else:
    # Si es un script regular, se usa Hydra
    @hydra.main(config_path="conf", config_name="config", version_base=None)
    def hydra_main(cfg: DictConfig):
        main(cfg)
    
    if __name__ == "__main__":
        hydra_main()

Seed set to 42


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 26.4M/26.4M [00:06<00:00, 4.33MB/s]


Extracting ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz



[A
100%|██████████| 29.5k/29.5k [00:00<00:00, 186kB/s]


Extracting ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 4.42M/4.42M [00:01<00:00, 2.59MB/s]


Extracting ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz

