<a href="https://colab.research.google.com/github/TSANT17/ProyectoPDI/blob/main/Proyecto_StyleMatcher.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Santiago Parrado
Carlos Lopez
Sebastian Vidal
Hernan Gutierrez

In [None]:
!pip install mediapipe
!pip install facenet-pytorch
!pip install gradio
!pip install torch torchvision
!pip install scikit-learn



In [None]:
import shutil
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
shutil.copy('/content/drive/MyDrive/StyleMatcher/model_anime_best.pt', '/content/model_anime_best.pt')
shutil.copy('/content/drive/MyDrive/StyleMatcher/model_painting_best.pt', '/content/model_painting_best.pt')

'/content/model_painting_best.pt'

**Normalizar todas las categorias**


In [None]:
import os
import cv2
from glob import glob
from tqdm import tqdm
import shutil

# Ruta de entrada y salida
root_input = '/content/drive/My Drive/StyleMatcher/dataset'
root_output = '/content/drive/My Drive/StyleMatcher/dataset_resized'
target_size = (512, 512)

def resize_and_save(input_path, output_path):
    os.makedirs(output_path, exist_ok=True)
    img_paths = glob(os.path.join(input_path, '*'))

    for img_path in tqdm(img_paths, desc=f'Redimensionando {input_path}'):
        try:
            img = cv2.imread(img_path)
            if img is None:
                continue
            resized = cv2.resize(img, target_size)
            name = os.path.basename(img_path)
            cv2.imwrite(os.path.join(output_path, name), resized)
        except Exception as e:
            print(f"Error con {img_path}: {e}")

def process_dataset(root_input, root_output):
    for phase in ['train', 'val']:
        phase_path = os.path.join(root_input, phase)
        for class_name in os.listdir(phase_path):
            input_folder = os.path.join(phase_path, class_name)
            output_folder = os.path.join(root_output, phase, class_name)
            resize_and_save(input_folder, output_folder)

# Ejecutamos
process_dataset(root_input, root_output)

Redimensionando /content/drive/My Drive/StyleMatcher/dataset/train/anime:  62%|██████▏   | 590/950 [15:27<09:49,  1.64s/it]

#Imports Y configuración inicial

In [None]:
import os
import cv2
import torch
import numpy as np
from glob import glob
from tqdm import tqdm
from torchvision import models, transforms
from facenet_pytorch import InceptionResnetV1
import mediapipe as mp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#Modelo de estilo con VGG19 multicapa

In [None]:
class VGG19StyleExtractor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features
        self.layers = torch.nn.Sequential(*[vgg[i] for i in [0, 5, 10, 19, 28]])  # conv1_1 a conv5_1
        for param in self.parameters():
            param.requires_grad = False
        self.to(device)

    def forward(self, x):
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)))
        return torch.cat([f.view(f.size(0), -1) for f in features], dim=1)

# Instanciar el modelo
vgg_model = VGG19StyleExtractor().to(device)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 194MB/s]


#Modelo de identidad con FaceNet

In [None]:
identity_model = InceptionResnetV1(pretrained='vggface2').eval().to(device)


  0%|          | 0.00/107M [00:00<?, ?B/s]

#Detector de pose con MediaPipe

In [None]:
mp_pose = mp.solutions.pose
pose_estimator = mp_pose.Pose(static_image_mode=True)

#Transforms para VGG y FaceNet

In [None]:
vgg_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

facenet_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

#Función para extraer todos los vectores de una imagen

In [None]:
def extract_all_features(img_path):
    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        return None, None, None
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_vgg = vgg_transform(img_rgb).unsqueeze(0).to(device)
    img_facenet = facenet_transform(img_rgb).unsqueeze(0).to(device)

    # Estilo
    with torch.no_grad():
        style_vec = vgg_model(img_vgg).cpu().numpy().flatten()

    # Identidad
    with torch.no_grad():
        identity_vec = identity_model(img_facenet).cpu().numpy().flatten()

    # Pose
    results = pose_estimator.process(img_rgb)
    pose_map = np.zeros((256, 256), dtype=np.uint8)
    if results.pose_landmarks:
        h, w = 256, 256
        for lm in results.pose_landmarks.landmark:
            x = int(lm.x * w)
            y = int(lm.y * h)
            if 0 <= x < w and 0 <= y < h:
                pose_map[y, x] = 255

    return style_vec, identity_vec, pose_map

#Función para procesar una carpeta completa

In [None]:
def process_folder(image_folder, style_out, input_out):
    os.makedirs(style_out, exist_ok=True)
    os.makedirs(input_out, exist_ok=True)
    img_paths = glob(os.path.join(image_folder, '*'))

    for img_path in tqdm(img_paths, desc=f"Procesando {image_folder}"):
        name = os.path.splitext(os.path.basename(img_path))[0]
        style_vec, identity_vec, pose_map = extract_all_features(img_path)
        if style_vec is None:
            continue
        np.save(os.path.join(style_out, f"{name}_style.npy"), style_vec)
        np.save(os.path.join(input_out, f"{name}_identity.npy"), identity_vec)
        cv2.imwrite(os.path.join(input_out, f"{name}_pose.png"), pose_map)

#Ejecutar el extractor sobre tus datasets

In [None]:
datasets = {
    "anime": {
        "input": "/content/drive/My Drive/StyleMatcher/dataset/train/anime",
        "style": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
        "others": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs",
    },
    "painting": {
        "input": "/content/drive/My Drive/StyleMatcher/dataset/train/painting",
        "style": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_styles",
        "others": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_inputs",
    }
}

for key, paths in datasets.items():
    process_folder(paths["input"], paths["style"], paths["others"])

Procesando /content/drive/My Drive/StyleMatcher/dataset/train/anime: 100%|██████████| 950/950 [15:11<00:00,  1.04it/s]
Procesando /content/drive/My Drive/StyleMatcher/dataset/train/painting: 100%|██████████| 750/750 [11:33<00:00,  1.08it/s]


In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 1.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose1_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose1_pose.json'
)

NameError: name 'extract_and_save_pose' is not defined

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 2.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose2_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose2_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 3.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose3_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose3_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 4.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose4_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose4_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 5.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose5_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose5_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 6.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose6_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose6_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 7.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose7_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose7_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 8.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose8_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose8_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 9.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose9_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose9_pose.json'
)

In [None]:
extract_and_save_pose(
    image_path='/content/drive/My Drive/StyleMatcher/input/pose 10.jpg',
    save_image_path='/content/drive/My Drive/StyleMatcher/input/pose10_pose.jpg',
    save_json_path='/content/drive/My Drive/StyleMatcher/input/pose10_pose.json'
)

In [None]:
extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo1.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/test_01_style.npy',
    layer=0
)

**deberia** aparecer un archivo .npy que contiene las matrices de Gram para las capas seleccionadas. Esta será tu “huella digital” del estilo, y luego se usará como guía para pintar la imagen generada.

