In [1]:
import os
#os.chdir('/practical/fast-DiT')
import torch
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from accelerate import Accelerator
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image
from IPython.display import display
from models import DiT_S_2
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

GPU not found. Using CPU instead.


Load Checkpoints

In [2]:
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
checkpoints_dir = os.path.join(os.getcwd(), "checkpoints")
checkpoint_filename = "0750000.pt"
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_filename)
model_class = DiT_S_2(input_size=latent_size).to(device)
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

def load_checkpoint(checkpoint_path, model_class):
    model = model_class
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)

    #print("Checkpoint keys:", checkpoint.keys())
    model.load_state_dict(checkpoint["model"])
    model.to(device)
    model.eval()
    return model


load_checkpoint(checkpoint_path,model_class)

  checkpoint = torch.load(checkpoint_path, map_location=device)


DiT(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(4, 384, kernel_size=(2, 2), stride=(2, 2))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=384, bias=True)
      (1): SiLU()
      (2): Linear(in_features=384, out_features=384, bias=True)
    )
  )
  (y_embedder): LabelEmbedder(
    (embedding_table): Embedding(1001, 384)
  )
  (blocks): ModuleList(
    (0-11): 12 x DiTBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_fea

Compute GNS

In [3]:
def compute_gns(model, diffusion, dataloader):
    """
    Compute Gradient Noise Scale (GNS) for a model using a given loss function.

    :param model: torch.nn.Module, Model Instance
    :param diffusion: diffusion object, Diffusion model instance with training_losses
    :param dataloader: DataLoader, Data Loader
    :param device: str, "cpu" or "cuda"
    :return: float, GNS value
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()  
    grads = []  # Storing gradients for each batch
    total_grad = None  # Overall gradient

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)        
        t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
        model_kwargs = dict(y=y)

        # Cal loss
        model.zero_grad()
        loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
        loss = loss_dict["loss"].mean()
        loss.backward()  

        # Extract the gradient of the current batch
        batch_grad = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])
        grads.append(batch_grad)

        # Accumulate overall gradient
        if total_grad is None:
            total_grad = batch_grad.clone()
        else:
            total_grad += batch_grad

    # Compute tr(Σ)
    grads = torch.stack(grads, dim=0)  # The gradients of each batch are stacked, with the shape (num_batches, num_params)
    grad_mean = grads.mean(dim=0)
    noise_cov = ((grads - grad_mean).t() @ (grads - grad_mean)) / grads.size(0)
    trace_sigma = torch.trace(noise_cov)

    # Compute |G|^2
    norm_g_squared = torch.norm(total_grad / len(dataloader)) ** 2

    # return GNS
    return trace_sigma.item() / norm_g_squared.item()

In [4]:
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
checkpoints_dir = os.path.join(os.getcwd(), "checkpoints")
checkpoint_filename = "0750000.pt"
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_filename)
model_class = DiT_S_2(input_size=latent_size).to(device)
vae = AutoencoderKL.from_pretrained(vae_model).to(device)


# Setup accelerator
accelerator = Accelerator()
device = accelerator.device

# Create model and diffusion
model = DiT_S_2(input_size=latent_size).to(device)
diffusion = create_diffusion(timestep_respacing="")  # default 1000 steps

# Load checkpoint (optional)
if checkpoint_path:
    model = load_checkpoint(checkpoint_path, model)

# Setup ImageNet DataLoader
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

imagenet_dir = os.path.join(os.getcwd(), "dataset")
dataset = ImageFolder(imagenet_dir, transform=data_transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
for inputs, targets in dataloader:
    print(inputs.shape)  # (batch_size, channels, height, width)
# Compute GNS
gns_value = compute_gns(model, diffusion, dataloader)
#print(f"Gradient Noise Scale (GNS): {gns_value:.4f}")

  checkpoint = torch.load(checkpoint_path, map_location=device)


FileNotFoundError: [Errno 2] No such file or directory: '/Users/egecimsir/Desktop/GenAI Practical/Critical-Multitask-Batch-Sizes-in-Diffusion-Models/dataset'