In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Librerias

In [2]:
!pip install torch torchvision tqdm requests
!pip install einops tqdm
!pip install albumentations lpips tqdm
!pip install pytorch-fid pytorch-ssim lpips
!pip install piq
!pip install lpips



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.models import vgg16
from torchvision import utils as vutils
from torchvision import datasets, transforms

import math
import numpy as np
import os
import json
import re
import albumentations
from PIL import Image
import matplotlib.pyplot as plt
import hashlib
import pickle
import uuid
import argparse

from tqdm import tqdm
import scipy.linalg
from piq import ssim, psnr
# from LPIPS import LPIPS
from scipy.spatial import distance
from collections import namedtuple
import requests
import shutil

from skimage.metrics import structural_similarity as ssim_metric, peak_signal_noise_ratio as psnr_metric


device = "cuda" if torch.cuda.is_available() else "cpu"

import lpips
lpips_fn = lpips.LPIPS(net='alex').eval().to(device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


## Configuraciones iniciales

In [14]:
# ============================
# Configuraci√≥n de rutas
# ============================
BASE_CKPT_DIR = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsVQ1"
GEN_IMG_DIR = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/generatedVQ1"
METRICS_PATH = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/metrics_history.json"
LOSS_VQ_PATH = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/loss_history_vq.json"
LOSS_TRANS_PATH = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/loss_history_transformer.json"

os.makedirs(BASE_CKPT_DIR, exist_ok=True)
os.makedirs(GEN_IMG_DIR, exist_ok=True)

# Funciones

In [15]:
# helper.py
class GroupNorm(nn.Module):
    def __init__(self, channels):
        super(GroupNorm, self).__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

    def forward(self, x):
        return self.gn(x)


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block = nn.Sequential(
            GroupNorm(in_channels),
            Swish(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            GroupNorm(out_channels),
            Swish(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        if in_channels != out_channels:
            self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        if self.in_channels != self.out_channels:
            return self.channel_up(x) + self.block(x)
        else:
            return x + self.block(x)


class UpSampleBlock(nn.Module):
    def __init__(self, channels):
        super(UpSampleBlock, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0)
        return self.conv(x)


class DownSampleBlock(nn.Module):
    def __init__(self, channels):
        super(DownSampleBlock, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 2, 0)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = F.pad(x, pad, mode="constant", value=0)
        return self.conv(x)


class NonLocalBlock(nn.Module):
    def __init__(self, channels):
        super(NonLocalBlock, self).__init__()
        self.in_channels = channels

        self.gn = GroupNorm(channels)
        self.q = nn.Conv2d(channels, channels, 1, 1, 0)
        self.k = nn.Conv2d(channels, channels, 1, 1, 0)
        self.v = nn.Conv2d(channels, channels, 1, 1, 0)
        self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0)

    def forward(self, x):
        h_ = self.gn(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape

        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)
        k = k.reshape(b, c, h*w)
        v = v.reshape(b, c, h*w)

        attn = torch.bmm(q, k)
        attn = attn * (int(c)**(-0.5))
        attn = F.softmax(attn, dim=2)
        attn = attn.permute(0, 2, 1)

        A = torch.bmm(v, attn)
        A = A.reshape(b, c, h, w)

        return x + A

In [16]:
# codebook.py
class Codebook(nn.Module):
    def __init__(self, args):
        super(Codebook, self).__init__()
        self.num_codebook_vectors = args.num_codebook_vectors
        self.latent_dim = args.latent_dim
        self.beta = args.beta

        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

    def forward(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)

        d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2*(torch.matmul(z_flattened, self.embedding.weight.t()))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)

        z_q = z + (z_q - z).detach()

        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss

In [17]:
# Encoder / Decoder
class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        # channels = [128, 128, 128, 256, 256, 512]
        channels = [64, 128, 256]
        attn_resolutions = [16]
        num_res_blocks = 2
        resolution = 256
        layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]
        for i in range(len(channels)-1):
            in_channels = channels[i]
            out_channels = channels[i+1]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
            if resolution in attn_resolutions:
                # layers.append(NonLocalBlock(in_channels))
                pass
            if i != len(channels)-2:
                layers.append(DownSampleBlock(channels[i+1]))
            resolution //= 2
        layers += [
            ResidualBlock(channels[-1], channels[-1]),
            # NonLocalBlock(channels[-1]),
            ResidualBlock(channels[-1], channels[-1]),
            GroupNorm(channels[-1]),
            Swish(),
            nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        channels = [256, 128, 64]
        attn_resolutions = [16]
        num_res_blocks = 2
        resolution = args.image_size // 4  # empezar en 32 si la imagen es 128

        in_channels = channels[0]
        layers = [
            nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1),
            ResidualBlock(in_channels, in_channels),
            ResidualBlock(in_channels, in_channels)
        ]

        for i in range(len(channels)):
            out_channels = channels[i]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels

            if resolution in attn_resolutions:
                layers.append(NonLocalBlock(in_channels))

            if i != 0 and resolution < args.image_size:
                layers.append(UpSampleBlock(in_channels))
                resolution *= 2

        layers += [
            GroupNorm(in_channels),
            Swish(),
            nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [18]:
# discriminator.py
"""
PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
"""
class Discriminator(nn.Module):
    def __init__(self, args, num_filters_last=64, n_layers=3):
        super(Discriminator, self).__init__()

        layers = [nn.Conv2d(args.image_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
        num_filters_mult = 1

        for i in range(1, n_layers + 1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2 ** i, 8)
            layers += [
                nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                          2 if i < n_layers else 1, 1, bias=False),
                nn.BatchNorm2d(num_filters_last * num_filters_mult),
                nn.LeakyReLU(0.2, True)
            ]

        layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [19]:
# mingpt.py
"""
taken from: https://github.com/karpathy/minGPT/
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
    - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
    - all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
class GPTConfig:
    """ base GPT config, params common to all GPT versions """
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1

    def __init__(self, vocab_size, block_size, **kwargs):
        self.vocab_size = vocab_size
        self.block_size = block_size
        for k, v in kwargs.items():
            setattr(self, k, v)


class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        mask = torch.tril(torch.ones(config.block_size,
                                     config.block_size))
        if hasattr(config, "n_unmasked"):
            mask[:config.n_unmasked, :config.n_unmasked] = 1
        self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        present = torch.stack((k, v))
        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat((past_key, k), dim=-2)
            v = torch.cat((past_value, v), dim=-2)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        if layer_past is None:
            att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))

        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y, present  # TODO: check that this does not break anything


class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),  # nice
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x, layer_past=None, return_present=False):
        # TODO: check that training still works
        if return_present:
            assert not self.training
        # layer past: tuple of length two with B, nh, T, hs
        attn, present = self.attn(self.ln1(x), layer_past=layer_past)

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        if layer_past is not None or return_present:
            return x, present
        return x


class GPT(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
                 embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
        super().__init__()
        config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
                           embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
                           n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                           n_unmasked=n_unmasked)
        # input embedding stem
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))  # 512 x 1024
        self.drop = nn.Dropout(config.embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.block_size = config.block_size
        self.apply(self._init_weights)
        self.config = config

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, idx, embeddings=None):
        token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector

        if embeddings is not None:  # prepend explicit embeddings
            token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

        t = token_embeddings.shape[1]
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        return logits, None

In [20]:
from PIL import Image, UnidentifiedImageError

# utils.py

# --------------------------------------------- #
#                  Data Utils
# --------------------------------------------- #

class ImagePaths(Dataset):
    def __init__(self, path, size=None):
        self.size = size

        self.images = [os.path.join(path, file) for file in os.listdir(path)]
        self._length = len(self.images)

        self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
        self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
        self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])

    def __len__(self):
        return self._length

    def preprocess_image(self, image_path):
        try:
            image = Image.open(image_path).convert("RGB")
        except (UnidentifiedImageError, OSError):
            print(f"‚ö†Ô∏è Imagen inv√°lida: {image_path}, reemplazada por ruido")
            image = Image.fromarray(
                np.uint8(np.random.rand(self.size, self.size, 3) * 255)
            )

        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image / 127.5 - 1.0).astype(np.float32)
        image = image.transpose(2, 0, 1)
        return image

    def __getitem__(self, i):
        example = self.preprocess_image(self.images[i])
        return example


def load_data(args):
    train_data = ImagePaths(args.dataset_path, size=128)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                          num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader


# --------------------------------------------- #
#                  Module Utils
#            for Encoder, Decoder etc.
# --------------------------------------------- #

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def plot_images(images):
    x = images["input"]
    reconstruction = images["rec"]
    half_sample = images["half_sample"]
    full_sample = images["full_sample"]

    fig, axarr = plt.subplots(1, 4)
    axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[3].imshow(full_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
    plt.show()

In [21]:
# vqgan.py
class VQGAN(nn.Module):
    def __init__(self, args):
        super(VQGAN, self).__init__()
        self.encoder = Encoder(args).to(device=args.device)
        self.decoder = Decoder(args).to(device=args.device)
        self.codebook = Codebook(args).to(device=args.device)
        self.quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)
        self.post_quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)

    def forward(self, imgs):
        encoded_images = self.encoder(imgs)
        quant_conv_encoded_images = self.quant_conv(encoded_images)
        codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
        post_quant_conv_mapping = self.post_quant_conv(codebook_mapping)
        decoded_images = self.decoder(post_quant_conv_mapping)

        return decoded_images, codebook_indices, q_loss

    def encode(self, imgs):
        encoded_images = self.encoder(imgs)
        quant_conv_encoded_images = self.quant_conv(encoded_images)
        codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
        return codebook_mapping, codebook_indices, q_loss

    def decode(self, z):
        post_quant_conv_mapping = self.post_quant_conv(z)
        decoded_images = self.decoder(post_quant_conv_mapping)
        return decoded_images

    def calculate_lambda(self, perceptual_loss, gan_loss):
        last_layer = self.decoder.model[-1]
        last_layer_weight = last_layer.weight
        perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
        gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]

        Œª = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
        Œª = torch.clamp(Œª, 0, 1e4).detach()
        return 0.8 * Œª

    @staticmethod
    def adopt_weight(disc_factor, i, threshold, value=0.):
        if i < threshold:
            disc_factor = value
        return disc_factor

    def load_checkpoint(self, path, map_location=None):
        """Carga un checkpoint entrenado previamente de forma segura en CPU/GPU."""
        checkpoint = torch.load(path, map_location=map_location or torch.device("cpu"))
        self.load_state_dict(checkpoint)
        print(f"‚úÖ Checkpoint cargado desde {path}")

In [22]:
# transformer.py
class VQGANTransformer(nn.Module):
    def __init__(self, args):
        super(VQGANTransformer, self).__init__()

        self.sos_token = args.sos_token
        self.vqgan = self.load_vqgan(args)

        transformer_config = {
            "vocab_size": args.num_codebook_vectors,
            "block_size": 512,
            "n_layer": 24,
            "n_head": 16,
            "n_embd": 1024
        }
        self.transformer = GPT(**transformer_config)

        self.pkeep = args.pkeep

    @staticmethod
    def load_vqgan(args):
        model = VQGAN(args)
        model.load_checkpoint(args.checkpoint_path)  # sin map_location
        model = model.eval().to("cuda")
        return model

    @torch.no_grad()
    def encode_to_z(self, x):
        quant_z, indices, _ = self.vqgan.encode(x)
        indices = indices.view(quant_z.shape[0], -1)
        return quant_z, indices

    @torch.no_grad()
    def z_to_image(self, indices, p1=16, p2=16):
        dim = self.vqgan.codebook.embedding.embedding_dim
        ix_to_vectors = self.vqgan.codebook.embedding(indices).reshape(indices.shape[0], p1, p2, dim)
        ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2)
        image = self.vqgan.decode(ix_to_vectors)
        return image

    def forward(self, x):
        _, indices = self.encode_to_z(x)

        sos_tokens = torch.ones(x.shape[0], 1, device="cuda") * self.sos_token
        sos_tokens = sos_tokens.long()

        mask = torch.bernoulli(self.pkeep * torch.ones(indices.shape, device=indices.device))
        mask = mask.round().to(dtype=torch.int64)
        random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)
        new_indices = mask * indices + (1 - mask) * random_indices

        new_indices = torch.cat((sos_tokens, new_indices), dim=1)

        target = indices

        logits, _ = self.transformer(new_indices[:, :-1])

        return logits, target

    def top_k_logits(self, logits, k):
        v, ix = torch.topk(logits, k)
        out = logits.clone()
        out[out < v[..., [-1]]] = -float("inf")
        return out

    @torch.no_grad()
    def sample(self, x, c, steps, temperature=1.0, top_k=100):
        self.transformer.eval()
        x = torch.cat((c, x), dim=1)
        for k in range(steps):
            logits, _ = self.transformer(x)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                logits = self.top_k_logits(logits, top_k)

            probs = F.softmax(logits, dim=-1)
            ix = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, ix), dim=1)

        x = x[:, c.shape[1]:]
        self.transformer.train()
        return x

    @torch.no_grad()
    def log_images(self, x):
        log = dict()

        _, indices = self.encode_to_z(x)
        sos_tokens = torch.ones(x.shape[0], 1, device="cuda") * self.sos_token
        sos_tokens = sos_tokens.long()

        # Half-sample
        start_indices = indices[:, :indices.shape[1] // 2]
        sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1])
        half_sample = self.z_to_image(sample_indices)

        # Full-sample
        start_indices = indices[:, :0]
        sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1])
        full_sample = self.z_to_image(sample_indices)

        # Reconstruction
        x_rec = self.z_to_image(indices)

        log["input"] = x
        log["rec"] = x_rec
        log["half_sample"] = half_sample
        log["full_sample"] = full_sample

        return log, torch.cat((x, x_rec, half_sample, full_sample))

#Metricas

In [26]:
# =========================
# FeatureStats class
# =========================
class FeatureStats:
    def __init__(self, capture_mean_cov=True, max_items=None):
        self.capture_mean_cov = capture_mean_cov
        self.max_items = max_items
        self.num_items = 0
        self.num_features = None
        self.raw_mean = None
        self.raw_cov = None

    def append_torch(self, x):
        x = x.detach().cpu().numpy()
        if self.num_features is None:
            self.num_features = x.shape[1]
            self.raw_mean = np.zeros([self.num_features], dtype=np.float64)
            self.raw_cov = np.zeros([self.num_features, self.num_features], dtype=np.float64)
        self.num_items += x.shape[0]

        x64 = x.astype(np.float64)
        self.raw_mean += x64.sum(axis=0)
        self.raw_cov += x64.T @ x64

    def get_mean_cov(self):
        mean = self.raw_mean / self.num_items
        cov = self.raw_cov / self.num_items - np.outer(mean, mean)
        return mean, cov

    def save(self, pkl_file):
        with open(pkl_file, "wb") as f:
            pickle.dump(self.__dict__, f)

    @staticmethod
    def load(pkl_file):
        with open(pkl_file, "rb") as f:
            s = pickle.load(f)
        obj = FeatureStats()
        obj.__dict__.update(s)
        return obj

# =========================
# Inception loader
# =========================
_inception_cache = {}

def get_inception_v3(device="cuda"):
    if device not in _inception_cache:
        inception = models.inception_v3(pretrained=True, transform_input=False)
        inception.fc = torch.nn.Identity()  # quitar clasificador
        inception.eval().to(device)
        _inception_cache[device] = inception
    return _inception_cache[device]

# =========================
# Compute dataset stats
# =========================
def compute_dataset_stats(dataset, device="cuda", batch_size=64,
                          cache_dir="/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsV1",
                          max_items=None):
    os.makedirs(cache_dir, exist_ok=True)

    # hash √∫nico para identificar dataset
    dataset_id = str(dataset) + f"-{len(dataset)}"
    md5 = hashlib.md5(dataset_id.encode("utf-8")).hexdigest()
    cache_file = os.path.join(cache_dir, f"dataset_stats_{md5}.pkl")

    # Si ya existen stats guardados ‚Üí cargarlos
    if os.path.isfile(cache_file):
        print(f"üìÇ Stats encontrados: {cache_file}")
        return FeatureStats.load(cache_file)

    print("üîé Calculando stats del dataset real...")
    inception = get_inception_v3(device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    stats = FeatureStats(capture_mean_cov=True, max_items=max_items)

    for batch in tqdm(loader, desc="Extrayendo features"):
        if isinstance(batch, (list, tuple)):
            images = batch[0]
        else:
            images = batch

        # Si son grayscale ‚Üí duplicar canales
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        # Ajustar tama√±o para InceptionV3
        images = F.interpolate(images, size=(299, 299), mode="bilinear", align_corners=False)

        with torch.no_grad():
            feats = inception(images.to(device))
        stats.append_torch(feats)

    stats.save(cache_file)
    print(f"‚úÖ Stats guardados en: {cache_file}")
    return stats

In [27]:
stats_dir = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsVQ1"
stats_path = os.path.join(stats_dir, "dataset_stats_c40b27d8db3f4b1d1ede4a5413f38844.pkl")

# Transformaciones del dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.ImageFolder(
    "/content/drive/MyDrive/Proyecto_Grado/Data/frames_extraidos_MedGAN",
    transform=transform
)

# === Revisi√≥n del archivo de stats ===
if os.path.exists(stats_path):
    print(f"‚úÖ Stats ya existen, cargando desde: {stats_path}")
    with open(stats_path, "rb") as f:
        stats = pickle.load(f)  # Esto es un dict
else:
    print("‚ö†Ô∏è Stats no encontrados, calculando desde cero...")
    stats = compute_dataset_stats(dataset, device="cuda", batch_size=64)
    with open(stats_path, "wb") as f:
        pickle.dump(stats, f)

# Usar el diccionario directamente
mu_real = np.array(stats["raw_mean"])
sigma_real = np.array(stats["raw_cov"])

print("mu_real:", mu_real.shape, "sigma_real:", sigma_real.shape)

‚úÖ Stats ya existen, cargando desde: /content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsVQ1/dataset_stats_c40b27d8db3f4b1d1ede4a5413f38844.pkl
mu_real: (2048,) sigma_real: (2048, 2048)


In [28]:
# LPIPS.py
URL_MAP = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
    "vgg_lpips": "vgg.pth"
}


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def get_ckpt_path(name, root):
    assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path):
        print(f"Downloading {name} model from {URL_MAP[name]} to {path}")
        download(URL_MAP[name], path)
    return path


class LPIPS(nn.Module):
    def __init__(self):
        super(LPIPS, self).__init__()
        self.scaling_layer = ScalingLayer()
        self.channels = [64, 128, 256, 512, 512]
        self.vgg = VGG16()
        self.lins = nn.ModuleList([
            NetLinLayer(self.channels[0]),
            NetLinLayer(self.channels[1]),
            NetLinLayer(self.channels[2]),
            NetLinLayer(self.channels[3]),
            NetLinLayer(self.channels[4])
        ])

        self.load_from_pretrained()

        for param in self.parameters():
            param.requires_grad = False

    def load_from_pretrained(self, name="vgg_lpips"):
        ckpt = get_ckpt_path(name, "vgg_lpips")
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)

    def forward(self, real_x, fake_x):
        # üî• Normalizar tama√±o si no coincide
        if real_x.shape[2:] != fake_x.shape[2:]:
            fake_x = F.interpolate(fake_x, size=real_x.shape[2:], mode="bilinear", align_corners=False)

        features_real = self.vgg(self.scaling_layer(real_x))
        features_fake = self.vgg(self.scaling_layer(fake_x))
        diffs = {}

        for i in range(len(self.channels)):
            diffs[i] = (norm_tensor(features_real[i]) - norm_tensor(features_fake[i])) ** 2

        return sum([spatial_average(self.lins[i].model(diffs[i])) for i in range(len(self.channels))])


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer("shift", torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer("scale", torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, x):
        return (x - self.shift) / self.scale


class NetLinLayer(nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super(NetLinLayer, self).__init__()
        self.model = nn.Sequential(
            nn.Dropout(),
            nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        )


class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        vgg_pretrained_features = vgg16(pretrained=True).features
        slices = [vgg_pretrained_features[i] for i in range(30)]
        self.slice1 = nn.Sequential(*slices[0:4])
        self.slice2 = nn.Sequential(*slices[4:9])
        self.slice3 = nn.Sequential(*slices[9:16])
        self.slice4 = nn.Sequential(*slices[16:23])
        self.slice5 = nn.Sequential(*slices[23:30])

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.slice1(x)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        vgg_outputs = namedtuple("VGGOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        return vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)


def norm_tensor(x):
    """
    Normalize images by their length to make them unit vector?
    :param x: batch of images
    :return: normalized batch of images
    """
    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
    return x / (norm_factor + 1e-10)


def spatial_average(x):
    """
     imgs have: batch_size x channels x width x height --> average over width and height channel
    :param x: batch of images
    :return: averaged images along width and height
    """
    return x.mean([2, 3], keepdim=True)

# TRAIN

## Entrenamiento VQ-GAN

In [29]:
# training_vqgan.py
class TrainVQGAN:
    def __init__(self, args, base_ckpt_dir, loss_vq_path, metrics_path):
        self.vqgan = VQGAN(args).to(device=args.device)
        self.discriminator = Discriminator(args).to(device=args.device)
        self.discriminator.apply(weights_init)

        # LPIPS
        self.perceptual_loss = lpips.LPIPS(net='alex').eval().to(device=args.device)

        # Rutas de logs
        self.loss_vq_path = loss_vq_path
        self.metrics_path = metrics_path

        # Optimizers
        self.opt_vq, self.opt_disc = self.configure_optimizers(args)

        # ================================
        # 1. Detectar √∫ltima carpeta de checkpoints
        # ================================
        existing = [d for d in os.listdir(base_ckpt_dir) if re.match(r"^\d{5}_checkpoint$", d)]
        if existing:
            last_num = max([int(d.split("_")[0]) for d in existing])
            last_folder = f"{last_num:05d}_checkpoint"
            last_path = os.path.join(base_ckpt_dir, last_folder)

            # Buscar √∫ltimo .pt dentro de esa carpeta
            ckpts = [f for f in os.listdir(last_path) if f.startswith("vqgan_epoch_") and f.endswith(".pt")]
            if ckpts:
                ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
                last_ckpt = ckpts[-1]
                ckpt_path = os.path.join(last_path, last_ckpt)
                self.vqgan.load_state_dict(torch.load(ckpt_path, map_location=args.device))
                self.start_epoch = int(last_ckpt.split("_")[-1].split(".")[0])
                print(f"‚úÖ √öltimo checkpoint cargado: {ckpt_path} (√©poca {self.start_epoch})")
            else:
                self.start_epoch = 0
                print("‚ö†Ô∏è No se encontraron .pt en la √∫ltima carpeta, entrenamiento desde cero.")
        else:
            last_num, self.start_epoch = -1, 0
            print("‚ö†Ô∏è No se encontraron carpetas de checkpoints previas, entrenamiento desde cero.")

        # ================================
        # 2. Crear nueva carpeta de run
        # ================================
        new_num = last_num + 1
        run_folder = f"{new_num:05d}_checkpoint"
        self.run_path = os.path.join(base_ckpt_dir, run_folder)
        os.makedirs(self.run_path, exist_ok=True)
        print(f"üìÇ Carpeta de checkpoints actual: {self.run_path}")

        # ================================
        # 3. Cargar historial global de p√©rdidas y m√©tricas
        # ================================
        self.loss_history = []
        if os.path.exists(self.loss_vq_path):
            with open(self.loss_vq_path, "r") as f:
                self.loss_history = json.load(f)

        self.metric_history = []
        if os.path.exists(self.metrics_path):
            with open(self.metrics_path, "r") as f:
                self.metric_history = json.load(f)

    def configure_optimizers(self, args):
        lr = args.learning_rate
        opt_vq = torch.optim.Adam(
            list(self.vqgan.encoder.parameters()) +
            list(self.vqgan.decoder.parameters()) +
            list(self.vqgan.codebook.parameters()) +
            list(self.vqgan.quant_conv.parameters()) +
            list(self.vqgan.post_quant_conv.parameters()),
            lr=lr, eps=1e-08, betas=(args.beta1, args.beta2)
        )
        opt_disc = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=lr, eps=1e-08, betas=(args.beta1, args.beta2)
        )
        return opt_vq, opt_disc

    def compute_metrics(self, real, recon):
        # Pasar a [0,1] y CPU
        real_np = ((real.detach().cpu().numpy() + 1) / 2).clip(0, 1)
        recon_np = ((recon.detach().cpu().numpy() + 1) / 2).clip(0, 1)

        # convertir a formato NCHW -> NHWC para skimage
        real_np = np.transpose(real_np, (0, 2, 3, 1))
        recon_np = np.transpose(recon_np, (0, 2, 3, 1))

        ssim_vals, psnr_vals = [], []
        for r, f in zip(real_np, recon_np):
            ssim_vals.append(ssim_metric(r, f, channel_axis=-1, data_range=1.0))
            psnr_vals.append(psnr_metric(r, f, data_range=1.0))

        # LPIPS batch directo
        lpips_val = self.perceptual_loss(real, recon).mean().item()

        return {
            "ssim": float(np.mean(ssim_vals)),
            "psnr": float(np.mean(psnr_vals)),
            "lpips": float(lpips_val)
        }

    def train(self, args):
        train_dataset = load_data(args)
        steps_per_epoch = len(train_dataset)
        global_step, Œª_prev = 0, 1.0

        for epoch in range(self.start_epoch, args.epochs):
            rec_losses, perceptual_losses, q_losses = [], [], []
            g_losses, gen_losses, disc_losses = [], [], []

            with tqdm(range(len(train_dataset))) as pbar:
                for i, imgs in zip(pbar, train_dataset):
                    imgs = imgs.to(device=args.device)
                    decoded_images, _, q_loss = self.vqgan(imgs)

                    if imgs.shape != decoded_images.shape:
                        decoded_images = F.interpolate(decoded_images, size=imgs.shape[2:], mode="bilinear")

                    # discriminador
                    disc_real = self.discriminator(imgs)
                    disc_fake = self.discriminator(decoded_images)

                    disc_factor = self.vqgan.adopt_weight(
                        args.disc_factor, epoch * steps_per_epoch + i, threshold=args.disc_start
                    )

                    # p√©rdidas
                    if global_step % args.lpips_interval == 0:
                        perceptual_loss = self.perceptual_loss(imgs, decoded_images).mean()
                    else:
                        perceptual_loss = torch.tensor(0.0, device=args.device)

                    rec_loss = torch.abs(imgs - decoded_images).mean()
                    perceptual_rec_loss = (
                        args.perceptual_loss_factor * perceptual_loss +
                        args.rec_loss_factor * rec_loss
                    )

                    g_loss = -torch.mean(disc_fake)
                    Œª = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss) if (global_step % 50) == 0 else Œª_prev
                    Œª_prev = Œª

                    vq_loss = perceptual_rec_loss + q_loss + disc_factor * Œª * g_loss
                    d_loss_real = torch.mean(F.relu(1. - disc_real))
                    d_loss_fake = torch.mean(F.relu(1. + disc_fake))
                    gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake)

                    # optimizaci√≥n
                    self.opt_vq.zero_grad()
                    vq_loss.backward(retain_graph=True)
                    self.opt_disc.zero_grad()
                    gan_loss.backward()
                    self.opt_vq.step()
                    self.opt_disc.step()

                    # acumular
                    rec_losses.append(rec_loss.item())
                    perceptual_losses.append(perceptual_loss.item())
                    q_losses.append(q_loss.item())
                    g_losses.append(g_loss.item())
                    gen_losses.append(vq_loss.item())
                    disc_losses.append(gan_loss.item())

                    pbar.set_postfix(
                        Rec=np.round(rec_loss.item(), 5),
                        Perc=np.round(perceptual_loss.item(), 5),
                        Q=np.round(q_loss.item(), 5),
                        G=np.round(g_loss.item(), 5),
                        Gen=np.round(vq_loss.item(), 5),
                        Disc=np.round(gan_loss.item(), 5),
                    )
                    global_step += 1

            # Guardar historial global
            self.loss_history.append({
                "epoch": epoch + 1,
                "rec_loss": float(np.mean(rec_losses)),
                "perceptual_loss": float(np.mean(perceptual_losses)),
                "q_loss": float(np.mean(q_losses)),
                "g_loss": float(np.mean(g_losses)),
                "gen_loss": float(np.mean(gen_losses)),
                "disc_loss": float(np.mean(disc_losses)),
            })
            with open(self.loss_vq_path, "w") as f:
                json.dump(self.loss_history, f, indent=4)

            # === Guardar checkpoint, imagen y m√©tricas cada 5 √©pocas ===
            if (epoch + 1) % 1 == 0:
                ckpt_path = os.path.join(self.run_path, f"vqgan_epoch_{epoch+1}.pt")
                torch.save(self.vqgan.state_dict(), ckpt_path)
                print(f"üíæ Checkpoint guardado: {ckpt_path}")

                gen_path = os.path.join(GEN_IMG_DIR, f"generated_v1_{epoch+1}.png")
                vutils.save_image((decoded_images[0].detach().cpu() + 1) * 0.5, gen_path)
                print(f"üñº Imagen generada guardada: {gen_path}")

                metrics = self.compute_metrics(imgs[:8], decoded_images[:8])
                metrics["epoch"] = epoch + 1
                self.metric_history.append(metrics)
                with open(self.metrics_path, "w") as f:
                    json.dump(self.metric_history, f, indent=4)
                print(f"üìä M√©tricas guardadas: {metrics}")

In [30]:
# === Configuraci√≥n de args ===
parser = argparse.ArgumentParser()
parser.add_argument('--latent-dim', type=int, default=128)
parser.add_argument('--image-size', type=int, default=128)
parser.add_argument('--num-codebook-vectors', type=int, default=512)
parser.add_argument('--beta', type=float, default=0.25)
parser.add_argument('--image-channels', type=int, default=3)
parser.add_argument('--dataset-path', type=str, default='/content/dataset')
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--learning-rate', type=float, default=2.25e-05)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.9)
parser.add_argument('--disc-start', type=int, default=10000)
parser.add_argument('--disc-factor', type=float, default=1.)
parser.add_argument('--rec-loss-factor', type=float, default=1.)
parser.add_argument('--perceptual-loss-factor', type=float, default=1.)
parser.add_argument("--lpips_interval", type=int, default=5,
                    help="Cada cu√°ntos pasos calcular LPIPS perceptual loss.")
args = parser.parse_args([])

# Forzar GPU (siempre)
device = torch.device("cuda")

# Dataset path
data_dir = "/content/drive/MyDrive/Proyecto_Grado/Data"
args.dataset_path = f"{data_dir}/frames_extraidos"

In [None]:
# Solo defines el directorio base donde estar√°n todas las carpetas
BASE_CKPT_DIR = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsVQ1"
train_vqgan = TrainVQGAN(args, BASE_CKPT_DIR, LOSS_VQ_PATH, METRICS_PATH)

# Entrenar directamente, sin preocuparte por start_epoch
train_vqgan.train(args)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth
‚ö†Ô∏è No se encontraron carpetas de checkpoints previas, entrenamiento desde cero.
üìÇ Carpeta de checkpoints actual: /content/drive/MyDrive/Proyecto_Grado/VQ_GAN/VQCheckpointsVQ1/00000_checkpoint


  1%|          | 50/5015 [00:56<1:26:54,  1.05s/it, Disc=0, G=-0.0564, Gen=0.93, Perc=0, Q=0.827, Rec=0.103]

## Entrenamiento del Transformer

In [None]:
# training_transformer.py
class TrainTransformer:
    def __init__(self, args, run_path):
        self.model = VQGANTransformer(args).to(device=args.device)
        self.optim = self.configure_optimizers()
        self.run_path = run_path
        os.makedirs(self.run_path, exist_ok=True)
        self.train(args)

    def configure_optimizers(self):
        decay, no_decay = set(), set()
        whitelist_weight_modules = (nn.Linear,)
        blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)

        for mn, m in self.model.transformer.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn
                if pn.endswith("bias"):
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        no_decay.add("pos_emb")
        param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()}
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups, lr=4.5e-06, betas=(0.9, 0.95))

    def train(self, args, start_epoch=0):
      train_dataset = load_data(args)
      steps_per_epoch = len(train_dataset)
      global_step, Œª_prev = 0, 1.0

      for epoch in range(args.epochs):
          epoch_losses = []
          with tqdm(range(len(train_dataset))) as pbar:
              for i, imgs in zip(pbar, train_dataset):
                  self.optim.zero_grad()
                  imgs = imgs.to(device=args.device)
                  logits, targets = self.model(imgs)
                  loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                  loss.backward()
                  self.optim.step()
                  epoch_losses.append(loss.item())
                  pbar.set_postfix(Transformer_Loss=np.round(loss.item(), 4))
                  pbar.update(0)

          # === Guardar historia de p√©rdidas ===
          loss_history.append({"epoch": epoch + 1, "loss": float(np.mean(epoch_losses))})
          with open(os.path.join(self.run_path, "transformer_loss_history.json"), "w") as f:
              json.dump(loss_history, f, indent=4)

          # === Cada 5 √©pocas: checkpoint, imagen y m√©tricas ===
          if (epoch + 1) % 5 == 0:
              ckpt_path = os.path.join(self.run_path, f"transformer_epoch_{epoch+1}.pt")
              torch.save(self.model.state_dict(), ckpt_path)
              print(f"üíæ Checkpoint guardado: {ckpt_path}")

              # Generar imagen de muestra
              log, sampled_imgs = self.model.log_images(imgs[0][None])
              sample_path = os.path.join(self.run_path, f"transformer_sample_epoch_{epoch+1}.jpg")
              vutils.save_image(sampled_imgs, sample_path, nrow=4)
              print(f"üñº Imagen de muestra guardada: {sample_path}")

              # Calcular m√©tricas
              decoded_images = self.model.z_to_image(self.model.vqgan.encode_codebook(imgs))
              metrics_path = os.path.join(self.run_path, f"transformer_metrics_epoch_{epoch+1}.json")
              compute_metrics(imgs[:8].detach(), decoded_images[:8].detach(), metrics_path, stage="transformer")

In [None]:
# === Crear carpeta numerada para el entrenamiento del Transformer ===
base_ckpt_dir = "/content/drive/MyDrive/Proyecto_Grado/VQ_GAN/TransformerCheckpointsV1"
os.makedirs(base_ckpt_dir, exist_ok=True)

existing = [d for d in os.listdir(base_ckpt_dir) if re.match(r"^\d{5}_checkpoint$", d)]
if existing:
    last_num = max([int(d.split("_")[0]) for d in existing])
    new_num = last_num + 1
else:
    new_num = 0

run_folder = f"{new_num:05d}_checkpoint"
run_path = os.path.join(base_ckpt_dir, run_folder)
os.makedirs(run_path, exist_ok=True)

print(f"‚úÖ Carpeta de entrenamiento de Transformer creada: {run_path}")

# === Configuraci√≥n de args ===
parser = argparse.ArgumentParser()
parser.add_argument('--latent-dim', type=int, default=128)
parser.add_argument('--image-size', type=int, default=128)
parser.add_argument('--num-codebook-vectors', type=int, default=512)
parser.add_argument('--beta', type=float, default=0.25)
parser.add_argument('--image-channels', type=int, default=3)
parser.add_argument('--dataset-path', type=str, default='/content/dataset')
parser.add_argument('--checkpoint-path', type=str, default='./checkpoints/vqgan_last_ckpt.pt')
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--batch-size', type=int, default=20)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--learning-rate', type=float, default=2.25e-05)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.9)
parser.add_argument('--disc-start', type=int, default=10000)
parser.add_argument('--disc-factor', type=float, default=1.0)
parser.add_argument('--l2-loss-factor', type=float, default=1.)
parser.add_argument('--perceptual-loss-factor', type=float, default=1.)
parser.add_argument('--pkeep', type=float, default=0.5)
parser.add_argument('--sos-token', type=int, default=0)
args = parser.parse_args([])

# Dataset
args.dataset_path = os.path.join(BASE_DIR, "frames_extraidos")

# Entrenar Transformer
train_transformer = TrainTransformer(args)
train_transformer.run_path = run_path  # <= IMPORTANTE: Para que sepa d√≥nde guardar

#Generar Data


In [None]:
# sample_tranformer.py
import os
import torch
from torchvision import utils as vutils
from tqdm import tqdm

# Usa las mismas rutas que definiste en las celdas de entrenamiento
data_dir = "/content/drive/MyDrive/Proyecto_Grado/Data"
output_checkpoints = f"{data_dir}/Checkpoints"
results_dir = f"{data_dir}/Resultados_Transformer"
os.makedirs(results_dir, exist_ok=True)

# Configuraci√≥n manual de los par√°metros (sin argparse)
class Args:
    latent_dim = 256
    image_size = 256
    num_codebook_vectors = 1024
    beta = 0.25
    image_channels = 3
    dataset_path = f"{data_dir}/frames_extraidos_MedGAN"
    checkpoint_path = os.path.join(output_checkpoints, "vqgan_epoch_100.pt")  # Ajusta si tienes otro checkpoint
    device = "cuda"
    batch_size = 20
    epochs = 100
    learning_rate = 2.25e-05
    beta1 = 0.5
    beta2 = 0.9
    disc_start = 10000
    disc_factor = 1.
    l2_loss_factor = 1.
    perceptual_loss_factor = 1.
    pkeep = 0.5
    sos_token = 0

args = Args()

# Cargar el modelo
transformer = VQGANTransformer(args).to(args.device)
transformer_ckpt_path = os.path.join(output_checkpoints, "transformer_epoch_100.pt")  # Ajusta si usas otro
transformer.load_state_dict(torch.load(transformer_ckpt_path, map_location=args.device))
print(f"‚úÖ Loaded Transformer checkpoint: {transformer_ckpt_path}")

# Generar N im√°genes
n = 10  # Cambia el n√∫mero de im√°genes a generar
for i in tqdm(range(n)):
    start_indices = torch.zeros((4, 0)).long().to(args.device)
    sos_tokens = torch.ones(start_indices.shape[0], 1) * args.sos_token
    sos_tokens = sos_tokens.long().to(args.device)
    sample_indices = transformer.sample(start_indices, sos_tokens, steps=256)
    sampled_imgs = transformer.z_to_image(sample_indices)

    save_path = os.path.join(results_dir, f"transformer_sample_{i}.jpg")
    vutils.save_image(sampled_imgs, save_path, nrow=4)
    print(f"üíæ Saved: {save_path}")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/Proyecto_Grado/Data/Checkpoints/vqgan_epoch_100.pt'