**Recomendaciones:**
Para cada estilo (anime, cartoon, pintura), selecciona unas 5–10 imágenes base representativas.

Puedes extraer sus estilos y tener un banco de estilos preprocesado.

Esto será útil si luego quieres permitir seleccionar estilo desde una galería sin tener que volver a calcularlo.





**Despues de mejorar el codigo ya se puede seguir con el paso 3**

In [None]:
style_features = extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo2.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/estilo2_style.npy'
)


In [None]:
style_features = extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo3.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/estilo3_style.npy'
)


In [None]:
style_features = extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo4.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/estilo4_style.npy'
)


In [None]:
style_features = extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo5.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/estilo5_style.npy'
)


In [None]:
style_features = extract_style_features(
    image_path='/content/drive/My Drive/StyleMatcher/input/estilo6.jpg',
    save_path='/content/drive/My Drive/StyleMatcher/input/estilo6_style.npy'
)


#StylizerDatasetV3

#Importaciones

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F

#Definición del Dataset

In [None]:
class StylizerDatasetV3(Dataset):
    def __init__(self, image_dir, style_dir, input_dir, transform=None):
        self.image_dir = image_dir
        self.style_dir = style_dir
        self.input_dir = input_dir
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
        self.samples = self._gather_valid_samples()

    def _gather_valid_samples(self):
        all_images = [f for f in os.listdir(self.image_dir) if f.lower().endswith(('.jpg', '.png'))]
        valid = []
        for img_name in all_images:
            base = os.path.splitext(img_name)[0]
            pose_path = os.path.join(self.input_dir, f"{base}_pose.png")
            id_path = os.path.join(self.input_dir, f"{base}_identity.npy")
            style_path = os.path.join(self.style_dir, f"{base}_style.npy")
            if os.path.exists(pose_path) and os.path.exists(id_path) and os.path.exists(style_path):
                valid.append(base)
        return valid

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        base = self.samples[idx]

        # Imagen
        img_path = os.path.join(self.image_dir, f"{base}.jpg")
        if not os.path.exists(img_path):
            img_path = os.path.join(self.image_dir, f"{base}.png")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # Pose
        pose_path = os.path.join(self.input_dir, f"{base}_pose.png")
        pose = Image.open(pose_path).convert("L")
        pose = transforms.Resize((256, 256))(pose)
        pose = transforms.ToTensor()(pose)

        # Identidad y estilo
        identity = np.load(os.path.join(self.input_dir, f"{base}_identity.npy"))
        style = np.load(os.path.join(self.style_dir, f"{base}_style.npy"))

        # Forzamos tamaños seguros
        identity = torch.tensor(identity[:512], dtype=torch.float32)       # Asegura 512
        style = torch.tensor(style[:1472], dtype=torch.float32)           # Asegura 1472

        return {
            "image": image,
            "pose": pose,
            "identity": identity,
            "style": style
        }

#Generador StylizerUNet

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout=False):
        super().__init__()
        if down:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        if use_dropout:
            self.block.add_module("dropout", nn.Dropout(0.5))

    def forward(self, x):
        return self.block(x)

class StylizerUNet(nn.Module):
    def __init__(self, id_dim=512, style_dim=1472):
        super().__init__()
        self.emb_dim = id_dim + style_dim  # 1984

        self.embedding_expand = nn.Sequential(
            nn.Linear(self.emb_dim, 512 * 4 * 4),
            nn.ReLU(True)
        )

        # Encoder
        self.enc1 = UNetBlock(1, 64, down=True)    # 256 → 128
        self.enc2 = UNetBlock(64, 128, down=True)  # 128 → 64
        self.enc3 = UNetBlock(128, 256, down=True) # 64 → 32
        self.enc4 = UNetBlock(256, 512, down=True) # 32 → 16
        self.enc5 = UNetBlock(512, 512, down=True) # 16 → 8
        self.enc6 = UNetBlock(512, 512, down=True) # 8 → 4
        self.enc7 = UNetBlock(512, 512, down=True) # 4 → 2

        # Bottleneck (embedding + encoder)
        self.middle = nn.Sequential(
            nn.Conv2d(512 + 512, 512, 3, padding=1),
            nn.ReLU(True)
        )

        # Decoder
        self.dec1 = UNetBlock(512, 512, down=False, use_dropout=True)        # 2 → 4
        self.dec2 = UNetBlock(1024, 512, down=False, use_dropout=True)       # 4 → 8
        self.dec3 = UNetBlock(1024, 512, down=False, use_dropout=True)       # 8 → 16
        self.dec4 = UNetBlock(1024, 512, down=False)                         # 16 → 32
        self.dec5 = UNetBlock(768, 256, down=False)                          # 32 → 64
        self.dec6 = UNetBlock(384, 128, down=False)                          # 64 → 128
        self.dec7 = UNetBlock(192, 64, down=False)                           # 128 → 256
        self.final = nn.Conv2d(65, 3, kernel_size=3, padding=1)                      # 256 → 512
        self.tanh = nn.Tanh()

    def forward(self, pose_map, identity_vec, style_vec):
        emb = torch.cat([identity_vec, style_vec], dim=1)
        emb = self.embedding_expand(emb).view(-1, 512, 4, 4)

        # Encoder
        e1 = self.enc1(pose_map)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)

        # Fusion
        emb_up = F.interpolate(emb, size=e7.shape[2:])  # Garantiza matching
        bottleneck = self.middle(torch.cat([e7, emb_up], dim=1))

        # Decoder con skip connections
        d1 = self.dec1(bottleneck)
        d2 = self.dec2(torch.cat([d1, e6], dim=1))
        d3 = self.dec3(torch.cat([d2, e5], dim=1))
        d4 = self.dec4(torch.cat([d3, e4], dim=1))
        d5 = self.dec5(torch.cat([d4, e3], dim=1))
        d6 = self.dec6(torch.cat([d5, e2], dim=1))
        d7 = self.dec7(torch.cat([d6, e1], dim=1))

        # Ajustar tamaño del pose_map si es necesario
        pose_resized = F.interpolate(pose_map, size=d7.shape[2:])
        out = self.final(torch.cat([d7, pose_resized], dim=1))
        return self.tanh(out)

Este modelo está listo para usarse en entrenamiento. Admite entradas:

pose_map: tensor de forma [B, 1, 256, 256]

identity_vec: tensor [B, 512]

style_vec: tensor [B, 960]

Y genera imágenes de salida [B, 3, 256, 256].

#ENTRENAMIENTO

#Importaciones y configuraciones

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import models
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Cargar dataset y modelo

In [None]:
# Dataset
anime_dataset = StylizerDatasetV3(
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/anime",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs"
)

loader = DataLoader(anime_dataset, batch_size=4, shuffle=True, num_workers=2)

