In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from torchvision.transforms import ToPILImage
from torchvision import models
from torchvision import transforms
from PIL import Image
from google.colab import drive
drive.mount('/content/drive')

In [None]:
size = (256, 256)

transform = transforms.Compose([
        transforms.Resize(size),        # Resize to fixed size
        transforms.ToTensor(),          # Converts to [C, H, W], values in [0, 1]
        transforms.Normalize([0.5]*3, [0.5]*3)  # Optional if using Tanh output
    ])

def load_image(image_path, transform=transform):
    """
    Loads a single image and converts it to a tensor of shape [1, 3, H, W]
    """



    image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
    image = transform(image)                      # [3, H, W]

    tensor = image.unsqueeze(0)

    return tensor

In [None]:
class NoiseScheduler:
    def __init__(self, timesteps=1000, beta_schedule="linear", device='cpu'):
        self.timesteps = timesteps
        self.device = torch.device(device) # Store the device

        if beta_schedule == "linear":
            betas = self._linear_beta_schedule(timesteps)
        elif beta_schedule == "cosine":
            betas = self._cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f"Unknown beta schedule: {beta_schedule}")

        # Move all pre-computed tensors to the specified device
        self.betas = betas.to(self.device)
        self.alphas = (1.0 - self.betas).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(self.device) # This is alpha_bar_t

        # Pre-compute square roots for convenience in the forward process
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device)

        # Ensure these are also on the correct device
        # You might also want to pre-compute these if you use them frequently to avoid repeated indexing
        # self.sqrt_recip_alphas = (1.0 / torch.sqrt(self.alphas)).to(self.device)
        # self.posterior_variance = ... (for sampling)

    def _linear_beta_schedule(self, timesteps, start=0.0001, end=0.02):
        # Betas are created on CPU, will be moved to self.device in __init__
        return torch.linspace(start, end, timesteps, dtype=torch.float32)

    def _cosine_beta_schedule(self, timesteps, s=0.008):
        # Betas are created on CPU, will be moved to self.device in __init__
        t = torch.arange(timesteps + 1, dtype=torch.float32)
        f_t = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi / 2) ** 2
        alphas_bar = f_t / f_t[0]
        betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
        return torch.clip(betas, 0.0001, 0.999) # Clip for stability

    def get_noisy_image(self, x_0, t, noise=None):
        """
        Adds noise to the original image x_0 at time t.
        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
        """
        # Ensure 'noise' is on the same device as x_0
        if noise is None:
            noise = torch.randn_like(x_0, device=self.device) # Create noise directly on the device
            # Alternatively: noise = torch.randn_like(x_0).to(self.device)

        # The 't' tensor should already be on the correct device (passed from the training loop).
        # We need to ensure that the alpha_bar_t value indexed from self.alphas_cumprod
        # has the correct shape for broadcasting and is on the same device as x_0.
        # self.alphas_cumprod is already on `self.device`

        # Ensure t has the correct shape for broadcasting
        # If t is a batch of scalar integers, expand its dimensions
        if t.ndim == 1: # (batch_size,)
            # Expand to (batch_size, 1, 1, 1) for image dimensions
            alpha_bar_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        else: # Assuming t is already shaped correctly, e.g., (1, 1, 1, 1) or scalar
            # This case means t might already have its own batch dimension
            # Ensure it's on the correct device for indexing (though it should be by now)
            # and then get the value.
            alpha_bar_t = self.alphas_cumprod[t]
            # No need for alpha_bar_t.to(x_0.device) because self.alphas_cumprod is already on self.device

        # If t was accidentally on a different device (e.g., cpu) when passed in,
        # self.alphas_cumprod[t] would automatically result in a tensor on self.device
        # if self.alphas_cumprod is on self.device. PyTorch handles this cross-device indexing,
        # but it's more efficient if 't' is already on the target device.

        sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)

        x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
        return x_t, noise # Return x_t and the noise that was added (for loss calculation)

    def denoise_image(self, x_t, t, noise_unet):
        """
        Denoises an image x_t given the predicted noise (noise_unet) at time t.
        This is the reverse step.
        """
        # Similar logic as get_noisy_image for handling 't' and ensuring device consistency
        if t.ndim == 1: # (batch_size,)
            alpha_bar_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        else:
            alpha_bar_t = self.alphas_cumprod[t]

        sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)

        # All tensors involved in this calculation (x_t, noise_unet, sqrt_alpha_bar_t, etc.)
        # must be on the same device. x_t and noise_unet are assumed to be on the target device
        # from the model's output and input.
        x_0 = (x_t - sqrt_one_minus_alpha_bar_t * noise_unet) / sqrt_alpha_bar_t

        return x_0

