In [1]:
!pip install diffusers[torch]

Collecting diffusers[torch]
  Downloading diffusers-0.31.0-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub>=0.23.2 (from diffusers[torch])
  Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from diffusers[torch])
  Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting safetensors>=0.3.1 (from diffusers[torch])
  Downloading safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting accelerate>=0.31.0 (from diffusers[torch])
  Downloading accelerate-1.1.1-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.1.1-py3-none-any.whl (333 kB)
Downloading huggingface_hub-0.26.2-py3-none-any.whl (447 kB)
Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m781.7/781.7 kB[0m [31m74.2 MB/s[0m eta [36m0:00:00[0m
[?

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np

In [18]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [19]:
!nvidia-smi

Sun Nov 24 20:29:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-16GB           On  |   00000000:00:1E.0 Off |                    0 |
| N/A   33C    P0             48W /  300W |    7083MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [20]:
def get_device():
    """
    Get the appropriate device (CUDA, MPS, or CPU)
    """
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"


class PianoRollDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        # Convert grayscale to RGB while maintaining binary values
        gray_image = Image.open(image_path).convert('L')
        # Convert to binary image first (0 or 255)
        binary_image = gray_image.point(lambda x: 0 if x < 128 else 255, '1')
        # Convert to RGB
        rgb_image = binary_image.convert('RGB')
        image = self.transform(rgb_image)
        return image

In [24]:
def save_images(images, path, step):
    """Save a batch of images during training for monitoring."""
    images = (images / 2 + 0.5).clamp(0, 1)
    # Convert to binary
    images = (images > 0.5).float()
    grid = torchvision.utils.make_grid(images)
    # Convert to PIL image
    grid_image = torchvision.transforms.ToPILImage()(grid)
    os.makedirs(path, exist_ok=True)
    grid_image.save(f"{path}/sample_{step}.png")


def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, device):
    progress_bar = tqdm(total=config["num_epochs"] * len(train_dataloader))
    global_step = 0

    for epoch in range(config["num_epochs"]):
        model.train()
        for batch in train_dataloader:
            clean_images = batch.to(device)
            batch_size = clean_images.shape[0]

            # Sample noise and add to images
            noise = torch.randn(clean_images.shape).to(device)
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (batch_size,),
                device=device
            ).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # Get model prediction
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            # Calculate loss
            loss = F.mse_loss(noise_pred, noise)
            print(f"loss:{loss}")

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            progress_bar.update(1)
            global_step += 1

            # Save sample images periodically
            if global_step % config["sample_interval"] == 0:
                model.eval()
                with torch.no_grad():
                    # Generate sample images
                    sample = torch.randn(8, 3, config["image_height"], config["image_width"]).to(device)
                    timesteps = torch.linspace(999, 0, 50).long().to(device)
                    for t in timesteps:
                        residual = model(sample, t.repeat(8), return_dict=False)[0]
                        sample = noise_scheduler.step(residual, t, sample).prev_sample
                save_images(sample, config["sample_dir"], global_step)
                model.train()

            if global_step % config["save_interval"] == 0:
                # Save checkpoint
                torch.save({
                    'step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                }, f"checkpoint_{global_step}.pt")

        # Save model after each epoch
        torch.save(model.state_dict(), f"model_epoch_{epoch}.pt")

In [26]:
def main():
    # Configuration
    config = {
        "image_height": 768,
        "image_width": 512,
        "batch_size": 2,
        "num_epochs": 1,
        "learning_rate": 1e-4,
        "save_interval": 100,
        "sample_interval": 1000,  # Interval for generating sample images
        "data_dir": "piano_roll_images",  # Your image directory
        "sample_dir": "samples"  # Directory to save generated samples
    }

    # Initialize device
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
    print(f"Using device: {device}")

    # Create dataset and dataloader
    dataset = PianoRollDataset(config["data_dir"])
    dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=4
    )

    # Initialize model with 3 input/output channels for RGB
    model = UNet2DModel(
        sample_size=(config["image_height"], config["image_width"]),
        in_channels=3,  # RGB input
        out_channels=3,  # RGB output
        layers_per_block=1,
        block_out_channels=(32, 64, 128),  # Further reduced channels
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D", 
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])

    # Train model
    train_loop(config, model, noise_scheduler, optimizer, dataloader, device)

In [27]:
 main()

Using device: cuda


  0%|          | 0/8510 [00:00<?, ?it/s]

loss:1.1097028255462646
loss:1.0769217014312744
loss:1.031741738319397
loss:1.0278807878494263
loss:0.9642127156257629
loss:0.9427040815353394
loss:0.9079720377922058
loss:0.8790071606636047
loss:0.8823263049125671
loss:0.8238559365272522
loss:0.8435823917388916
loss:0.7746265530586243
loss:0.7508037090301514
loss:0.7509687542915344
loss:0.7120912671089172
loss:0.7170543670654297
loss:0.6677703261375427
loss:0.6417092084884644
loss:0.6230636239051819
loss:0.6051941514015198
loss:0.5898979306221008
loss:0.5724374055862427
loss:0.5518147945404053
loss:0.5372375845909119
loss:0.5228298306465149
loss:0.5078498125076294
loss:0.5415394902229309
loss:0.49455907940864563
loss:0.607220470905304
loss:0.45511844754219055
loss:0.4498027265071869
loss:0.4321652054786682
loss:0.5000439882278442
loss:0.409685879945755
loss:0.4003109037876129
loss:0.3899511396884918
loss:0.3990858495235443
loss:0.4028222858905792
loss:0.36161813139915466
loss:0.37377089262008667
loss:0.44119352102279663
loss:0.4341203

loss:0.08878806978464127
loss:0.1717263013124466
loss:0.06847652792930603
loss:0.04384036362171173
loss:0.047140080481767654
loss:0.060672249644994736
loss:0.04368949681520462
loss:0.044839873909950256
loss:0.043523598462343216
loss:0.0496414378285408
loss:0.3162239193916321
loss:0.043275196105241776
loss:0.05280445143580437
loss:0.0518839955329895
loss:0.042904410511255264
loss:0.04555659741163254
loss:0.049672260880470276
loss:0.04416251927614212
loss:0.08712614327669144
loss:0.05570948123931885
loss:0.05229908227920532
loss:0.11347603797912598
loss:0.042128197848796844
loss:0.061556048691272736
loss:0.055761199444532394
loss:0.048121389001607895
loss:0.046290747821331024
loss:0.05886713042855263
loss:0.05446017533540726
loss:0.059621281921863556
loss:0.07038373500108719
loss:0.05576719343662262
loss:0.04805457219481468
loss:0.042365819215774536
loss:0.04266088455915451
loss:0.0414525605738163
loss:0.04451172426342964
loss:0.04638538882136345
loss:0.044960279017686844
loss:0.05126512

loss:0.0376620851457119
loss:0.03807871416211128
loss:0.024675684049725533
loss:0.03988773375749588
loss:0.0320945642888546
loss:0.06155632808804512
loss:0.031036533415317535
loss:0.025122715160250664
loss:0.04231763631105423


KeyboardInterrupt: 

In [36]:
def save_images(images, path, step):
    """Save a batch of images during training for monitoring."""
    images = (images / 2 + 0.5).clamp(0, 1)
    # Convert to binary
    images = (images > 0.5).float()
    grid = torchvision.utils.make_grid(images)
    # Convert to PIL image
    grid_image = torchvision.transforms.ToPILImage()(grid)
    grid_image = grid_image.resize((512*4, 768))
    os.makedirs(path, exist_ok=True)
    grid_image.save(f"{path}/sample_{step}.png")


def generate_images(
        checkpoint_path,
        image_height=768,
        image_width=512,
        output_dir="generated_images"
):
    # 配置
    config = {
        "image_height": image_height,
        "image_width": image_width,
        "sample_dir": output_dir
    }

    # 初始化设备
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
    print(f"Using device: {device}")

    # 初始化模型
    model = UNet2DModel(
        sample_size=[image_height, image_width],
        in_channels=3,
        out_channels=3,
        layers_per_block=1,
        block_out_channels=(32, 64, 128),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    # 加载checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    model.eval()

    with torch.no_grad():
        # Generate sample images
        sample = torch.randn(4, 3, config["image_height"], config["image_width"]).to(device)
        timesteps = torch.linspace(999, 0, 50).long().to(device)
        for t in timesteps:
            residual = model(sample, t.repeat(4), return_dict=False)[0]
            sample = noise_scheduler.step(residual, t, sample).prev_sample
    save_images(sample, config["sample_dir"], 500)

In [37]:
generate_images(
        checkpoint_path="checkpoint_500.pt",  # 你的checkpoint路径
        image_height=768,
        image_width=512,
        output_dir="generated_images"
    )

Using device: cuda