# Modelo
model = StylizerUNet().to(device)

#Función de pérdida L1 + Perceptual

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[0, 5, 10]):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features[:max(layer_ids)+1].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.layers = layer_ids
        self.criterion = nn.L1Loss()

    def forward(self, input, target):
        loss = 0
        x = input
        y = target
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.layers:
                loss += self.criterion(x, y)
        return loss

#Configurar optimizador y pérdida

In [None]:
criterion_l1 = nn.L1Loss()
criterion_perceptual = PerceptualLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))

#Loop de entrenamiento

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

# === OPTIMIZACIÓN GPU A100 ===
torch.backends.cudnn.benchmark = True  # Acelera convoluciones dinámicas

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 30
batch_size = 128  # A100 puede manejar incluso más si no hay OOM
lr = 1e-4

def train_model(category, image_dir, style_dir, input_dir, save_path):
    print(f"\n🚀 Entrenando modelo para: {category.upper()} ({epochs} épocas)")

    dataset = StylizerDatasetV3(image_dir, style_dir, input_dir)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                        num_workers=4, pin_memory=True)

    model = StylizerUNet(id_dim=512, style_dim=1472).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = torch.nn.SmoothL1Loss()

    best_loss = float("inf")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(loader, desc=f"[{category}] Época {epoch+1}/{epochs}")
        for batch in pbar:
            imgs = batch["image"].to(device, non_blocking=True)
            poses = batch["pose"].to(device, non_blocking=True)
            ids = batch["identity"].to(device, non_blocking=True)
            styles = batch["style"].to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(poses, ids, styles)

            if outputs.shape != imgs.shape:
                outputs = F.interpolate(outputs, size=imgs.shape[2:])

            loss = criterion(outputs, imgs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        scheduler.step()
        avg_loss = total_loss / len(loader)

        # Guardar el mejor modelo
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), save_path)
            print(f"💾 Nuevo mejor modelo guardado en {save_path} (loss={best_loss:.4f})")

    print(f"✅ Entrenamiento completo de {category}. Mejor loss: {best_loss:.4f}")

# === ENTRENAMIENTO DE ANIME Y LUEGO PAINTING ===
train_model(
    category="anime",
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/anime",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs",
    save_path="/content/drive/My Drive/StyleMatcher/model_anime_best.pt"
)

train_model(
    category="painting",
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/painting",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_inputs",
    save_path="/content/drive/My Drive/StyleMatcher/model_painting_best.pt"
)



🚀 Entrenando modelo para: ANIME (30 épocas)


[anime] Época 1/30: 100%|██████████| 8/8 [09:04<00:00, 68.11s/it, loss=0.142]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.2210)


[anime] Época 2/30: 100%|██████████| 8/8 [00:08<00:00,  1.07s/it, loss=0.0783]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.1003)


[anime] Época 3/30: 100%|██████████| 8/8 [00:13<00:00,  1.68s/it, loss=0.0531]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0603)


[anime] Época 4/30: 100%|██████████| 8/8 [00:08<00:00,  1.05s/it, loss=0.0382]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0429)


[anime] Época 5/30: 100%|██████████| 8/8 [00:08<00:00,  1.06s/it, loss=0.0318]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0337)


[anime] Época 6/30: 100%|██████████| 8/8 [00:08<00:00,  1.05s/it, loss=0.0284]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0285)


[anime] Época 7/30: 100%|██████████| 8/8 [00:13<00:00,  1.65s/it, loss=0.0225]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0248)


[anime] Época 8/30: 100%|██████████| 8/8 [00:09<00:00,  1.22s/it, loss=0.02]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0226)


[anime] Época 9/30: 100%|██████████| 8/8 [00:08<00:00,  1.07s/it, loss=0.0197]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0209)


[anime] Época 10/30: 100%|██████████| 8/8 [00:08<00:00,  1.11s/it, loss=0.0189]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0196)


[anime] Época 11/30: 100%|██████████| 8/8 [00:08<00:00,  1.06s/it, loss=0.0187]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0186)


[anime] Época 12/30: 100%|██████████| 8/8 [00:08<00:00,  1.05s/it, loss=0.0178]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0176)


[anime] Época 13/30: 100%|██████████| 8/8 [00:08<00:00,  1.07s/it, loss=0.0172]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0170)


[anime] Época 14/30: 100%|██████████| 8/8 [00:08<00:00,  1.05s/it, loss=0.0163]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0164)


[anime] Época 15/30: 100%|██████████| 8/8 [00:08<00:00,  1.07s/it, loss=0.0153]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0158)


[anime] Época 16/30: 100%|██████████| 8/8 [00:07<00:00,  1.01it/s, loss=0.0185]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0157)


[anime] Época 17/30: 100%|██████████| 8/8 [00:08<00:00,  1.02s/it, loss=0.0154]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0151)


[anime] Época 18/30: 100%|██████████| 8/8 [00:08<00:00,  1.09s/it, loss=0.0154]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0148)


[anime] Época 19/30: 100%|██████████| 8/8 [00:07<00:00,  1.03it/s, loss=0.0141]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0144)


[anime] Época 20/30: 100%|██████████| 8/8 [00:08<00:00,  1.06s/it, loss=0.0142]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0143)


[anime] Época 21/30: 100%|██████████| 8/8 [00:07<00:00,  1.06it/s, loss=0.0145]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0142)


[anime] Época 22/30: 100%|██████████| 8/8 [00:09<00:00,  1.16s/it, loss=0.015]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0140)


[anime] Época 23/30: 100%|██████████| 8/8 [00:07<00:00,  1.04it/s, loss=0.0138]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0138)


[anime] Época 24/30: 100%|██████████| 8/8 [00:08<00:00,  1.01s/it, loss=0.0151]
[anime] Época 25/30: 100%|██████████| 8/8 [00:12<00:00,  1.56s/it, loss=0.0148]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0137)


[anime] Época 26/30: 100%|██████████| 8/8 [00:07<00:00,  1.05it/s, loss=0.0154]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0136)


[anime] Época 27/30: 100%|██████████| 8/8 [00:08<00:00,  1.07s/it, loss=0.0161]
[anime] Época 28/30: 100%|██████████| 8/8 [00:12<00:00,  1.57s/it, loss=0.0124]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_anime_best.pt (loss=0.0134)


[anime] Época 29/30: 100%|██████████| 8/8 [00:07<00:00,  1.04it/s, loss=0.0136]
[anime] Época 30/30: 100%|██████████| 8/8 [00:07<00:00,  1.08it/s, loss=0.0135]


✅ Entrenamiento completo de anime. Mejor loss: 0.0134

🚀 Entrenando modelo para: PAINTING (30 épocas)


[painting] Época 1/30: 100%|██████████| 6/6 [09:15<00:00, 92.55s/it, loss=0.0869]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.1103)


[painting] Época 2/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0492]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0628)


