In [1]:
import os
import torch
import math
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

# -----------------------------------
# 📁 Custom Dataset Loader
# -----------------------------------
class ImageFolderDataset(Dataset):
    def __init__(self, root, transform=None):
        self.image_paths = [os.path.join(root, fname) for fname in os.listdir(root)
                            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# -----------------------------------
# 🔧 Transform & Load Images
# -----------------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

image_size = 256


image_dir = '/aul/homes/amaha038/Mapsgeneration/TerraFlySat_and_MapDatatset/TerraFly_Full_Satellite_Dataset/Philadelphia_Washington_Newyork_Train'  # Upload folder to this path in Colab
dataset = ImageFolderDataset(image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)


In [2]:
from diffusers import UNet2DModel
import torch.nn.functional as F
class UNetFlowModel(nn.Module):
    def __init__(self, image_size=256):
        super().__init__()
        self.unet = UNet2DModel(
            sample_size=image_size,
            in_channels=3,           # No need to manually concatenate time
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
        )

    def forward(self, x, t):
        # t must be passed as a LongTensor for timestep
        t = (t * 999).long()  # scale to [0, 999] as expected
        return self.unet(x, timestep=t).sample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# -----------------------------------
model = UNetFlowModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)



In [None]:
# -----------------------------------
# 🏋️ Training Loop
# -----------------------------------
loader = dataloader
num_epochs = 5
save_interval = 1
batch_size = 16
image_size = 256
epoch_loss = []


for epoch in range(num_epochs):
    pbar = tqdm(loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    model.train()
    epoch_loss = 0.0

    for x1 in pbar:
        x1 = x1.to(device)
        x0 = torch.randn_like(x1)
        t = torch.rand(x1.size(0), device=device)

        xt = (1 - t[:, None, None, None]) * x0 + t[:, None, None, None] * x1
        target = x1 - x0

        pred = model(xt, t)
        loss = F.mse_loss(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix(batch_loss=loss.item())

    avg_epoch_loss = epoch_loss / len(loader)
    print(f"✅ Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_epoch_loss:.4f}")
    losses.append(avg_epoch_loss)


Epoch [1/5]:   1%|▊                                                                                                                   | 4/594 [01:01<2:26:51, 14.93s/it, batch_loss=0.912]

In [None]:
# Plot loss
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Step")
plt.ylabel("MSE Loss")
plt.show()

In [None]:
# 🎨 Sampling from the Trained Model
# -----------------------------------
from torchvision.utils import save_image

# Seting model to eval mode and no gradients
model.eval().requires_grad_(False)

# Generating starting noise
num_samples = 5
xt = torch.randn(num_samples, 196608).to(device)
steps = 1000

# Flow sampling
for i, t in enumerate(torch.linspace(0, 1, steps)):
    t_vec = t.expand(xt.size(0)).to(device)
    xt = xt + (1 / steps) * model(xt, t_vec)

# Reshape and unnormalize
samples = xt.view(-1, 3, 256, 256)
samples = samples.clamp(-1, 1) * 0.5 + 0.5  # [0, 1]

# Folder to save
output_dir = "/path/to/save/dir"
os.makedirs(output_dir, exist_ok=True)

# Saving each image separately
for i, img in enumerate(samples):
    save_image(img, os.path.join(output_dir, f"image_{i:03d}.png"))

print(f"✅ Saved {len(samples)} images to: {output_dir}")

In [None]:
# Show preview
import numpy as np
plt.imshow(np.transpose(vutils.make_grid(samples.cpu(), nrow=4), (1, 2, 0)))
plt.axis("off")
plt.title("Generated Samples")
plt.show()