In [1]:
cd ..

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/va0831/Projects/FlowMatchingMnist


In [2]:
# conditional_mnist_diffusion_flow.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# --- Configuración ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
timesteps = 100
img_shape = (1, 28, 28)
os.makedirs("outputs/diffusion", exist_ok=True)
os.makedirs("outputs/flow_matching", exist_ok=True)

In [3]:
# --- Dataset ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2 - 1)
])
dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [5]:
# --- Modelos ---
class DiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64)
        )
        self.label_embed = nn.Embedding(10, 64)
        self.net = nn.Sequential(
            nn.Conv2d(1 + 1 + 1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 1, 3, padding=1)
        )

    def forward(self, x, t, y):
        t_embed = self.time_embed(t.view(-1, 1).float())
        y_embed = self.label_embed(y)
        cond = t_embed + y_embed
        cond_img = cond.view(-1, cond.size(1), 1, 1).expand(-1, cond.size(1), 28, 28)
        x = torch.cat([x, cond_img[:, :1], cond_img[:, 1:2]], dim=1)  # fake 2-ch cond
        return self.net(x)

# --- Entrenamiento Diffusion ---
def train_diffusion():
    #model = nn.DataParallel(DiffusionModel(), device_ids=[0,1,2,3,4,5]).to(device)
    model = DiffusionModel().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    betas = torch.linspace(1e-4, 0.02, timesteps).to(device)
    alphas = 1 - betas
    alpha_hat = torch.cumprod(alphas, dim=0)

    for epoch in range(1000):
        pbar = tqdm(dataloader, desc=f"[Diffusion {epoch}]", leave=False, ncols=80)
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            t = torch.randint(0, timesteps, (x.size(0),), device=device)
            a_hat = alpha_hat[t].view(-1, 1, 1, 1)
            noise = torch.randn_like(x)
            x_t = a_hat.sqrt() * x + (1 - a_hat).sqrt() * noise

            noise_pred = model(x_t, t, y)
            mse = F.mse_loss(noise_pred, noise)
            norm_pred = F.log_softmax(noise_pred.view(noise_pred.size(0), -1), dim=1)
            norm_true = F.softmax(noise.view(noise.size(0), -1), dim=1)
            kl = F.kl_div(norm_pred, norm_true, reduction='batchmean')
            loss = mse + 0.1 * kl

            opt.zero_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            if epoch == 0 and pbar.n < 5:
                print("x_t:", x_t.min().item(), x_t.max().item())
                print("noise_pred:", noise_pred.min().item(), noise_pred.max().item())

        if (epoch + 1) % 10 == 0:
            generate_diffusion(9, model=model, save_path=f"outputs/diffusion/sample_epoch{epoch+1}.png")
            torch.save(model.state_dict(), "outputs/diffusion/diffusion_model.pth")

    torch.save(model.state_dict(), "outputs/diffusion/diffusion_model.pth")

# --- Generación Diffusion Condicional ---
@torch.no_grad()
def generate_diffusion(label, model=None, save_path=None, show=False):
    if model is None:
        #model = nn.DataParallel(DiffusionModel(), device_ids=[0,1,2,3,4,5]).to(device)
        model = DiffusionModel().to(device)
        model.load_state_dict(torch.load("outputs/diffusion/diffusion_model.pth"))
        model.eval()

    x = torch.randn(64, *img_shape).to(device)
    y = torch.full((64,), label, dtype=torch.long, device=device)

    betas = torch.linspace(1e-4, 0.02, timesteps).to(device)
    alphas = 1 - betas
    alpha_hat = torch.cumprod(alphas, dim=0)

    for t in reversed(range(timesteps)):
        t_batch = torch.full((x.size(0),), t, device=device, dtype=torch.long)
        eps_pred = model(x, t_batch, y)
        alpha_t = alphas[t]
        alpha_bar_t = alpha_hat[t]
        x0_pred = (x - (1 - alpha_bar_t).sqrt() * eps_pred) / alpha_bar_t.sqrt()
        noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
        x = alpha_t.sqrt() * x0_pred + (1 - alpha_t).sqrt() * noise

    img = (x + 1) / 2
    utils.save_image(img, save_path or f"outputs/diffusion/diffusion_gen_{label}.png", nrow=8)
    if show:
        plt.imshow(img[0].cpu().squeeze().numpy(), cmap='gray')
        plt.title(f'Generated {label}')
        plt.axis('off')
        plt.show()

# --- Ejecutar ---
# train_diffusion()
# generate_diffusion(9)

In [None]:
train_diffusion()

[Diffusion 807]:  15%|█▊          | 69/469 [00:01<00:06, 60.06it/s, loss=0.0722]