[painting] Época 3/30: 100%|██████████| 6/6 [00:06<00:00,  1.11s/it, loss=0.0436]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0441)


[painting] Época 4/30: 100%|██████████| 6/6 [00:07<00:00,  1.22s/it, loss=0.0399]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0399)


[painting] Época 5/30: 100%|██████████| 6/6 [00:06<00:00,  1.05s/it, loss=0.0355]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0376)


[painting] Época 6/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0345]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0344)


[painting] Época 7/30: 100%|██████████| 6/6 [00:07<00:00,  1.20s/it, loss=0.0306]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0312)


[painting] Época 8/30: 100%|██████████| 6/6 [00:06<00:00,  1.16s/it, loss=0.0284]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0291)


[painting] Época 9/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0254]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0273)


[painting] Época 10/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0256]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0261)


[painting] Época 11/30: 100%|██████████| 6/6 [00:06<00:00,  1.15s/it, loss=0.0246]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0248)


[painting] Época 12/30: 100%|██████████| 6/6 [00:07<00:00,  1.21s/it, loss=0.0254]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0237)


[painting] Época 13/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0232]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0229)


[painting] Época 14/30: 100%|██████████| 6/6 [00:06<00:00,  1.14s/it, loss=0.0217]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0222)


[painting] Época 15/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0227]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0217)


[painting] Época 16/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0216]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0211)


[painting] Época 17/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0211]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0207)


[painting] Época 18/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0198]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0202)


[painting] Época 19/30: 100%|██████████| 6/6 [00:07<00:00,  1.19s/it, loss=0.0211]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0201)


[painting] Época 20/30: 100%|██████████| 6/6 [00:06<00:00,  1.16s/it, loss=0.0195]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0197)


[painting] Época 21/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0192]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0197)


[painting] Época 22/30: 100%|██████████| 6/6 [00:07<00:00,  1.19s/it, loss=0.0197]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0194)


[painting] Época 23/30: 100%|██████████| 6/6 [00:07<00:00,  1.19s/it, loss=0.0177]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0192)


[painting] Época 24/30: 100%|██████████| 6/6 [00:07<00:00,  1.17s/it, loss=0.0208]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0191)


[painting] Época 25/30: 100%|██████████| 6/6 [00:06<00:00,  1.17s/it, loss=0.0206]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0190)


[painting] Época 26/30: 100%|██████████| 6/6 [00:07<00:00,  1.19s/it, loss=0.019]
[painting] Época 27/30: 100%|██████████| 6/6 [00:06<00:00,  1.03s/it, loss=0.0188]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0189)


[painting] Época 28/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.0189]


💾 Nuevo mejor modelo guardado en /content/drive/My Drive/StyleMatcher/model_painting_best.pt (loss=0.0188)


[painting] Época 29/30: 100%|██████████| 6/6 [00:07<00:00,  1.18s/it, loss=0.019]
[painting] Época 30/30: 100%|██████████| 6/6 [00:06<00:00,  1.04s/it, loss=0.0184]

✅ Entrenamiento completo de painting. Mejor loss: 0.0188





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import gradio as gr
import numpy as np
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformaciones
def image_loader(img, max_size=400, shape=None):
    if img.mode != 'RGB':
        img = img.convert("RGB")
    if max(img.size) > max_size:
        size = max_size
    else:
        size = max(img.size)
    if shape is not None:
        size = shape
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])])
    image = transform(img).unsqueeze(0)
    return image.to(device)

def im_convert(tensor):
    image = tensor.clone().detach().cpu().squeeze(0)
    image = image.numpy().transpose(1, 2, 0)
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    image = np.clip(image, 0, 1)
    return image

# Modelo VGG19
vgg = models.vgg19(pretrained=True).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

# Extracción de características
def get_features(image, model, layers=None):
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',
            '28': 'conv5_1'
        }
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    return torch.mm(tensor, tensor.t())

# Función principal de inferencia
def style_transfer(content_img, style_img):
    content = image_loader(content_img)
    style = image_loader(style_img, shape=content.shape[-2:])

    content_features = get_features(content, vgg)
    style_features = get_features(style, vgg)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

    target = content.clone().requires_grad_(True).to(device)
    optimizer = optim.Adam([target], lr=0.003)

    style_weights = {
        'conv1_1': 1.0,
        'conv2_1': 0.75,
        'conv3_1': 0.2,
        'conv4_1': 0.2,
        'conv5_1': 0.2
    }

    content_weight = 1e4
    style_weight = 1e2

    steps = 200
    for i in range(steps):
        target_features = get_features(target, vgg)
        content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)

        style_loss = 0
        for layer in style_weights:
            target_feature = target_features[layer]
            target_gram = gram_matrix(target_feature)
            style_gram = style_grams[layer]
            layer_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
            style_loss += layer_loss / (target_feature.shape[1] ** 2)

        total_loss = content_weight * content_loss + style_weight * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    return im_convert(target)

# Interfaz Gradio
gr.Interface(
    fn=style_transfer,
    inputs=[
        gr.Image(label="Imagen base (contenido)", type="pil"),
        gr.Image(label="Imagen estilo", type="pil")
    ],
    outputs=gr.Image(label="Resultado estilizado"),
    title="🎨 Style Transfer con VGG19",
    description="Generación de imagen estilizada combinando contenido y estilo. No necesita entrenamiento.",
).launch(debug=True)


In [None]:
sample = anime_dataset[0]
print("Image:", sample["image"].shape)
print("Pose:", sample["pose"].shape)
print("Identity:", sample["identity"].shape)
print("Style:", sample["style"].shape)

In [None]:
sample = next(iter(DataLoader(StylizerDatasetV3(
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/anime",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs"
), batch_size=1)))

print("🟢 image:", sample["image"].shape)
print("🟠 pose:", sample["pose"].shape)
print("🔵 identity:", sample["identity"].shape)
print("🔴 style:", sample["style"].shape)

In [None]:
# Obtener 1 batch del dataloader
dataset = StylizerDatasetV3(
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/anime",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs"
)

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=1)

model = StylizerUNet(id_dim=512, style_dim=1472).to("cuda" if torch.cuda.is_available() else "cpu")

# Ejecutar paso forward y mostrar formas
batch = next(iter(loader))
img = batch["image"]
pose = batch["pose"]
id_vec = batch["identity"]
style_vec = batch["style"]

print("📦 image:", img.shape)
print("🟠 pose:", pose.shape)
print("🔵 identity:", id_vec.shape)
print("🔴 style:", style_vec.shape)

# Mover al dispositivo
device = "cuda" if torch.cuda.is_available() else "cpu"
pose = pose.to(device)
id_vec = id_vec.to(device)
style_vec = style_vec.to(device)

with torch.no_grad():
    output = model(pose, id_vec, style_vec)
    print("✅ output:", output.shape)


