In [None]:
%pip install numpy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw
import os
import numpy as np

# --- 0. Parámetros de Configuración ---
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
BATCH_SIZE = 8
DUMMY_DATA_COUNT = 40

# --- 1. Definición del Modelo U-Net Simple ---
class SimpleUNet(nn.Module):
    def __init__(self):
        super(SimpleUNet, self).__init__()

        # Encoder (Camino de bajada)
        self.enc_conv1 = self.conv_block(1, 64) # Grayscale in, 64 out
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc_conv2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = self.conv_block(128, 256)

        # Decoder (Camino de subida)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv2 = self.conv_block(256, 128) # 128 (from upconv) + 128 (from skip)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv1 = self.conv_block(128, 64) # 64 (from upconv) + 64 (from skip)
        
        # Capa Final
        self.out_conv = nn.Conv2d(64, 3, kernel_size=1) # 64 in, 3 out (RGB)
        self.final_activation = nn.Sigmoid() # Para que los píxeles estén entre 0 y 1

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc_conv1(x)
        p1 = self.pool1(e1)
        e2 = self.enc_conv2(p1)
        p2 = self.pool2(e2)

        # Bottleneck
        b = self.bottleneck(p2)

        # Decoder con Skip Connections
        d2 = self.upconv2(b)
        d2 = torch.cat([d2, e2], dim=1) # Skip connection
        d2 = self.dec_conv2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1) # Skip connection
        d1 = self.dec_conv1(d1)

        # Salida
        output = self.out_conv(d1)
        output = self.final_activation(output)
        
        return output

# --- 2. Generador de Datos (Dataset) ---
class ColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.png')]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        color_image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            color_image = self.transform(color_image)
        
        # Crear la entrada en escala de grises
        grayscale_image = transforms.Grayscale()(color_image)
        
        return grayscale_image, color_image


# --- 3. Bucle Principal de Entrenamiento ---
if __name__ == '__main__':
    # Setup
    torch.cuda.empty_cache() # Limpiar caché de CUDA
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Usando dispositivo: {device}")
    
    # Preparar Dataset y DataLoader
    transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
    dataset = ColorizationDataset(root_dir='../data/colors', transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Instanciar modelo, pérdida y optimizador
    model = SimpleUNet().to(device)
    criterion = nn.MSELoss() # Mean Squared Error es buena para comparar imágenes
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    print("\n--- ¡Comenzando el Entrenamiento! ---")
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        for gray_imgs, color_imgs in dataloader:
            gray_imgs = gray_imgs.to(device)
            color_imgs = color_imgs.to(device)
            
            # Forward pass
            outputs = model(gray_imgs)
            loss = criterion(outputs, color_imgs)
            
            # Backward pass y optimización
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {running_loss/len(dataloader):.4f}")
        
    print("--- ¡Entrenamiento Finalizado! ---")

    # Guardar el modelo entrenado
    torch.save(model.state_dict(), '../outputs/simple_unet_colorizer.pth')
    print("Modelo guardado en '../outputs/simple_unet_colorizer.pth'")
    


Note: you may need to restart the kernel to use updated packages.
Usando dispositivo: cuda

--- ¡Comenzando el Entrenamiento! ---




Epoch [1/20], Loss: 0.0563
Epoch [2/20], Loss: 0.0127
Epoch [3/20], Loss: 0.0063
Epoch [4/20], Loss: 0.0040
Epoch [5/20], Loss: 0.0032
Epoch [6/20], Loss: 0.0028
Epoch [7/20], Loss: 0.0025
Epoch [8/20], Loss: 0.0022
Epoch [9/20], Loss: 0.0021
Epoch [10/20], Loss: 0.0019
Epoch [11/20], Loss: 0.0018
Epoch [12/20], Loss: 0.0018
Epoch [13/20], Loss: 0.0017
Epoch [14/20], Loss: 0.0016
Epoch [15/20], Loss: 0.0015
Epoch [16/20], Loss: 0.0013
Epoch [17/20], Loss: 0.0011
Epoch [18/20], Loss: 0.0010
Epoch [19/20], Loss: 0.0008
Epoch [20/20], Loss: 0.0008
--- ¡Entrenamiento Finalizado! ---
Modelo guardado en '../outputs/simple_unet_colorizer.pth'


In [9]:
# --- 5. Inferencia en una imagen específica ---
def inference(model_path, image_path, output_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Cargar modelo
    model = SimpleUNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Transform idéntico al entrenamiento
    preprocess = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    # Abrir imagen, convertir a RGB y luego a grayscale
    color_image = Image.open(image_path).convert("RGB")
    gray_image = transforms.Grayscale()(color_image)
    gray_tensor = preprocess(gray_image).unsqueeze(0).to(device)  # añadir batch

    # Inferencia
    with torch.no_grad():
        output = model(gray_tensor)

    # Convertir a imagen y guardar
    output_image = transforms.ToPILImage()(output.squeeze(0).cpu())
    output_image.save(output_path)
    print(f"Imagen inferida guardada en: {output_path}")


# Ejemplo de uso
inference(
    model_path='../outputs/simple_unet_colorizer.pth',
    image_path=r'C:\Users\sergi\Pictures\vlcsnap-2025-02-17-17h59m59s967.png',
    output_path='../outputs/inference_result.png'
)


Imagen inferida guardada en: ../outputs/inference_result.png
