In [2]:
import numpy as np
import torch
import math
import fastai
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
from dataloader import get_imagenette_dataloader
from quantize import quantize_img, plot_imgs

device = "cuda"

In [3]:
def method_helper(o): return list(filter(lambda x: x[0] != "_", dir(o)))

In [4]:
# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, class_labels=None 
)
# stage_2.enable_model_cpu_offload()


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[unet/diffusion_pytorch_model.fp16.safetensors, safety_checker/model.fp16.safetensors, text_encoder/model.fp16-00001-of-00002.safetensors, text_encoder/model.fp16-00002-of-00002.safetensors]
Loaded non-fp16 filenames:
[watermarker/diffusion_pytorch_model.safetensors
If this behavior is not expected, please check your folder structure.
Keyword arguments {'class_labels': None} are not expected by IFSuperResolutionPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
scheduler = stage_2.scheduler

# Load the UNet model - the core denoiser
unet = stage_2.unet.to(device)

In [6]:
from fastai.vision.all import Callback, ImageDataLoaders, Resize
import torch
import torch.nn as nn

class DDPMCB(Callback):
    """Custom FastAI callback for training a UNet diffusion model with DeepFloyd's DDPM Scheduler."""

    def __init__(self, unet, scheduler, timesteps=1000):
        self.unet = unet
        self.scheduler = scheduler
        self.timesteps = timesteps  # Number of diffusion steps

    def before_batch(self):
        """Add noise to the input images before passing them to the model."""
        # Get the real images from the batch
        images = self.xb[0]  # Shape: (batch_size, 3, H, W)
        batch_size = images.shape[0]

        # Generate noise
        t = torch.randint(0, self.timesteps, (batch_size,), device=images.device).long()
        noise = torch.randn_like(images)[:,:3,...]

        # Apply noise
        noisy_images = self.scheduler.add_noise(images[:,:3,...], noise, t)

        noisy_images = torch.cat([noisy_images, noisy_images], dim=1)  # Shape: (batch_size, 6, H, W)
        gt_image_true = torch.cat([images[:,:3,...], images[:,:3,...]], dim=1)  # Shape: (batch_size, 6, H, W)
 
        self.learn.xb = (noisy_images, images, t)
        self.learn.yb = (gt_image_true,)

    def after_pred(self):
        """Compute loss: Train UNet to predict noise (MSE loss)."""
        gt_image_pred = self.pred.to(torch.float32)  # Model's prediction
        gt_image_true = self.yb[0]  # Ground truth (now 6-channel)
        
        self.learn.loss = nn.functional.mse_loss(gt_image_pred, gt_image_true)

    # def after_batch(self):
    #     """Zero gradients manually after each batch (recommended for diffusion)."""
    #     self.learn.opt.zero_grad()

In [7]:
def ignore_category(f):
    def wrapper(frame):
        if not isinstance(frame, TensorCategory):
            frame = f(frame)
        return frame
    return wrapper

In [7]:
from fastai.vision.all import *
import torch
import numpy as np
from PIL import Image, ImageOps
from torchvision.transforms import functional as TF
import cv2


def quantize_image(image: Image.Image, num_colors: int) -> Image.Image:
    """Quantizes a PIL image using the median cut algorithm."""
    return image.convert("RGB").quantize(colors=num_colors, method=Image.MEDIANCUT).convert("RGB")

def compute_luminance(image: Image.Image) -> Image.Image:
    """Computes the luminance (grayscale) channel of a PIL image."""
    return ImageOps.grayscale(image)

def compute_gradients(luminance: Image.Image) -> tuple:
    """Computes x- and y-direction gradients of the luminance channel."""
    lum_array = np.array(luminance).astype(np.float32)
    grad_x = cv2.Sobel(lum_array, cv2.CV_32F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(lum_array, cv2.CV_32F, 0, 1, ksize=3)
    return grad_x, grad_y

def threshold_gradients(grad_x, grad_y, threshold=8) -> tuple:
    """Thresholds gradients to create binary images."""
    grad_x = (np.abs(grad_x) > threshold).astype(np.float32)
    grad_y = (np.abs(grad_y) > threshold).astype(np.float32)
    return grad_x, grad_y

@ignore_category
def conditioning_transform(image_tensor: torch.Tensor) -> torch.Tensor:
    """
    Applies all transformatiqons to generate the 7-channel input.
    
    Args:
        image_tensor (torch.Tensor): A tensor image of shape (C, H, W) in range [0, 1].
    
    Returns:
        torch.Tensor: A 7-channel tensor of shape (7, H, W) in range [0, 1].
    """
    # Convert tensor to PIL image for processing
    image = TF.to_pil_image(image_tensor)

    # 1-3: Quantized Image
    num_colors = 2 ** np.random.randint(2, 8)
    quantized_image = quantize_image(image, num_colors)
    quantized_tensor = TF.to_tensor(quantized_image)

    # 4: Quantization Level Channel (normalized to [0,1])
    quant_level_channel = torch.full((1, quantized_tensor.shape[1], quantized_tensor.shape[2]), num_colors / 256)

    # 5: Luminance Channel
    luminance = compute_luminance(image)
    luminance_tensor = TF.to_tensor(luminance)

    # 6: Gradient-Based Conditioning
    grad_x, grad_y = compute_gradients(luminance)
    grad_x_tensor = torch.tensor(grad_x, dtype=torch.float32).unsqueeze(0) / 255.0
    grad_y_tensor = torch.tensor(grad_y, dtype=torch.float32).unsqueeze(0) / 255.0

    # 7: Texture Indicator (1 if texture is present, 0 otherwise)
    texture_indicator = torch.ones_like(quant_level_channel)

    # Stack all channels into a 7-channel tensor
    stacked_tensor = torch.cat([quantized_tensor, quant_level_channel, grad_x_tensor, grad_y_tensor, texture_indicator], dim=0)

    return stacked_tensor.half()

In [8]:
clip_processor = stage_2.feature_extractor

@ignore_category
def clip_preprocess(frame):
    return torch.tensor(clip_processor(frame)["pixel_values"][0])
    
@ignore_category
def to_f16(frame):
    torch.tensor(frame, dtype=torch.float16)

In [9]:
import torchvision.transforms as transforms
dls = ImageDataLoaders.from_folder(
    "/mnt/wd/datasets/imagenette2",
    valid_pct=0.1,
    item_tfms=[clip_preprocess, conditioning_transform],
    bs=4,
    num_workers=16
)

In [10]:
dls.one_batch()[0].shape

torch.Size([4, 7, 224, 224])

In [11]:
# dls.one_batch()

In [12]:
# dls = ImageDataLoaders.from_folder("/mnt/wd/datasets/imagenette2", valid_pct=0.1, bs=4, item_tfms=Resize(224))

In [13]:
# def preprocess(frame)
#     return clip_processor(frame)["pixel_values"][0]

In [14]:
# clip_processor(dls.one_batch()[0], rescale=False)["pixel_values"][0].shape

In [15]:
one_batch = dls.one_batch()[0]

In [16]:
import torch.nn as nn
from transformers import ViTModel, ViTImageProcessor
class ViTImageEncoder(nn.Module):
    """
    Uses a pre-trained ViT model to extract embeddings from a quantized image.
    This replaces the CNN encoder with a stronger transformer-based encoder.
    """
    def __init__(self,channels_in,  model_name="google/vit-base-patch16-224", output_dim=1024):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)
        self.feature_extractor = ViTImageProcessor.from_pretrained(model_name)
        self.fc = nn.Linear(self.vit.config.hidden_size, output_dim)  # Resize to match UNet's expected size
        self.vit.embeddings.patch_embeddings.projection = nn.Conv2d(channels_in, 768, kernel_size=(16, 16), stride=(16, 16))
        self.vit.config.num_channels = channels_in
        self.vit.embeddings.patch_embeddings.num_channels = channels_in
    
    def forward(self, x):
        features = self.vit(x).last_hidden_state  # Extract token embeddings
        pooled_features = features.mean(dim=1)  # Global Average Pooling (B, D)
        return self.fc(pooled_features).unsqueeze(1)  # Shape: (batch, 1, output_dim)

# =============================
# 2️⃣ Load DeepFloyd IF UNet & Scheduler
# =============================


encoder = ViTImageEncoder(7, output_dim=unet.config.encoder_hid_dim).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
encoder(dls.one_batch()[0])

tensor([[[ 0.1545,  0.8223,  0.3134,  ...,  0.5701,  0.5903, -0.5687]],

        [[ 0.0820,  0.7439,  0.3273,  ...,  0.4758,  0.4894, -0.4048]],

        [[ 0.1745,  0.8123,  0.1718,  ...,  0.4505,  0.5863, -0.4291]],

        [[ 0.0928,  0.8939,  0.2339,  ...,  0.5358,  0.4837, -0.5152]]],
       device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [18]:
class CTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = unet
        self.unet.class_embedding = None
        self.vit = ViTImageEncoder(7, output_dim=self.unet.config.encoder_hid_dim).to(device)

        for param in self.unet.parameters():
            param.requires_grad = False
        

    def forward(self, noisy_images, images, t):
        encoded = self.vit(images).expand(-1, 77, -1).half()

        return self.unet(noisy_images.half(), t.half(), encoded.half())[0]

In [19]:
model = CTModel()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
one_batch = dls.one_batch()
one_batch[0].shape
images = one_batch[0]
images = torch.cat([images, images], dim=1)

In [21]:
with torch.no_grad():
    x = model(images, one_batch[0], torch.tensor([1.0]*4, dtype=torch.float16, device="cuda"))
# x

RuntimeError: Given groups=1, weight of size [160, 6, 3, 3], expected input[4, 14, 224, 224] to have 6 channels, but got 14 channels instead

In [None]:
x[0].size()

In [None]:
learn = Learner(dls, model, loss_func=torch.nn.MSELoss(), cbs=[DDPMCB(unet,scheduler)]).to_fp16()
# learn = Learner(dls, model.half(), loss_func=torch.nn.MSELoss(), cbs=[DDPMCB(unet,scheduler)])
# from fastai.learner import AvgSmoothLoss

# class FP16AvgSmoothLoss(AvgSmoothLoss):
#     def accumulate(self, learn):
#         self.count += 1
#         loss_fp16 = to_detach(learn.loss.mean()).half()  # Ensure FP16
#         self.val = torch.lerp(loss_fp16, self.val.half(), self.beta)  # Convert self.val to FP16

# learn.recorder.metrics = []

learn.lr_find()

In [None]:
lr = 10e-05
learn.fit_one_cycle(1, lr)

In [None]:
# If lr_max is not provided, use the suggested learning rate from the finder
    lr_max = lr_max or lr_max_suggested
    print(f"Using learning rate: {lr_max:.2e}")

    # 🚀 Step 2: Train the model with OneCycle policy
    learn.fit_one_cycle(epochs, lr_max)

    return learn  # Return trained Learner

In [None]:
x = x.expand(-1, 77, -1)
one_batch = torch.cat([one_batch, one_batch], dim=1)