In [1]:
cd ..

/home/va0831/Projects/FlowMatchingMnist


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [None]:
# conditional_mnist_diffusion_flow.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import random

# --- Configuración ---
device = 'cuda:4' if torch.cuda.is_available() else 'cpu'
batch_size = 128
timesteps = 10
img_shape = (1, 28, 28)
os.makedirs("outputs/diffusion/images", exist_ok=True)

# --- Dataset ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- Modelo Condicional Transformer Denoiser ---
class TransformerDenoiser(nn.Module):
    def __init__(self, num_classes=10, dim=128, depth=4, heads=4):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, dim)
        self.t_embed = nn.Linear(1, dim)
        self.input_proj = nn.Linear(28 * 28, dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads), num_layers=depth
        )
        self.output_proj = nn.Linear(dim, 28 * 28)

    def forward(self, x, noise_level, labels):
        B = x.size(0)
        x_flat = x.view(B, -1)  # (B, 784)
        x_embed = self.input_proj(x_flat)
        print("noise_level:",noise_level.shape)
        print("labels:",labels.shape)
        label_embed = self.label_embedding(labels)      # (B, dim)
        t_embed = self.t_embed(noise_level)  # (B, dim)
        print("t_embed:",t_embed.shape)
        cond = label_embed + t_embed
        cond = cond#.unsqueeze(0)  # (1, B, dim)

        print("cond:",cond.shape)
        x_cond = x_embed + cond  # broadcasting (1, B, dim)
        print("x_cond:",x_cond.shape)
        transformed = self.transformer(x_cond)  # (1, B, dim)
        out = self.output_proj(transformed.squeeze(0))  # (B, 784)
        return out.view(B, 1, 28, 28)

# --- Entrenamiento ---
model = TransformerDenoiser().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.MSELoss()
epochs = 30

for epoch in range(epochs):
    total_loss = 0
    model.train()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        t = random.randint(1, timesteps)
        noise_level = torch.tensor([t / timesteps], device=device).view(-1, 1, 1, 1)
        noisy_images = images + torch.randn_like(images) * noise_level

        optimizer.zero_grad()
        outputs = model(noisy_images, noise_level, labels)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}")

# --- Generación ---
model.eval()
prompt = random.randint(0, 9)
with torch.no_grad():
    denoised_img = torch.randn((1, 1, 28, 28)).to(device)
    label = torch.tensor([prompt], device=device)

    fig, axes = plt.subplots(1, timesteps + 1, figsize=(15, 3))
    axes[0].imshow(denoised_img.cpu().squeeze(), cmap='gray')
    axes[0].set_title("Start (Noise)")
    axes[0].axis('off')

    for t in reversed(range(1, timesteps + 1)):
        noise_level = torch.tensor([[[[t / timesteps]]]], device=device)
        denoised_img = model(denoised_img, noise_level, label)
        axes[timesteps - t + 1].imshow(denoised_img.cpu().detach().squeeze(), cmap='gray')
        axes[timesteps - t + 1].set_title(f"Step {timesteps - t + 1}\nDigit {prompt}")
        axes[timesteps - t + 1].axis('off')

plt.tight_layout()
plt.show()


In [11]:
noise_level.shape

torch.Size([1, 1, 1, 1])