In [1]:
cd ..

/home/va0831/Projects/FlowMatchingMnist


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
# conditional_mnist_diffusion_flow.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# --- Configuración ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
timesteps = 100
img_shape = (1, 28, 28)
os.makedirs("outputs/diffusion", exist_ok=True)
os.makedirs("outputs/flow_matching", exist_ok=True)

In [3]:
# --- Dataset ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2 - 1)
])
dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [4]:
class FlowModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28 + 1 + 10, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 28*28)
        )

    def forward(self, x, t, y):
        x_flat = x.view(x.size(0), -1)
        t_embed = t.unsqueeze(1).float() if t.ndim == 1 else t
        y_onehot = F.one_hot(y, num_classes=10).float()
        x_cat = torch.cat([x_flat, t_embed, y_onehot], dim=1)
        return self.net(x_cat).view(x.size())

# --- Entrenamiento Flow Matching ---
def train_flow():
    #model = nn.DataParallel(FlowModel(), device_ids=[0,1,2,3,4,5]).to(device)
    model = FlowModel().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        pbar = tqdm(dataloader, desc=f"[Flow {epoch}]", leave=False, ncols=80)
        for x_real, y in pbar:
            x_real = x_real.to(device)
            y = y.to(device)
            x_noise = torch.randn_like(x_real)
            t = torch.rand(x_real.size(0), device=device)
            x_t = (1 - t.view(-1, 1, 1, 1)) * x_noise + t.view(-1, 1, 1, 1) * x_real
            v_target = x_real - x_noise

            v_pred = model(x_t, t, y)
            mse = F.mse_loss(v_pred, v_target)
            norm_pred = F.log_softmax(v_pred.view(v_pred.size(0), -1), dim=1)
            norm_true = F.softmax(v_target.view(v_target.size(0), -1), dim=1)
            kl = F.kl_div(norm_pred, norm_true, reduction='batchmean')
            loss = mse + 0.1 * kl

            opt.zero_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        if (epoch + 1) % 10 == 0:
            generate_flow(9, model=model, save_path=f"outputs/flow_matching/sample_epoch{epoch+1}.png")
            torch.save(model.state_dict(), "outputs/flow_matching/flow_model.pth")

    torch.save(model.state_dict(), "outputs/flow_matching/flow_model.pth")

# --- Generación Flow Condicional ---
@torch.no_grad()
def generate_flow(label, model=None, save_path=None, show=False):
    if model is None:
        #model = nn.DataParallel(FlowModel(), device_ids=[0,1,2,3,4,5])
        model = FlowModel().to(device)
        model.load_state_dict(torch.load("outputs/flow_matching/flow_model.pth"))
        model.eval()

    x = torch.randn(64, *img_shape).to(device)
    y = torch.full((64,), label, dtype=torch.long, device=device)
    steps = 50
    dt = 1.0 / steps

    for i in range(steps):
        t = torch.full((x.size(0),), i * dt, device=device)
        v = model(x, t, y)
        x = x + v * dt

    img = (x + 1) / 2
    utils.save_image(img, save_path or f"outputs/flow_matching/flow_gen_{label}.png", nrow=8)
    if show:
        plt.imshow(img[0].cpu().squeeze().numpy(), cmap='gray')
        plt.title(f'Generated {label}')
        plt.axis('off')
        plt.show()

# train_flow()
# generate_flow(9)

In [5]:
train_flow()

                                                                                