In [None]:
print("🎯 output.shape:", outputs.shape)
print("📦 image.shape:", img.shape)
print("✅ same shape:", outputs.shape == img.shape)
print("📊 output dtype:", outputs.dtype)
print("📊 image dtype:", img.dtype)
print("🧠 output device:", outputs.device)
print("🧠 image device:", img.device)


In [None]:
train_model(
    dataset_name="anime",
    image_dir="/content/drive/My Drive/StyleMatcher/dataset/train/anime",
    style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
    input_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs",
    save_path="/content/drive/My Drive/StyleMatcher/model_anime_best.pt",
    checkpoint_path=None,
    epochs=30
)


#Ejemplo con 1 imagen

In [None]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset # Ensure Dataset is imported
from torchvision import transforms     # Ensure transforms is imported
from PIL import Image                # Ensure Image is imported
import numpy as np                   # Ensure numpy is imported
import json                          # Ensure json is imported
import cv2                           # Ensure cv2 is imported

# (StylizerDataset definition remains the same as in the previous suggestion)
class StylizerDataset(Dataset):
    def __init__(self, input_dir, target_dir, name_prefix='test_01'):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.name = name_prefix
        self.img_size = (256, 256)

        self.transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor()
        ])

    def load_pose_heatmap(self, path_json):
        with open(path_json, 'r') as f:
            data = json.load(f)
        heatmap = np.zeros(self.img_size, dtype=np.uint8)
        for lm in data.get('pose', []):
            x = int(lm['x'] * self.img_size[1])
            y = int(lm['y'] * self.img_size[0])
            cv2.circle(heatmap, (x, y), 4, 255, -1)
        return torch.tensor(heatmap, dtype=torch.float32).unsqueeze(0) / 255.0

    def __getitem__(self, idx):
        img = Image.open(f"{self.input_dir}/{self.name}.jpg").convert('RGB')
        pose = self.load_pose_heatmap(f"{self.input_dir}/{self.name}_pose.json")

        # Load identity and ensure it's a 1D vector (excluding batch dim)
        identity = np.load(f"{self.input_dir}/{self.name}_identity.npy").squeeze() # Remove potential extra dimensions
        identity = torch.tensor(identity, dtype=torch.float32)
        if identity.ndim == 0: # Handle case where squeeze results in scalar
             identity = identity.unsqueeze(0) # Make it 1D: [512] or [64] etc.

        # Load style and ensure it's a 1D vector (excluding batch dim)
        style = np.load(f"{self.input_dir}/{self.name}_style.npy").squeeze() # Remove potential extra dimensions
        style = torch.tensor(style, dtype=torch.float32)
        if style.ndim == 0: # Handle case where squeeze results in scalar
            style = style.unsqueeze(0) # Make it 1D: [64] or [512] etc.

        target = Image.open(f"{self.target_dir}/{self.name}_target.jpg").convert('RGB')

        return {
            'image': self.transform(img),
            'pose': pose,
            'identity': identity,
            'style': style,
            'target': self.transform(target)
        }

    def __len__(self):
        return 1


# Preparar dataset
input_dir = "/content/drive/My Drive/StyleMatcher/dataset_mini/inputs"
target_dir = "/content/drive/My Drive/StyleMatcher/dataset_mini/targets"
# Ensure the dataset file exists for 'test_01' with identity and style embeddings
try:
    dataset = StylizerDataset(input_dir, target_dir)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
except FileNotFoundError as e:
    print(f"Error: Make sure the dataset files exist at {input_dir} and {target_dir}. Details: {e}")
    # You might want to exit or handle this case appropriately
    exit()


# Inicializar modelo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MiniStylizedUNet(style_dim=64, identity_dim=512).to(device)

# Pérdida y optimizador
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Función para mostrar resultados
def show_result(pred, target, epoch):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    # Convert tensor to numpy and change shape from [C, H, W] to [H, W, C] for matplotlib
    ax[0].imshow(pred.permute(1, 2, 0).detach().cpu().numpy())
    ax[0].set_title("Generado")
    ax[1].imshow(target.permute(1, 2, 0).detach().cpu().numpy())
    ax[1].set_title("Esperado")
    for a in ax:
        a.axis('off')
    plt.suptitle(f"Resultado Epoch {epoch}")
    plt.show()

# Entrenamiento
model.train()
for epoch in range(300):
    for batch in dataloader:
        img = batch['image'].to(device)
        pose = batch['pose'].to(device)
        identity = batch['identity'].to(device)
        style = batch['style'].to(device)
        target = batch['target'].to(device)  # <-- Línea añadida
        strength = torch.ones((identity.size(0), 1), device=device)

        output = model(img, pose, style, identity, strength)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch % 50 == 0 or epoch == 299:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        show_result(output[0], target[0], epoch)

In [None]:
# ----------------------------------------------
# Script para verificar resolución y diversidad de pose en imágenes
# ----------------------------------------------
import os
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from tqdm import tqdm

# Rutas
dataset_dir = "/content/drive/My Drive/StyleMatcher/dataset/train"
categories = ["anime", "painting"] # Asegúrate de que estos nombres coincidan exactamente con los nombres de las carpetas

# --- Añadir verificación de directorios ---
print("Verificando la existencia de las carpetas de categorías:")
for category in categories:
    folder_path = os.path.join(dataset_dir, category)
    if not os.path.exists(folder_path):
        print(f"❌ ERROR: La carpeta '{folder_path}' no existe.")
        # Considerar salir o manejar este error de forma diferente
        # exit() # Descomentar para detener la ejecución si una carpeta no existe
    else:
        print(f"✅ La carpeta '{folder_path}' existe.")
print("-" * 30)
# -----------------------------------------


# Inicializar MediaPipe Pose
mp_pose = mp.solutions.pose
pose_detector = mp_pose.Pose(static_image_mode=True)

# Configuración
min_resolution = (200, 200)  # resolución mínima aceptable
max_images_per_category = 20

# Visualización
def show_pose_landmarks(image_path, landmarks, category):
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w, _ = img.shape

    for lm in landmarks:
        cx, cy = int(lm.x * w), int(lm.y * h)
        cv2.circle(img_rgb, (cx, cy), 3, (0, 255, 0), -1)

    plt.imshow(img_rgb)
    plt.title(f"{category} - Pose detectada")
    plt.axis('off')
    plt.show()

