In [None]:
import os
import torch
import math
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# -----------------------------------
# 📁 Custom Dataset Loader
# -----------------------------------
class ImageFolderDataset(Dataset):
    def __init__(self, root, transform=None):
        self.image_paths = [os.path.join(root, fname) for fname in os.listdir(root)
                            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# -----------------------------------
# 🔧 Transform & Load Images
# -----------------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


image_dir = 'path/to/dataset'  # Upload folder to this path in Colab
dataset = ImageFolderDataset(image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)


In [None]:
#MLP 
class Block(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ff = nn.Linear(channels, channels)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.ff(x))

class MLP(nn.Module):
    def __init__(self, channels_data=196608, layers=5, channels=1024, channels_t=512):
        super().__init__()
        self.channels_t = channels_t
        self.in_projection = nn.Linear(channels_data, channels)
        self.t_projection = nn.Linear(channels_t, channels)
        self.blocks = nn.Sequential(*[Block(channels) for _ in range(layers)])
        self.out_projection = nn.Linear(channels, channels_data)

    def gen_t_embedding(self, t, max_positions=10000):
        t = t * max_positions
        half_dim = self.channels_t // 2
        emb = math.log(max_positions) / (half_dim - 1)
        emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp()
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)
        if self.channels_t % 2 == 1:
            emb = nn.functional.pad(emb, (0, 1), mode='constant')
        return emb

    def forward(self, x, t):
        x = self.in_projection(x)
        t = self.gen_t_embedding(t)
        t = self.t_projection(t)
        x = x + t
        x = self.blocks(x)
        x = self.out_projection(x)
        return x

In [None]:
# -----------------------------------
model = MLP().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)



In [None]:
# -----------------------------------
# 🏋️ Training Loop
# -----------------------------------
num_epochs = 5
losses = []

for epoch in range(num_epochs):
    loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    for batch in loop:
        x1 = batch.to(device)
        x1 = x1.view(x1.size(0), -1)  # Flatten image
        x0 = torch.randn_like(x1).to(device)
        target = x1 - x0

        t = torch.rand(x1.size(0), device=device)
        xt = (1 - t[:, None]) * x0 + t[:, None] * x1

        pred = model(xt, t)
        loss = ((target - pred)**2).mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loop.set_postfix(loss=loss.item())
        losses.append(loss.item())

In [None]:
# Plot loss
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Step")
plt.ylabel("MSE Loss")
plt.show()

In [None]:
# 🎨 Sampling from the Trained Model
# -----------------------------------
from torchvision.utils import save_image

# Seting model to eval mode and no gradients
model.eval().requires_grad_(False)

# Generating starting noise
num_samples = 5
xt = torch.randn(num_samples, 196608).to(device)
steps = 1000

# Flow sampling
for i, t in enumerate(torch.linspace(0, 1, steps)):
    t_vec = t.expand(xt.size(0)).to(device)
    xt = xt + (1 / steps) * model(xt, t_vec)

# Reshape and unnormalize
samples = xt.view(-1, 3, 256, 256)
samples = samples.clamp(-1, 1) * 0.5 + 0.5  # [0, 1]

# Folder to save
output_dir = "/path/to/save/dir"
os.makedirs(output_dir, exist_ok=True)

# Saving each image separately
for i, img in enumerate(samples):
    save_image(img, os.path.join(output_dir, f"image_{i:03d}.png"))

print(f"✅ Saved {len(samples)} images to: {output_dir}")

In [None]:
# Show preview
import numpy as np
plt.imshow(np.transpose(vutils.make_grid(samples.cpu(), nrow=4), (1, 2, 0)))
plt.axis("off")
plt.title("Generated Samples")
plt.show()