In [1]:
# !pip install diffusers[training] datasets accelerate tensorboard
# !pip install torchvision

In [3]:
import torch
from datasets import load_dataset
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator
from tqdm.auto import tqdm
from PIL import Image
import os

# Configuration
config = {
    "image_size": 128,
    "train_batch_size": 16,
    "eval_batch_size": 16,
    "num_epochs": 50,
    "gradient_accumulation_steps": 1,
    "learning_rate": 1e-4,
    "lr_warmup_steps": 500,
    "save_image_epochs": 10,
    "save_model_epochs": 10,
    "mixed_precision": "fp16",
    "output_dir": "nwpu-resisc45-diffusion",
    "seed": 0,
}

# Load dataset
print("Loading NWPU-RESISC45 dataset...")
dataset = load_dataset("timm/resisc45", split="train")

# Check the dataset structure and get label names
print(f"Dataset features: {dataset.features}")
label_names = dataset.features["label"].names
print(f"Available labels: {label_names}")

# Selected categories - natural landscapes and terrain (using indices)
# Label indices for: forest, mountain, meadow, cloud, lake, river, sea_ice, desert, island, beach
selected_label_names = [
    "forest",
    "mountain", 
    "meadow",
    "cloud",
    "lake",
    "river",
    "sea_ice",
    "desert",
    "island",
    "beach"
]

# Get indices for selected categories
selected_indices = [i for i, name in enumerate(label_names) if name in selected_label_names]
print(f"Selected categories: {[label_names[i] for i in selected_indices]}")

# Filter for selected categories
def filter_categories(example):
    return example["label"] in selected_indices

dataset = dataset.filter(filter_categories)
print(f"Filtered dataset size: {len(dataset)} images")

# Data preprocessing
preprocess = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

# Create dataloader
train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=config["train_batch_size"], shuffle=True
)

# Initialize model
model = UNet2DModel(
    sample_size=config["image_size"],
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

# Initialize noise scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])

# Learning rate scheduler
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config["lr_warmup_steps"],
    num_training_steps=(len(train_dataloader) * config["num_epochs"]),
)

# Initialize accelerator
accelerator = Accelerator(
    mixed_precision=config["mixed_precision"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)

# Create output directory
os.makedirs(config["output_dir"], exist_ok=True)

# Training function
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    global_step = 0
    
    for epoch in range(config["num_epochs"]):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            
            # Sample noise
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]
            
            # Sample random timesteps
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
            ).long()
            
            # Add noise to images
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            with accelerator.accumulate(model):
                # Predict noise
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = torch.nn.functional.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
                
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1
        
        # Generate sample images
        if accelerator.is_main_process:
            if (epoch + 1) % config["save_image_epochs"] == 0 or epoch == config["num_epochs"] - 1:
                pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
                
                generator = torch.Generator(device=pipeline.device).manual_seed(config["seed"])
                images = pipeline(
                    generator=generator,
                    batch_size=config["eval_batch_size"],
                    num_inference_steps=1000,
                ).images
                
                # Save sample images
                image_grid = Image.new("RGB", (config["image_size"] * 4, config["image_size"] * 4))
                for i, image in enumerate(images[:16]):
                    row = i // 4
                    col = i % 4
                    image_grid.paste(image, (col * config["image_size"], row * config["image_size"]))
                
                image_grid.save(f"{config['output_dir']}/samples_epoch_{epoch}.png")
            
            if (epoch + 1) % config["save_model_epochs"] == 0 or epoch == config["num_epochs"] - 1:
                pipeline.save_pretrained(config["output_dir"])

# Start training
print("Starting training...")
train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

# Save final model
if accelerator.is_main_process:
    pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    pipeline.save_pretrained(config["output_dir"])
    print(f"Model saved to {config['output_dir']}")

# Generate images after training
print("\nGenerating sample images...")
pipeline = DDPMPipeline.from_pretrained(config["output_dir"])
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

generator = torch.Generator(device=pipeline.device).manual_seed(42)
images = pipeline(
    generator=generator,
    batch_size=16,
    num_inference_steps=1000,
).images

# Create final grid
final_grid = Image.new("RGB", (config["image_size"] * 4, config["image_size"] * 4))
for i, image in enumerate(images[:16]):
    row = i // 4
    col = i % 4
    final_grid.paste(image, (col * config["image_size"], row * config["image_size"]))

final_grid.save(f"{config['output_dir']}/final_samples.png")
print(f"Final samples saved to {config['output_dir']}/final_samples.png")

Loading NWPU-RESISC45 dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/85.1M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/85.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18900 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/6300 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6300 [00:00<?, ? examples/s]

Dataset features: {'image': Image(mode=None, decode=True), 'label': ClassLabel(names=['airplane', 'airport', 'baseball_diamond', 'basketball_court', 'beach', 'bridge', 'chaparral', 'church', 'circular_farmland', 'cloud', 'commercial_area', 'dense_residential', 'desert', 'forest', 'freeway', 'golf_course', 'ground_track_field', 'harbor', 'industrial_area', 'intersection', 'island', 'lake', 'meadow', 'medium_residential', 'mobile_home_park', 'mountain', 'overpass', 'palace', 'parking_lot', 'railway', 'railway_station', 'rectangular_farmland', 'river', 'roundabout', 'runway', 'sea_ice', 'ship', 'snowberg', 'sparse_residential', 'stadium', 'storage_tank', 'tennis_court', 'terrace', 'thermal_power_station', 'wetland']), 'image_id': Value('string')}
Available labels: ['airplane', 'airport', 'baseball_diamond', 'basketball_court', 'beach', 'bridge', 'chaparral', 'church', 'circular_farmland', 'cloud', 'commercial_area', 'dense_residential', 'desert', 'forest', 'freeway', 'golf_course', 'groun

Filter:   0%|          | 0/18900 [00:00<?, ? examples/s]

Filtered dataset size: 4155 images
Starting training...


  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model saved to nwpu-resisc45-diffusion

Generating sample images...


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Final samples saved to nwpu-resisc45-diffusion/final_samples.png
