In [None]:
import os
import torch
import math
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
from diffusers.models import AutoencoderKL
from torchvision.utils import make_grid, save_image


device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# -----------------------------------
# 📁 Custom Dataset Loader
# -----------------------------------
class ImageFolderDataset(Dataset):
    def __init__(self, root, transform=None):
        self.image_paths = [os.path.join(root, fname) for fname in os.listdir(root)
                            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# -----------------------------------
# 🔧 Transform & Load Images
# -----------------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

image_size = 256


image_dir = '/path/to/dataset'  
dataset = ImageFolderDataset(image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)


In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vae.eval()

In [None]:
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, freq_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(freq_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size)
        )

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32) / half).to(t.device)
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
        return emb

    def forward(self, t):
        t_freq = self.timestep_embedding(t, 256)
        return self.mlp(t_freq)

class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = Mlp(hidden_size, hidden_size * 4)
        self.adaLN_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))

    def forward(self, x, c):
        shift1, scale1, gate1, shift2, scale2, gate2 = self.adaLN_mod(c).chunk(6, dim=1)
        x = x + gate1.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift1, scale1))
        x = x + gate2.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift2, scale2))
        return x

class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
        self.adaLN_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size))

    def forward(self, x, c):
        shift, scale = self.adaLN_mod(c).chunk(2, dim=1)
        x = modulate(self.norm(x), shift, scale)
        return self.linear(x)

class TransformerFlowLatent(nn.Module):
    def __init__(self, latent_size=32, patch_size=2, in_channels=4, hidden_size=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = latent_size
        self.in_channels = in_channels
        self.out_channels = in_channels

        self.embed = PatchEmbed(latent_size, patch_size, in_channels, hidden_size)
        self.t_embed = TimestepEmbedder(hidden_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.embed.num_patches, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads) for _ in range(depth)])
        self.final = FinalLayer(hidden_size, patch_size, self.out_channels)

    def unpatchify(self, x):
        B, T, _ = x.shape
        p = self.patch_size
        h = w = self.img_size // p
        C = self.out_channels
        x = x.view(B, h, w, p, p, C).permute(0, 5, 1, 3, 2, 4)
        return x.reshape(B, C, h * p, w * p)

    def forward(self, x, t):
        x = self.embed(x) + self.pos_embed
        t_emb = self.t_embed(t)
        for blk in self.blocks:
            x = blk(x, t_emb)
        x = self.final(x, t_emb)
        return self.unpatchify(x)


In [None]:
# -----------------------------------
model = TransformerFlowLatent().to(device)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=0
)


In [None]:
# -----------------------------------
# 🏋️ Training Loop
# -----------------------------------

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

num_epochs = 400
save_every_n_epochs = 10
num_samples_to_generate = 6

save_dir = "/path/to/generated_latent_outputs"
weights_dir = "/path/to/saved_weights"
os.makedirs(save_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=True)

best_loss = float("inf")
epoch_losses = []

for epoch in range(num_epochs):
    model.train()
    pbar = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    total_loss = 0

    for i, x1 in enumerate(pbar):
        x1 = x1.to(device)

        with torch.no_grad():
            latents1 = vae.encode(x1).latent_dist.sample() * 0.18215
            latents0 = torch.randn_like(latents1)

        t = torch.rand(latents1.size(0), device=device)
        xt = (1 - t[:, None, None, None]) * latents0 + t[:, None, None, None] * latents1
        target = latents1 - latents0

        pred = model(xt, t)
        loss = F.mse_loss(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    epoch_losses.append(avg_loss)
    print(f"✅ Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")

    # Save best weights
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), os.path.join(weights_dir, "best_model.pth"))
        print(f"💾 Best model saved at epoch {epoch+1} with loss {best_loss:.4f}")

    # Save grid of samples every N epochs
    if (epoch + 1) % save_every_n_epochs == 0:
        model.eval()
        with torch.no_grad():
            z = torch.randn(num_samples_to_generate, 4, 32, 32).to(device)
            t_sample = torch.ones(z.size(0), device=device)  # generate at t=1
            generated_latents = model(z, t_sample)
            decoded_imgs = vae.decode(generated_latents / 0.18215).sample
            decoded_imgs = decoded_imgs.clamp(-1, 1) * 0.5 + 0.5  # [0, 1] for saving
            grid = make_grid(decoded_imgs, nrow=4)  # change layout as needed
            save_image(grid, os.path.join(save_dir, f"epoch_{epoch+1:03d}.png"))
            print(f"🖼️ Saved image grid for epoch {epoch+1}")

# Save final model
torch.save(model.state_dict(), os.path.join(weights_dir, "final_model.pth"))
print("✅ Final model weights saved.")


In [None]:
save_dir_plot = "/pathto/plots"
# Plot training loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o', label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir_plot, "training_loss_curve.png"))
plt.close()
print("📈 Saved training loss curve to training_loss_curve.png")


In [None]:
# 🎨 Sampling from the Trained Model
# -----------------------------------
from torchvision.utils import save_image

model.eval().requires_grad_(False)
vae.eval().requires_grad_(False)

# -----------------------------------
# ⚙️ Flow Matching in Latent Space
# -----------------------------------
num_samples = 5
latent_size = 32  # VAE latent resolution for 256x256 images
xt = torch.randn(num_samples, 4, latent_size, latent_size).to(device)  # starting from latent noise
steps = 100

# Integrate through flow field
for t in torch.linspace(0, 1, steps, device=xt.device):
    t_vec = t.expand(num_samples)
    xt = xt + (1 / steps) * model(xt, t_vec)

# -----------------------------------
# 🔄 Decode Latents into RGB Images
# -----------------------------------
with torch.no_grad():
    decoded = vae.decode(xt / 0.18215).sample  # match SD's latent scale
    decoded = decoded.clamp(-1, 1) * 0.5 + 0.5  # [-1,1] → [0,1] for visualization

# -----------------------------------
# 💾 Save Output Images
# -----------------------------------
output_dir = "/path/to/Generated_Images"
os.makedirs(output_dir, exist_ok=True)

for i, img in enumerate(decoded):
    save_image(img, os.path.join(output_dir, f"image_{i:03d}.png"))

print(f"✅ Saved {len(decoded)} images to: {output_dir}")


In [None]:
# Show preview
import numpy as np
plt.imshow(np.transpose(vutils.make_grid(samples.cpu(), nrow=4), (1, 2, 0)))
plt.axis("off")
plt.title("Generated Samples")
plt.show()