In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoFeatureExtractor, AutoModelForDepthEstimation
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

class LoRA(nn.Module):
    def __init__(self, original_layer, rank):
        super().__init__()
        self.original_layer = original_layer
        in_features, out_features = original_layer.weight.shape
        self.rank = rank

        # Ensure the rank is smaller than the input/output features
        self.rank = min(self.rank, in_features, out_features)

        self.lora_down = nn.Linear(in_features, self.rank, bias=False)
        self.lora_up = nn.Linear(self.rank, out_features, bias=False)

        self.scale = 0.1

        nn.init.normal_(self.lora_down.weight, std=1.0 / self.rank)
        nn.init.zeros_(self.lora_up.weight)

    def forward(self, x):
        batch_size = x.size(0)
        if len(x.shape) == 3:
            seq_len = x.size(1)
            x_reshape = x.view(batch_size * seq_len, -1)
        else:
            x_reshape = x

        lora_output = self.lora_up(self.lora_down(x_reshape))
        lora_output = lora_output * self.scale

        if len(x.shape) == 3:
            lora_output = lora_output.view(batch_size, seq_len, -1)

        return self.original_layer(x) + lora_output

class DepthEstimator:
    def __init__(self, device='cpu'):
        self.device = device
        self.model = AutoModelForDepthEstimation.from_pretrained("intel/dpt-hybrid-midas").to(device)
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("intel/dpt-hybrid-midas")
        self.model.eval()

    def estimate_depth(self, images):
        # Convert 4-channel RGBA images to 3-channel RGB
        images_rgb = images[:, :3, :, :]

        # Convert tensor to PIL images for feature extractor
        images_np = ((images_rgb.clamp(-1, 1) + 1) * 127.5).byte().cpu().numpy().transpose(0, 2, 3, 1)
        images_list = [Image.fromarray(img) for img in images_np]

        inputs = self.feature_extractor(images=images_list, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.predicted_depth

class ILoraDepthTrainer:
    def __init__(self, stable_diffusion_model, depth_estimator, rank=8, device='cpu'):
        self.device = device
        self.pipeline = stable_diffusion_model
        self.depth_estimator = depth_estimator
        self.replace_attention_layers(rank)

    def replace_attention_layers(self, rank):
        def _should_replace_layer(name, module):
            attention_patterns = ['attn1', 'attn2']
            return (isinstance(module, nn.Linear) and
                    any(pattern in name for pattern in attention_patterns))

        for name, module in self.pipeline.unet.named_modules():
            if _should_replace_layer(name, module):
                *parent_path, attr_name = name.split('.')
                parent_module = self.pipeline.unet
                for part in parent_path:
                    parent_module = getattr(parent_module, part)
                setattr(parent_module, attr_name, LoRA(module, rank).to(self.device))

    def train(self, dataloader, optimizer, epochs=10):
        self.pipeline.unet.train()

        for epoch in range(epochs):
            epoch_loss = 0
            for batch_idx, (images,) in enumerate(dataloader):
                # Normalize images to [-1, 1] range for the pipeline
                images = images.to(self.device)

                # Generate pseudo-ground truth depth
                with torch.no_grad():
                    pseudo_depth = self.depth_estimator.estimate_depth(images)

                # Prepare inputs for the UNet
                batch_size = images.size(0)
                timesteps = torch.zeros(batch_size, device=self.device)
                encoder_hidden_states = torch.randn(
                    batch_size, 77, 768, device=self.device
                )

                # Forward pass
                noise_pred = self.pipeline.unet(
                    images,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states
                ).sample

                # Normalize depths for comparison
                pseudo_depth_norm = (pseudo_depth - pseudo_depth.min()) / (pseudo_depth.max() - pseudo_depth.min())
                noise_pred_norm = (noise_pred - noise_pred.min()) / (noise_pred.max() - noise_pred.min())

                loss = nn.MSELoss()(noise_pred_norm, pseudo_depth_norm)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if batch_idx % 10 == 0:
                    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}")

            avg_epoch_loss = epoch_loss / len(dataloader)
            print(f"Epoch {epoch} completed. Average loss: {avg_epoch_loss:.6f}")

    def save_sample_outputs(self, images, epoch, batch_idx):
        """
        Save sample outputs for visualization
        """
        with torch.no_grad():
            pseudo_depth = self.depth_estimator.estimate_depth(images)

            batch_size = images.size(0)
            timesteps = torch.zeros(batch_size, device=self.device)
            encoder_hidden_states = torch.randn(
                batch_size, 77, 768, device=self.device
            )

            predicted_depth = self.pipeline.unet(
                images,
                timesteps,
                encoder_hidden_states=encoder_hidden_states
            ).sample

        # Save the first image from the batch
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

        # Original image
        img_np = ((images[0].cpu().numpy() + 1) * 127.5).astype(np.uint8).transpose(1, 2, 0)
        ax1.imshow(img_np)
        ax1.set_title('Original Image')
        ax1.axis('off')

        # Ground truth depth
        ax2.imshow(pseudo_depth[0].cpu().numpy(), cmap='viridis')
        ax2.set_title('Ground Truth Depth')
        ax2.axis('off')

        # Predicted depth
        ax3.imshow(predicted_depth[0].cpu().numpy(), cmap='viridis')
        ax3.set_title('Predicted Depth')
        ax3.axis('off')

        plt.savefig(f'depth_comparison_epoch_{epoch}_batch_{batch_idx}.png')
        plt.close()

import torch

def generate_normalized_images(batch_size, height, width):
    """
    Generate random images normalized to [-1, 1] range with 4 channels (RGBA).
    """
    images = torch.rand(batch_size, 4, height, width) * 2 - 1  # Range: [-1, 1]
    return images

def main():
    # Set device
    device = "cpu"

    # Load models
    stable_diffusion_model = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        safety_checker=None
    ).to(device)
    depth_estimator = DepthEstimator(device=device)

    # Create trainer
    trainer = ILoraDepthTrainer(
        stable_diffusion_model,
        depth_estimator,
        rank=4,
        device=device
    )

    # Create properly normalized example dataset
    example_images = generate_normalized_images(8, 512, 512)  # 8 images, normalized to [-1, 1] with 4 channels
    dataset = TensorDataset(example_images)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        [p for p in trainer.pipeline.unet.parameters() if p.requires_grad],
        lr=1e-4,
        weight_decay=0.01
    )

    # Train
    trainer.train(dataloader, optimizer, epochs=5)

if __name__ == "__main__":
    main()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