# Proceso
for category in categories:
    folder = os.path.join(dataset_dir, category)
    # Ya verificamos si la carpeta existe, ahora podemos listar sin riesgo de FileNotFoundError si la verificación no detuvo la ejecución.
    images = [f for f in os.listdir(folder) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
    print(f"\n🔍 Verificando categoría: {category} (máx. {max_images_per_category})")

    count = 0
    # Limitar el número de imágenes a procesar para evitar mostrar demasiadas figuras si hay muchas imágenes.
    for img_file in tqdm(images[:max_images_per_category], desc=f"Procesando imágenes de {category}"):
        img_path = os.path.join(folder, img_file)
        img = cv2.imread(img_path)
        if img is None:
            print(f"⚠️ No se pudo leer la imagen: {img_file}")
            continue

        h, w, _ = img.shape
        if h < min_resolution[0] or w < min_resolution[1]:
            print(f"⚠️ Imagen muy pequeña: {img_file} ({w}x{h})")
            continue

        # Pose detection
        results = pose_detector.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        if not results or not results.pose_landmarks:
            print(f"❌ Sin pose detectada: {img_file}")
            continue

        # Mostrar imagen con pose
        show_pose_landmarks(img_path, results.pose_landmarks.landmark, category)
        count += 1

        if count >= max_images_per_category:
            break

# Script para generar dataset_final con 900 imágenes de anime y pintura (train)

In [None]:
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import mediapipe as mp
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
from torchvision.models import vgg19

# Configuración de rutas

In [None]:
source_dir = "/content/drive/My Drive/StyleMatcher/dataset/train"
base_input_out = "/content/drive/My Drive/StyleMatcher/dataset_final/inputs"
base_target_out = "/content/drive/My Drive/StyleMatcher/dataset_final/targets"

input_dirs = {
    "anime": os.path.join(base_input_out, "anime_inputs"),
    "painting": os.path.join(base_input_out, "painting_inputs")
}
target_dirs = {
    "anime": os.path.join(base_target_out, "anime"),
    "painting": os.path.join(base_target_out, "painting")
}

for d in input_dirs.values(): os.makedirs(d, exist_ok=True)
for d in target_dirs.values(): os.makedirs(d, exist_ok=True)

# Modelos necesarios

In [None]:
vgg = vgg19(pretrained=True).features.eval()
vgg_layers = [0, 5]  # capas 0 y 5 para estilo más informativo
facenet = InceptionResnetV1(pretrained='vggface2').eval()

transform_style = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_identity = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor()
])

transform_target = transforms.Compose([
    transforms.Resize((256, 256))
])

mp_pose = mp.solutions.pose.Pose(static_image_mode=True)


# Funciones auxiliares

In [None]:
def extract_pose_landmarks(image_path):
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = mp_pose.process(image_rgb)
    if not results.pose_landmarks:
        return None
    keypoints = [{"x": lm.x, "y": lm.y} for lm in results.pose_landmarks.landmark]
    return keypoints

def extract_face_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform_identity(image).unsqueeze(0)
    with torch.no_grad():
        return facenet(tensor).numpy()

def extract_style_vector(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform_style(image).unsqueeze(0)
    features = []
    x = tensor
    for i, layer in enumerate(vgg):
        x = layer(x)
        if i in vgg_layers:
            pooled = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
            features.append(pooled)
    final_style = torch.cat(features, dim=1).numpy()
    return final_style

# Proceso principal

In [None]:
categories = ["anime", "painting"]
max_images = 900

for category in categories:
    folder = os.path.join(source_dir, category)
    images = sorted([f for f in os.listdir(folder) if f.lower().endswith(('.jpg', '.png', '.jpeg'))])[:max_images]

    for idx, filename in tqdm(enumerate(images), total=len(images), desc=f"Procesando {category}"):
        path = os.path.join(folder, filename)
        base_name = f"{category}_{str(idx).zfill(4)}"
        try:
            img = Image.open(path).convert("RGB")
            img.save(os.path.join(input_dirs[category], f"{base_name}.jpg"))
            transform_target(img).save(os.path.join(target_dirs[category], f"{base_name}_target.jpg"))

            pose = extract_pose_landmarks(path)
            if pose is None: continue
            with open(os.path.join(input_dirs[category], f"{base_name}_pose.json"), 'w') as f:
                json.dump({"pose": pose}, f)

            identity = extract_face_embedding(path)
            np.save(os.path.join(input_dirs[category], f"{base_name}_identity.npy"), identity)

            style = extract_style_vector(path)
            np.save(os.path.join(input_dirs[category], f"{base_name}_style.npy"), style)

        except Exception as e:
            continue

In [None]:
import os
import json
import cv2
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import mediapipe as mp
import torchvision.transforms as transforms
from torchvision.models import vgg19
from facenet_pytorch import InceptionResnetV1

# Modelos
vgg = vgg19(pretrained=True).features.eval()
facenet = InceptionResnetV1(pretrained='vggface2').eval()
mp_pose = mp.solutions.pose.Pose(static_image_mode=True)

# Transformaciones
transform_style = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

transform_identity = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor()
])

# Funciones auxiliares
def extract_pose_landmarks(image_path):
    image = cv2.imread(image_path)
    if image is None:
        return None
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = mp_pose.process(image_rgb)
    if not results.pose_landmarks:
        return None
    return [{"x": lm.x, "y": lm.y} for lm in results.pose_landmarks.landmark]

def extract_face_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    img_tensor = transform_identity(image).unsqueeze(0)
    with torch.no_grad():
        return facenet(img_tensor).numpy()

def extract_style_vector(image_path):
    image = Image.open(image_path).convert("RGB")
    img_tensor = transform_style(image).unsqueeze(0)
    with torch.no_grad():
        x = img_tensor
        for i, layer in enumerate(vgg):
            x = layer(x)
            if i == 0:  # conv1_1
                break
        pooled = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
    return pooled.numpy()

# Carpetas
bases = {
    "anime": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs",
    "painting": "/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_inputs"
}

for categoria, base_dir in bases.items():
    print(f"🧩 Procesando {categoria}")
    files = sorted([f for f in os.listdir(base_dir) if f.endswith(".jpg")])

    for file in tqdm(files):
        path_base = os.path.join(base_dir, file)
        name = os.path.splitext(file)[0]

        # Pose
        pose_path = os.path.join(base_dir, name + "_pose.json")
        if not os.path.exists(pose_path):
            pose = extract_pose_landmarks(path_base)
            if pose is not None:
                with open(pose_path, 'w') as f:
                    json.dump({"pose": pose}, f)

        # Identidad
        identity_path = os.path.join(base_dir, name + "_identity.npy")
        if not os.path.exists(identity_path):
            try:
                identity = extract_face_embedding(path_base)
                np.save(identity_path, identity)
            except:
                continue

        # Estilo
        style_path = os.path.join(base_dir.replace("inputs", "styles"), name + "_style.npy")
        if not os.path.exists(style_path):
            try:
                style = extract_style_vector(path_base)
                os.makedirs(os.path.dirname(style_path), exist_ok=True)
                np.save(style_path, style)
            except:
                continue


#Generador

In [None]:
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import mediapipe as mp
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
from torchvision.models import vgg19

# 🔧 Rutas base
ruta_base = "/content/drive/My Drive/StyleMatcher/dataset_final/inputs"
ruta_target = "/content/drive/My Drive/StyleMatcher/dataset_final/targets"
categorias = {
    "anime": os.path.join(ruta_base, "anime_styles"),
    "painting": os.path.join(ruta_base, "painting_styles")
}