In [None]:
class SinusoidalPositionalEmbedding(nn.Module):
    """
    Sinusoidal Positional Embedding for time steps.
    Transforms a scalar time step 't' into a high-dimensional vector.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.half_dim = dim // 2
        # Frequencies are based on 10000^(2i/dim)
        self.embeddings = math.log(10000) / (self.half_dim - 1)
        self.embeddings = torch.exp(torch.arange(self.half_dim) * -self.embeddings)

    def forward(self, time):
        # time is typically a batch of scalars (B,)
        # unsqueeze for broadcasting: (B, 1)
        # multiply by frequencies: (B, half_dim)
        time_embedding = time.unsqueeze(1) * self.embeddings.to(time.device)

        # Apply sine and cosine: (B, half_dim) -> (B, dim)
        time_embedding = torch.cat((time_embedding.sin(), time_embedding.cos()), dim=-1)
        return time_embedding

class TimeEmbeddingMLP(nn.Module):
    """
    MLP to process the sinusoidal time embedding into scale and shift parameters.
    """
    def __init__(self, dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(), # SiLU is a common activation in modern diffusion models
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class AdaptiveGroupNorm(nn.Module):
    """
    Adaptive Group Normalization (AdaGN) layer.
    Applies GroupNorm, then scales and shifts features based on conditioning (time embedding).
    """
    def __init__(self, num_groups, num_channels, time_emb_dim):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups, num_channels, affine=False) # affine=False means no learnable scale/shift in GroupNorm itself
        self.time_proj = nn.Linear(time_emb_dim, 2 * num_channels) # Projects time_emb to 2 * num_channels (for scale and shift)

    def forward(self, x, time_emb):
        # x: (B, C, H, W)
        # time_emb: (B, time_emb_dim)

        # Apply GroupNorm first
        normed_x = self.norm(x)

        # Get scale and shift from time embedding
        scale_shift = self.time_proj(time_emb) # (B, 2 * num_channels)
        scale, shift = scale_shift.chunk(2, dim=1) # (B, num_channels), (B, num_channels)

        # Reshape scale and shift for broadcasting: (B, C, 1, 1)
        scale = scale.unsqueeze(-1).unsqueeze(-1)
        shift = shift.unsqueeze(-1).unsqueeze(-1)

        # Apply adaptive scaling and shifting
        output = normed_x * (1 + scale) + shift # The `1 + scale` ensures that if scale is 0, it doesn't zero out the features.

        return output

In [None]:
class ContractingBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False): # Added time_emb_dim
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # Changed stride=2 to padding=1, and will use another conv for downsample
        self.activation = nn.LeakyReLU(0.2)
        self.norm1 = AdaptiveGroupNorm(num_groups=8, num_channels=out_channels, time_emb_dim=time_emb_dim) # Use AdaGN

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = AdaptiveGroupNorm(num_groups=8, num_channels=out_channels, time_emb_dim=time_emb_dim) # Use AdaGN

        self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) # Explicit downsample after the convs

        if use_dropout:
            self.dropout = nn.Dropout(0.3) # Common dropout rate
        self.use_dropout = use_dropout

    def forward(self, x, time_emb): # Now accepts time_emb
        skip_x = x # Save x for residual connection (common in diffusion U-Nets)

        x = self.conv1(x)
        x = self.norm1(x, time_emb) # Pass time_emb to AdaGN
        x = self.activation(x)

        x = self.conv2(x)
        x = self.norm2(x, time_emb) # Pass time_emb to AdaGN
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)

        # Optional: Add a residual connection before downsampling
        # x = x + skip_x # This requires skip_x to have same shape. If not, it's a ResNet-style block where first conv changes channels.
        # For typical diffusion blocks, you'd apply the initial convolution with proper input/output channels to make this work.
        # Let's keep it simpler for now and assume typical residual blocks where output channels match input channels for the shortcut.

        # Downsample
        x = self.downsample(x)
        return x

class ExpandingBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False): # Added time_emb_dim
        super().__init__()
        # Changed upsample: Now upsamples to `out_channels` (which will be `skip_con_x` channels).
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        # Convs after concatenation (input channels will be doubled due to skip connection)
        self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1)
        self.norm1 = AdaptiveGroupNorm(num_groups=8, num_channels=out_channels, time_emb_dim=time_emb_dim)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = AdaptiveGroupNorm(num_groups=8, num_channels=out_channels, time_emb_dim=time_emb_dim)

        self.activation = nn.ReLU() # ReLU is fine here
        if use_dropout:
            self.dropout = nn.Dropout(0.3)
        self.use_dropout = use_dropout

    def forward(self, x, skip_con_x, time_emb): # Now accepts time_emb
        x = self.upsample(x)

        # Pad or crop skip_con_x if dimensions don't match exactly after upsampling
        # This is common if strides/paddings lead to slight discrepancies.
        # For simplicity, assuming perfect alignment for now based on kernel=2, stride=2 upsample
        if x.shape != skip_con_x.shape:
             x = torch.nn.functional.interpolate(x, size=skip_con_x.shape[2:], mode='nearest') # Or use F.pad for padding smaller x.

        x = torch.cat([x, skip_con_x], dim=1) # dim=1 for channel concatenation

        x = self.conv1(x)
        x = self.norm1(x, time_emb) # Pass time_emb to AdaGN
        x = self.activation(x)

        x = self.conv2(x)
        x = self.norm2(x, time_emb) # Pass time_emb to AdaGN
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

# FeatureMapBlock remains the same, as it's just input/output mapping
class FeatureMapBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, input_channels, output_channels, hidden_channels=32, time_emb_dim=256): # Added time_emb_dim
        super().__init__()

        # Time embedding modules
        self.time_mlp = nn.Sequential(
            SinusoidalPositionalEmbedding(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4), # A common design choice is to expand the time emb size
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim) # Then bring it back to a standard size for injection
        )
        self.time_emb_dim = time_emb_dim # Store it for internal use

        # Initial feature mapping
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)

        # Contracting Path
        self.contract1 = ContractingBlock(hidden_channels, hidden_channels * 2, time_emb_dim)
        self.contract2 = ContractingBlock(hidden_channels * 2, hidden_channels * 4, time_emb_dim)
        self.contract3 = ContractingBlock(hidden_channels * 4, hidden_channels * 8, time_emb_dim)
        self.contract4 = ContractingBlock(hidden_channels * 8, hidden_channels * 16, time_emb_dim)
        self.contract5 = ContractingBlock(hidden_channels * 16, hidden_channels * 32, time_emb_dim)
        self.contract6 = ContractingBlock(hidden_channels * 32, hidden_channels * 64, time_emb_dim)

        # Expanding Path
        # Note: The output channels of ExpandingBlock's upsample should match the skip connection's channels.
        # And the convs inside ExpandingBlock will manage their own `out_channels` (which is half of the *incoming* features after upsample+concat).
        # Let's define the `out_channels` for each ExpandingBlock to match the `skip_con_x` channels they're connecting to.
        self.expand0 = ExpandingBlock(hidden_channels * 64, hidden_channels * 32, time_emb_dim) # upsamples 64->32, concat with x5 (32), then convs work on 64
        self.expand1 = ExpandingBlock(hidden_channels * 32, hidden_channels * 16, time_emb_dim)
        self.expand2 = ExpandingBlock(hidden_channels * 16, hidden_channels * 8, time_emb_dim)
        self.expand3 = ExpandingBlock(hidden_channels * 8, hidden_channels * 4, time_emb_dim)
        self.expand4 = ExpandingBlock(hidden_channels * 4, hidden_channels * 2, time_emb_dim)
        self.expand5 = ExpandingBlock(hidden_channels * 2, hidden_channels, time_emb_dim) # Final expand block outputting to original hidden_channels

        # Final output layer
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        # self.sigmoid = torch.nn.Sigmoid() # REMOVE THIS FOR NOISE PREDICTION

    def forward(self, x, time): # NOW TAKES TIME AS INPUT!
        # Process time embedding
        time_emb = self.time_mlp(time)

        # Contracting Path
        x0 = self.upfeature(x) # hidden_channels
        x1 = self.contract1(x0, time_emb) # hidden_channels * 2
        x2 = self.contract2(x1, time_emb) # hidden_channels * 4
        x3 = self.contract3(x2, time_emb) # hidden_channels * 8
        x4 = self.contract4(x3, time_emb) # hidden_channels * 16
        x5 = self.contract5(x4, time_emb) # hidden_channels * 32
        x6 = self.contract6(x5, time_emb) # hidden_channels * 64 (bottleneck)

        # Expanding Path
        x7 = self.expand0(x6, x5, time_emb) # hidden_channels * 32
        x8 = self.expand1(x7, x4, time_emb) # hidden_channels * 16
        x9 = self.expand2(x8, x3, time_emb) # hidden_channels * 8
        x10 = self.expand3(x9, x2, time_emb) # hidden_channels * 4
        x11 = self.expand4(x10, x1, time_emb) # hidden_channels * 2
        x12 = self.expand5(x11, x0, time_emb) # hidden_channels (back to original input channels)

        # Final output convolution
        xn = self.downfeature(x12)

        return xn # Return raw output (noise prediction)

class Vgg19(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    def forward(self, X):

        X = (X + 1) / 2
        X = self.normalize(X)

        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

class VGGLoss(nn.Module):
    def __init__(self,layids = None):
        super(VGGLoss, self).__init__()
        self.vgg = Vgg19()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        self.layids = layids

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        if self.layids is None:
            self.layids = list(range(len(x_vgg)))
        for i in self.layids:
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.00001
mse_loss = nn.MSELoss().to(device)
vgg_loss = VGGLoss().to(device)
vgg_lambda = 0.001
epochs = 500

In [None]:
path = os.listdir("/content/drive/MyDrive/AIClothes/Inputs_VITON/inputs_difussion_model/")

inputs_viton = []
original_images = []
mask_images = []

for img in path:

    input_viton = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/inputs_difussion_model/" + img).to(device)
    original_image = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/images/" + img).to(device)
    mask_image = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/agnostic_mask/" + img).to(device)

    inputs_viton.append(input_viton)
    original_images.append(original_image)
    mask_images.append(mask_image)

In [None]:
viton = UNet(7, 3).to(device)
viton_opt = torch.optim.AdamW(viton.parameters(), lr=lr, betas=(0.5, 0.999))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_normal_(m.weight) # Good for conv layers
        if m.bias is not None: # Initialize bias if it exists
            torch.nn.init.constant_(m.bias, 0)
    # If you are using GroupNorm (which is part of AdaptiveGroupNorm often):
    if isinstance(m, nn.GroupNorm):
        if hasattr(m, 'weight') and m.weight is not None:
            torch.nn.init.constant_(m.weight, 1) # Gamma
        if hasattr(m, 'bias') and m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

viton = viton.apply(weights_init)

In [None]:
#lrscheduler = lr_scheduler.CosineAnnealingLR(viton_opt, T_max=epochs, eta_min=1e-8) # A very small min LR

In [None]:
CHECKPOINT_PATH ='/content/drive/MyDrive/AIClothes/Models/Diffusion/allimg_viton_adamW_schedulelr1e-05_vgg0.001_epoch_499.pth'

checkpoint = torch.load(CHECKPOINT_PATH)

viton.load_state_dict(checkpoint['model_state_dict'])
viton_opt.load_state_dict(checkpoint['optimizer_state_dict'])

losses = checkpoint['loss']

In [None]:
scheduler = NoiseScheduler(timesteps=1000, beta_schedule="linear", device=device)

#losses = []

for epoch in range(epochs):

  for original_image, input_viton, mask_image in zip(original_images, inputs_viton, mask_images):

    x_0_original = original_image
    t = torch.randint(0, scheduler.timesteps, (1, )).to(device) # Random batch of timesteps
    x_t_original, noise = scheduler.get_noisy_image(x_0_original, t)

    x_0_agnostic = input_viton
    x_t_agnostic, _ = scheduler.get_noisy_image(x_0_agnostic, t, noise)

    input_original = torch.concat([x_t_original, input_viton, mask_image[:, 0:1, :, :]], dim=1)
    input_agnostic = torch.concat([x_t_agnostic, input_viton, mask_image[:, 0:1, :, :]], dim=1)

    unet_output_original = viton(input_original, torch.tensor([t]).to(device))

    unet_output_agnostic = viton(input_agnostic, torch.tensor([t]).to(device))

    denoise_agnostic = scheduler.denoise_image(x_t_agnostic, t, unet_output_agnostic)

    loss = mse_loss(unet_output_original, noise) + (vgg_lambda * vgg_loss(denoise_agnostic, x_0_original))

    losses.append(loss.item())

    viton_opt.zero_grad()
    loss.backward()
    viton_opt.step()

    #lrscheduler.step()

    # Log current LR to see its progression
    current_lr = viton_opt.param_groups[0]['lr']


    if epoch == epochs-1:

      print(f"Epoch {epoch}, Mean Loss: {np.mean(losses)} at timestep {t} and lr: {current_lr}")

      fig, axs = plt.subplots(1, 3, figsize=(20, 10))
      axs[0].imshow(((original_image[0]+1) / 2).permute(1, 2, 0).detach().cpu().numpy())
      axs[0].axis('off')

      axs[1].imshow(((x_t_original[0]+1)/2).permute(1, 2, 0).detach().cpu().numpy())
      axs[1].axis('off')

      axs[2].imshow(((denoise_agnostic[0]+1)/2).permute(1, 2, 0).detach().cpu().numpy())
      axs[2].axis('off')

    #   plt.show()

    # if count % 500 == 0:

    #   checkpoint = {
    #       'epoch': epoch,
    #       'model_state_dict': viton.state_dict(),
    #       'optimizer_state_dict': viton_opt.state_dict(),
    #       'loss': losses,
    #       # You can add more training parameters here if needed, e.g., scheduler state, random seeds
    #   }

    #   torch.save(checkpoint, f'/content/drive/MyDrive/AIClothes/Models/viton_adamW_schedulelr{lr}_vgg{vgg_lambda}_epoch_{3000+epoch}.pth')

In [None]:
checkpoint = {
        'epoch': epoch,
        'model_state_dict': viton.state_dict(),
        'optimizer_state_dict': viton_opt.state_dict(),
        'loss': losses,
        # You can add more training parameters here if needed, e.g., scheduler state, random seeds
    }

torch.save(checkpoint, f'/content/drive/MyDrive/AIClothes/Models/allimg_viton_adamW_schedulelr{lr}_vgg{vgg_lambda}_epoch_{500+epoch}.pth')

In [None]:
# CHECKPOINT_PATH ='/content/drive/MyDrive/AIClothes/Models/viton_adamW_schedulelr0.0001_vgg0.001_epoch_2999.pth'

# checkpoint = torch.load(CHECKPOINT_PATH)

# viton.load_state_dict(checkpoint['model_state_dict'])
# viton_opt.load_state_dict(checkpoint['optimizer_state_dict'])

# losses = checkpoint['loss']

In [None]:
timesteps_to_sample = torch.linspace(scheduler.timesteps - 10, 5, 100).to(device).long()
x_current, _ = scheduler.get_noisy_image(input_viton, timesteps_to_sample[0])

final_image = []

with torch.no_grad():
# --- DDIM Denoising Loop ---
    for i, t in enumerate(timesteps_to_sample):
        # Determine the previous timestep in your custom sequence
        t_prev = timesteps_to_sample[i + 1] if i < len(timesteps_to_sample) - 1 else 0

        # --- CRITICAL: CORRECTLY CONSTRUCT THE UNet INPUT ---
        # input_to_unet MUST be [current_noisy_image, STATIC_WARPED_CLOTH, STATIC_WARPED_MASK]
        input_to_unet = torch.cat([x_current, input_viton, mask_image[:, 0:1, :, :]], dim=1)

        # --- Call the UNet to predict noise ---
        # (assuming your UNet is named 'viton')
        predicted_noise = viton(input_to_unet, torch.tensor([t]).to(device))

        # --- PERFORM THE DDIM UPDATE (Explicitly, NOT using scheduler.denoise_image or get_noisy_image) ---

        # 1. Estimate the clean image (x_0_pred) from the current noisy image (x_current) and predicted noise
        alpha_bar_t_val = scheduler.alphas_cumprod[t].item()
        x_0_pred = (x_current - math.sqrt(1.0 - alpha_bar_t_val) * predicted_noise) / math.sqrt(alpha_bar_t_val)

        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0) # <--- ESSENTIAL CLAMPING

        # # 2. Calculate the next noisy state (x_{t-1}) using x_0_pred and predicted_noise
        alpha_bar_t_prev_val = scheduler.alphas_cumprod[t_prev].item()
        x_current = math.sqrt(alpha_bar_t_prev_val) * x_0_pred + math.sqrt(1.0 - alpha_bar_t_prev_val) * predicted_noise
        x_current = torch.clamp(x_current, -1.0, 1.0) # <--- ESSENTIAL CLAMPING

        if (i+1) % 10 == 0:

          plt.imshow(((x_0_pred[0].permute(1, 2, 0).cpu().numpy()) + 1) / 2)
          plt.title(f"DDIM Step {i+1} (t={t.item()})")
          plt.axis('off')
          plt.show()

# After the loop, x_current holds the final denoised image
final_image.append(x_0_pred)
# ... (rest of your post-processing and saving/plotting code) ...