# 🔧 Transformaciones
vgg = vgg19(pretrained=True).features.eval()
vgg_layers = [0]

facenet = InceptionResnetV1(pretrained='vggface2').eval()

transform_style = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

transform_identity = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor()
])

transform_target = transforms.Compose([
    transforms.Resize((256, 256))
])

mp_pose = mp.solutions.pose.Pose(static_image_mode=True)

# 🔧 Funciones
def extract_pose_landmarks(image_path):
    image = cv2.imread(image_path)
    if image is None:
        return None
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = mp_pose.process(image_rgb)
    if not results.pose_landmarks:
        return None
    return [{"x": lm.x, "y": lm.y} for lm in results.pose_landmarks.landmark]

def extract_face_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform_identity(image).unsqueeze(0)
    with torch.no_grad():
        emb = facenet(tensor)
    return emb.numpy()

def extract_style_vector(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform_style(image).unsqueeze(0)
    x = tensor
    for i, layer in enumerate(vgg):
        x = layer(x)
        if i == vgg_layers[-1]:
            break
    pooled = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
    return pooled.numpy()

# 🔁 Recorremos cada categoría
for clase, carpeta in categorias.items():
    output_input_dir = os.path.join(ruta_base, f"{clase}_inputs")
    output_target_dir = os.path.join(ruta_target, clase)
    os.makedirs(output_input_dir, exist_ok=True)
    os.makedirs(output_target_dir, exist_ok=True)

    imagenes = sorted([f for f in os.listdir(carpeta) if f.lower().endswith(('.jpg', '.png', '.jpeg'))])

    for idx, nombre_archivo in tqdm(enumerate(imagenes), total=len(imagenes), desc=f"Procesando {clase}"):
        nombre_base = f"{clase}_{str(idx).zfill(4)}"
        ruta_imagen = os.path.join(carpeta, nombre_archivo)

        try:
            # Target
            img = Image.open(ruta_imagen).convert("RGB")
            target = transform_target(img)
            target.save(os.path.join(output_target_dir, f"{nombre_base}_target.jpg"))

            # Pose
            pose = extract_pose_landmarks(ruta_imagen)
            if pose is None:
                continue
            with open(os.path.join(output_input_dir, f"{nombre_base}_pose.json"), "w") as f:
                json.dump({"pose": pose}, f)

            # Identidad
            identidad = extract_face_embedding(ruta_imagen)
            np.save(os.path.join(output_input_dir, f"{nombre_base}_identity.npy"), identidad)

            # Estilo
            estilo = extract_style_vector(ruta_imagen)
            np.save(os.path.join(output_input_dir, f"{nombre_base}_style.npy"), estilo)

        except Exception as e:
            print(f"❌ Fallo con {nombre_archivo}: {e}")
            continue


# Entrenamiento

In [None]:
import os
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# -------------------------------------------------
# Configuración general
# -------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🖥️ Usando dispositivo:", device)

BATCH_SIZE = 4
INPUT_VECTOR_SIZE = 1025  # 512 pose + 449 identidad + 64 estilo
EPOCHS = 40

# -------------------------------------------------
# Dataset personalizado
# -------------------------------------------------
class StylizerDatasetV2(Dataset):
    def __init__(self, input_dir, target_dir, style_dir):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.style_dir = style_dir
        self.names = []

        # Filtrar solo los ejemplos que tienen todos los archivos
        for file in os.listdir(input_dir):
            if not file.endswith(".jpg"):
                continue
            name = file.split(".")[0]
            pose_path = os.path.join(input_dir, name + "_pose.json")
            identity_path = os.path.join(input_dir, name + "_identity.npy")
            style_path = os.path.join(style_dir, name + "_style.npy")
            target_path = os.path.join(target_dir, name + "_target.jpg")

            if all(os.path.exists(p) for p in [pose_path, identity_path, style_path, target_path]):
                self.names.append(name)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]

        # Cargar imagen target
        target = Image.open(os.path.join(self.target_dir, name + "_target.jpg")).convert("RGB")
        target = target.resize((256, 256))
        target = torch.FloatTensor(np.array(target)).permute(2, 0, 1) / 255.0

        # Cargar pose
        with open(os.path.join(self.input_dir, name + "_pose.json"), "r") as f:
            pose_data = json.load(f)["pose"]
        pose_map = np.zeros((512,), dtype=np.float32)
        for i, lm in enumerate(pose_data[:256]):
            pose_map[i * 2] = lm["x"]
            pose_map[i * 2 + 1] = lm["y"]

        # Cargar identidad
        identity = np.load(os.path.join(self.input_dir, name + "_identity.npy")).flatten()
        identity = identity[:449]  # Limitamos a 449 si hace falta

        # Cargar estilo
        style = np.load(os.path.join(self.style_dir, name + "_style.npy")).flatten()

        # Concatenar
        vector = np.concatenate([pose_map, identity, style], axis=0)
        vector = torch.tensor(vector, dtype=torch.float32)

        return vector, target

# -------------------------------------------------
# Modelo U-Net modificado
# -------------------------------------------------
class MiniStylizedUNet(nn.Module):
    def __init__(self, input_vector_size):
        super(MiniStylizedUNet, self).__init__()
        self.input_vector_size = input_vector_size

        self.fc = nn.Sequential(
            nn.Linear(input_vector_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256 * 4 * 4),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),    # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=4, padding=0),     # 256x256
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 256, 4, 4)
        x = self.decoder(x)
        return x

# -------------------------------------------------
# Entrenamiento con Fine-Tuning
# -------------------------------------------------
def entrenar(dataset_dir, style_dir, target_dir, nombre_categoria):
    dataset = StylizerDatasetV2(dataset_dir, target_dir, style_dir)
    if len(dataset) == 0:
        print(f"⚠️ No se encontraron ejemplos completos para {nombre_categoria}.")
        return

    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_set, val_set = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=1)

    model = MiniStylizedUNet(INPUT_VECTOR_SIZE).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.L1Loss()

    model_path = f"/content/drive/My Drive/StyleMatcher/model_{nombre_categoria}_best.pt"
    best_val_loss = float("inf")

    # 🔁 Intentar cargar modelo preentrenado
    if os.path.exists(model_path):
        print(f"📦 Cargando modelo existente: {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("✅ Modelo cargado correctamente")

    for epoch in range(EPOCHS):
        print(f"\n📘 Época {epoch+1}/{EPOCHS} — Categoría: {nombre_categoria.upper()}")
        model.train()
        total_loss = 0
        for vec, target in tqdm(train_loader, desc="Entrenando"):
            vec = vec.to(device)
            target = target.to(device)
            output = model(vec)
            loss = loss_fn(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"📉 Pérdida entrenamiento: {avg_train_loss:.4f}")

        # 🔍 Validación
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for vec, target in val_loader:
                vec = vec.to(device)
                target = target.to(device)
                output = model(vec)
                val_loss += loss_fn(output, target).item()

            avg_val_loss = val_loss / len(val_loader)
            print(f"✅ Validación pérdida: {avg_val_loss:.4f}")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), model_path)
                print(f"💾 Nuevo mejor modelo guardado en {model_path}")

    print(f"🏁 Fine-tuning de {nombre_categoria} completado")

# -------------------------------------------------
# Lanzar entrenamientos (Fine-Tuning)
# -------------------------------------------------
if __name__ == "__main__":
    print("🚀 Comenzando fine-tuning...")

    entrenar(
        dataset_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs",
        style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles",
        target_dir="/content/drive/My Drive/StyleMatcher/dataset_final/targets/anime",
        nombre_categoria="anime"
    )

    entrenar(
        dataset_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_inputs",
        style_dir="/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_styles",
        target_dir="/content/drive/My Drive/StyleMatcher/dataset_final/targets/painting",
        nombre_categoria="painting"
    )


In [None]:
#ruta imagenes
/content/drive/My Drive/StyleMatcher/dataset/train/anime
/content/drive/My Drive/StyleMatcher/dataset/train/painting
#ruta estilos
/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_styles
/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_styles
#ruta posiciones e identidad
/content/drive/My Drive/StyleMatcher/dataset_final/inputs/painting_inputs
/content/drive/My Drive/StyleMatcher/dataset_final/inputs/anime_inputs

In [None]:
# === INTERFAZ GRADIO FINAL ===
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Modelo de Estilo: VGG19 ---
vgg = models.vgg19(pretrained=True).features[:29].eval().to(device)
for p in vgg.parameters():
    p.requires_grad = False

def extract_style(img_tensor):
    layers = [0, 5, 10, 19, 28]
    features = []
    x = img_tensor
    for i, layer in enumerate(vgg):
        x = layer(x)
        if i in layers:
            features.append(x.mean([2, 3]))
    return torch.cat(features, dim=1)

# --- Modelo de Identidad: FaceNet ---
!pip install facenet-pytorch --quiet

from facenet_pytorch import InceptionResnetV1
facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

# --- Modelo de Estilización ---
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout=False):
        super().__init__()
        if down:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        if use_dropout:
            self.block.add_module("dropout", nn.Dropout(0.5))

    def forward(self, x):
        return self.block(x)

class StylizerUNet(nn.Module):
    def __init__(self, id_dim=512, style_dim=1472):
        super().__init__()
        self.emb_dim = id_dim + style_dim
        self.embedding_expand = nn.Sequential(
            nn.Linear(self.emb_dim, 512 * 4 * 4),
            nn.ReLU(True)
        )
        self.enc1 = UNetBlock(1, 64, down=True)
        self.enc2 = UNetBlock(64, 128, down=True)
        self.enc3 = UNetBlock(128, 256, down=True)
        self.enc4 = UNetBlock(256, 512, down=True)
        self.enc5 = UNetBlock(512, 512, down=True)
        self.enc6 = UNetBlock(512, 512, down=True)
        self.enc7 = UNetBlock(512, 512, down=True)
        self.middle = nn.Sequential(nn.Conv2d(1024, 512, 3, padding=1), nn.ReLU(True))
        self.dec1 = UNetBlock(512, 512, down=False, use_dropout=True)
        self.dec2 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.dec3 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.dec4 = UNetBlock(1024, 512, down=False)
        self.dec5 = UNetBlock(768, 256, down=False)
        self.dec6 = UNetBlock(384, 128, down=False)
        self.dec7 = UNetBlock(192, 64, down=False)
        self.final = nn.Conv2d(65, 3, 3, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, pose_map, identity_vec, style_vec):
        emb = torch.cat([identity_vec, style_vec], dim=1)
        emb = self.embedding_expand(emb).view(-1, 512, 4, 4)
        e1 = self.enc1(pose_map)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        emb_up = F.interpolate(emb, size=e7.shape[2:])
        bottleneck = self.middle(torch.cat([e7, emb_up], dim=1))
        d1 = self.dec1(bottleneck)
        d2 = self.dec2(torch.cat([d1, e6], dim=1))
        d3 = self.dec3(torch.cat([d2, e5], dim=1))
        d4 = self.dec4(torch.cat([d3, e4], dim=1))
        d5 = self.dec5(torch.cat([d4, e3], dim=1))
        d6 = self.dec6(torch.cat([d5, e2], dim=1))
        d7 = self.dec7(torch.cat([d6, e1], dim=1))
        pose_resized = F.interpolate(pose_map, size=d7.shape[2:])
        out = self.final(torch.cat([d7, pose_resized], dim=1))
        return self.tanh(out)

# --- Transformaciones de entrada ---
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

def generate_image(pose_img, id_img, style_img, mode):
    # Transformaciones
    pose = transform(Image.fromarray(pose_img).convert("L")).unsqueeze(0).to(device)
    id_tensor = transform(Image.fromarray(id_img).convert("RGB")).unsqueeze(0).to(device)
    style_tensor = transform(Image.fromarray(style_img).convert("RGB")).unsqueeze(0).to(device)

    # Extraer vectores
    identity_vec = facenet(id_tensor)
    style_vec = extract_style(style_tensor)

    # 🔧 Normalización de vectores (muy importante)
    identity_vec = F.normalize(identity_vec, dim=1)
    style_vec = F.normalize(style_vec, dim=1)

    # Cargar modelo
    model = StylizerUNet().to(device)
    if mode == "anime":
        model.load_state_dict(torch.load("/content/model_anime_best.pt", map_location=device))
    else:
        model.load_state_dict(torch.load("/content/model_painting_best.pt", map_location=device))
    model.eval()

    # Inferencia
    with torch.no_grad():
        output = model(pose, identity_vec, style_vec)
        output = output.squeeze(0).cpu().permute(1, 2, 0).clamp(0, 1).numpy()
        return Image.fromarray((output * 255).astype(np.uint8))


# === INTERFAZ ===
gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Image(label="Mapa de Pose (PNG)", type="numpy"),
        gr.Image(label="Imagen de Identidad (PNG)", type="numpy"),
        gr.Image(label="Imagen de Estilo (PNG)", type="numpy"),
        gr.Radio(["anime", "painting"], label="Modelo a Usar")
    ],
    outputs=gr.Image(label="Imagen Estilizada"),
    title="StyleMatcher - Generador de Imágenes",
    description="Sube una imagen de pose, una de identidad y una de estilo. Elige el modelo (anime o pintura) para generar tu imagen estilizada."
).launch(debug=True)




  0%|          | 0.00/107M [00:00<?, ?B/s]

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://9b4bd9954ee632a5bf.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://9b4bd9954ee632a5bf.gradio.live


