In [None]:
import os
import math
import torch
import torch.nn as nn
import requests
import numpy as np
from PIL import Image
from tqdm import tqdm  
from abc import abstractmethod
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import torch_optimizer as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader


from DDIM_C import GaussianDiffusion
from ldm.modules.diffusionmodules.openaimodel import UNetModel


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [None]:
#Step 1 : 

transform2 = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])

class PalmDataset(Dataset):
    def __init__(self, palm_dir, label_dir, transform=None):
        self.palm_dir = palm_dir
        self.label_dir = label_dir
        self.transform = transform
        
        self.palm_images = sorted([f for f in os.listdir(palm_dir) 
                                if os.path.isfile(os.path.join(palm_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.label_images = sorted([f for f in os.listdir(label_dir) 
                                    if os.path.isfile(os.path.join(label_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg'))])


    def __len__(self):
        return len(self.palm_images)

    def __getitem__(self, idx):
        palm_path = os.path.join(self.palm_dir, self.palm_images[idx])
        label_path = os.path.join(self.label_dir, self.label_images[idx])
        palm_image = Image.open(palm_path).convert("RGB")
        label_image = Image.open(label_path).convert("RGB")
        
        palm_image = palm_image.resize((256, 256))
        label_image = label_image.resize((64, 64))

        if self.transform:
            palm_image = self.transform(palm_image)
            label_image = self.transform(label_image)

        return palm_image, label_image



real_image_folder = './datasets/palm2bezier/train/trainA'
bezier_image_folder='./datasets/palm2bezier/test/binary_images_fake_bezier'

dataset = PalmDataset(real_image_folder, bezier_image_folder, transform=transform2)

def custom_collate_fn(batch):
    palm_images = torch.stack([item[0] for item in batch])
    label_images = torch.stack([item[1] for item in batch])
    return palm_images, label_images

train_loader = DataLoader(dataset, 
                          batch_size=32, 
                          shuffle=True, 
                          num_workers=16, 
                          pin_memory=True, 
                          collate_fn=custom_collate_fn)

In [None]:
#Step 2 : 
ckpt_path = 'vqmodel_checkpoint.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)


class VQModel(torch.nn.Module):
    def __init__(self, ddconfig, embed_dim=3,n_embed=8192):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
        self.encoder = Encoder(**ddconfig)
        # self.decoder = Decoder(**ddconfig)
        # self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
        #                                 remap=None,
        #                                 sane_index_shape=False)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        # self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    def encode(self, x):  
        h = self.quant_conv(self.encoder(x))
        return h

    # def decode(self, x, force_not_quantize=False):  
    #     if not force_not_quantize:
    #         quant, emb_loss, info = self.quantize(x)
    #     else:
    #         quant = x
    #     dec = self.decoder(self.post_quant_conv(quant))
    #     return dec


vq_model = VQModel(  
    ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    embed_dim=3,
    n_embed=8192
)


if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  

filtered_state_dict = {k: v for k, v in state_dict.items() if k in vq_model.state_dict()}
vq_model.load_state_dict(filtered_state_dict, strict=False)

vq_model = vq_model.to(device)
for param in vq_model.parameters():
    param.requires_grad = False

vq_model.eval()

Working with z of shape (1, 3, 64, 64) = 12288 dimensions.


VQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): 

In [None]:
#Step 3 : 
unet_config = {
    "image_size": 64,
    "in_channels": 6,
    "out_channels": 3,
    "model_channels": 224,
    "attention_resolutions": [8, 4, 2],
    "num_res_blocks": 2,
    "channel_mult": [1, 2, 3, 4],
    "num_head_channels": 32,
}



unet_model = UNetModel(**unet_config)


checkpoint_path = 'ddim_c.ckpt'  


pretrained_dict = torch.load(checkpoint_path,weights_only = True)


unet_model.load_state_dict(pretrained_dict)

print("Pretrained weights loaded successfully.")

unet_model.to(device)
unet_model.train()

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=224, out_features=896, bias=True)
    (1): SiLU()
    (2): Linear(in_features=896, out_features=896, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(6, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=896, out_features=224, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(224, 224, kernel_size=(3, 3), 

In [None]:
#Step 4 : 
timesteps = 1000
gaussian_diffusion = GaussianDiffusion(timesteps=timesteps,
        beta_schedule='linear',
        linear_start = 0.0015,
        linear_end= 0.0155)

In [None]:
#Step 5 : 
optimizer = torch.optim.AdamW(unet_model.parameters(), lr=1e-06)


cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

global_step = 0
num_epochs = 10

for epoch in tqdm(range(num_epochs), desc='Epochs'):
    for palm_images, label_images in train_loader:
        optimizer.zero_grad()
        
        palm_images = palm_images.to(device)
        label_images = label_images.to(device)
        
        batch_size = palm_images.shape[0]
        
        with torch.no_grad():
            latent = vq_model.encode(palm_images).detach()
            
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = gaussian_diffusion.train_losses(unet_model, latent, t, label_images)
        
        loss.backward() 
        optimizer.step()
            
        global_step += 1
    
    
    cosine_scheduler.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    

    if (epoch + 1) % 5 == 0:
        torch.save(unet_model.state_dict(), 'ddim_c.ckpt')
        print("Model parameters saved to ddim_c.ckpt")


Epochs:   1%|          | 1/100 [04:06<6:46:29, 246.36s/it]

Epoch [1/100], Loss: 0.6326


Epochs:   2%|▏         | 2/100 [08:11<6:41:33, 245.85s/it]

Epoch [2/100], Loss: 0.3956


Epochs:   3%|▎         | 3/100 [12:17<6:37:16, 245.74s/it]

Epoch [3/100], Loss: 0.2766


Epochs:   4%|▍         | 4/100 [16:23<6:33:09, 245.72s/it]

Epoch [4/100], Loss: 0.1791


Epochs:   5%|▌         | 5/100 [20:28<6:29:06, 245.75s/it]

Epoch [5/100], Loss: 0.3144


Epochs:   6%|▌         | 6/100 [24:34<6:24:56, 245.71s/it]

Epoch [6/100], Loss: 0.2835


Epochs:   7%|▋         | 7/100 [28:40<6:20:49, 245.69s/it]

Epoch [7/100], Loss: 0.1631


Epochs:   8%|▊         | 8/100 [32:45<6:16:34, 245.59s/it]

Epoch [8/100], Loss: 0.1770


Epochs:   9%|▉         | 9/100 [36:51<6:12:26, 245.56s/it]

Epoch [9/100], Loss: 0.1427


Epochs:  10%|█         | 10/100 [40:56<6:08:22, 245.58s/it]

Epoch [10/100], Loss: 0.2576


Epochs:  11%|█         | 11/100 [45:02<6:04:18, 245.60s/it]

Epoch [11/100], Loss: 0.1526


Epochs:  12%|█▏        | 12/100 [49:07<6:00:06, 245.52s/it]

Epoch [12/100], Loss: 0.1909


Epochs:  13%|█▎        | 13/100 [53:13<5:56:07, 245.60s/it]

Epoch [13/100], Loss: 0.2029


Epochs:  15%|█▌        | 15/100 [1:01:25<5:48:05, 245.71s/it]

Epoch [15/100], Loss: 0.2489


Epochs:  16%|█▌        | 16/100 [1:05:30<5:43:56, 245.67s/it]

Epoch [16/100], Loss: 0.1966


Epochs:  18%|█▊        | 18/100 [1:13:42<5:35:47, 245.70s/it]

Epoch [18/100], Loss: 0.2377


Epochs:  19%|█▉        | 19/100 [1:17:48<5:31:44, 245.74s/it]

Epoch [19/100], Loss: 0.1295
Epoch [20/100], Loss: 0.1480


Epochs:  20%|██        | 20/100 [1:21:55<5:28:26, 246.33s/it]

Model parameters saved to ddim_c.ckpt


Epochs:  21%|██        | 21/100 [1:26:01<5:24:03, 246.12s/it]

Epoch [21/100], Loss: 0.2048


Epochs:  22%|██▏       | 22/100 [1:30:07<5:19:47, 246.00s/it]

Epoch [22/100], Loss: 0.1423


Epochs:  23%|██▎       | 23/100 [1:34:12<5:15:34, 245.90s/it]

Epoch [23/100], Loss: 0.1937


Epochs:  24%|██▍       | 24/100 [1:38:18<5:11:26, 245.88s/it]

Epoch [24/100], Loss: 0.2108


Epochs:  25%|██▌       | 25/100 [1:42:24<5:07:20, 245.87s/it]

Epoch [25/100], Loss: 0.1810


Epochs:  26%|██▌       | 26/100 [1:46:30<5:03:15, 245.89s/it]

Epoch [26/100], Loss: 0.1791


Epochs:  27%|██▋       | 27/100 [1:50:35<4:59:00, 245.76s/it]

Epoch [27/100], Loss: 0.1918


Epochs:  28%|██▊       | 28/100 [1:54:41<4:54:52, 245.74s/it]

Epoch [28/100], Loss: 0.1809


Epochs:  29%|██▉       | 29/100 [1:58:47<4:50:46, 245.72s/it]

Epoch [29/100], Loss: 0.2182


Epochs:  30%|███       | 30/100 [2:02:52<4:46:39, 245.71s/it]

Epoch [30/100], Loss: 0.1874


Epochs:  31%|███       | 31/100 [2:06:58<4:42:37, 245.76s/it]

Epoch [31/100], Loss: 0.1190


Epochs:  32%|███▏      | 32/100 [2:11:04<4:38:35, 245.82s/it]

Epoch [32/100], Loss: 0.1563


Epochs:  33%|███▎      | 33/100 [2:15:10<4:34:30, 245.83s/it]

Epoch [33/100], Loss: 0.1727


Epochs:  34%|███▍      | 34/100 [2:19:16<4:30:19, 245.75s/it]

Epoch [34/100], Loss: 0.1664


Epochs:  35%|███▌      | 35/100 [2:23:21<4:26:08, 245.66s/it]

Epoch [35/100], Loss: 0.1656


Epochs:  36%|███▌      | 36/100 [2:27:27<4:22:02, 245.67s/it]

Epoch [36/100], Loss: 0.1654


Epochs:  37%|███▋      | 37/100 [2:31:32<4:17:56, 245.67s/it]

Epoch [37/100], Loss: 0.1482


Epochs:  38%|███▊      | 38/100 [2:35:38<4:13:55, 245.73s/it]

Epoch [38/100], Loss: 0.1019


Epochs:  39%|███▉      | 39/100 [2:39:44<4:09:47, 245.70s/it]

Epoch [39/100], Loss: 0.1867
Epoch [40/100], Loss: 0.1599


Epochs:  40%|████      | 40/100 [2:43:52<4:06:23, 246.39s/it]

Model parameters saved to ddim_c.ckpt


Epochs:  41%|████      | 41/100 [2:47:57<4:02:00, 246.10s/it]

Epoch [41/100], Loss: 0.1207


Epochs:  42%|████▏     | 42/100 [2:52:03<3:57:40, 245.87s/it]

Epoch [42/100], Loss: 0.1704


Epochs:  43%|████▎     | 43/100 [2:56:08<3:53:30, 245.79s/it]

Epoch [43/100], Loss: 0.1295


Epochs:  44%|████▍     | 44/100 [3:00:14<3:49:17, 245.66s/it]

Epoch [44/100], Loss: 0.1625


Epochs:  45%|████▌     | 45/100 [3:04:19<3:45:10, 245.64s/it]

Epoch [45/100], Loss: 0.1325


Epochs:  46%|████▌     | 46/100 [3:08:25<3:41:02, 245.60s/it]

Epoch [46/100], Loss: 0.1401


Epochs:  47%|████▋     | 47/100 [3:12:30<3:36:55, 245.58s/it]

Epoch [47/100], Loss: 0.1703


Epochs:  48%|████▊     | 48/100 [3:16:36<3:32:47, 245.54s/it]

Epoch [48/100], Loss: 0.0674


Epochs:  49%|████▉     | 49/100 [3:20:41<3:28:42, 245.54s/it]

Epoch [49/100], Loss: 0.1448


Epochs:  51%|█████     | 51/100 [3:28:53<3:20:33, 245.58s/it]

Epoch [51/100], Loss: 0.1357


Epochs:  52%|█████▏    | 52/100 [3:32:58<3:16:24, 245.51s/it]

Epoch [52/100], Loss: 0.1841


Epochs:  53%|█████▎    | 53/100 [3:37:04<3:12:21, 245.56s/it]

Epoch [53/100], Loss: 0.1089


Epochs:  54%|█████▍    | 54/100 [3:41:09<3:08:16, 245.57s/it]

Epoch [54/100], Loss: 0.1636


Epochs:  55%|█████▌    | 55/100 [3:45:15<3:04:12, 245.60s/it]

Epoch [55/100], Loss: 0.1502


Epochs:  56%|█████▌    | 56/100 [3:49:20<3:00:04, 245.55s/it]

Epoch [56/100], Loss: 0.2019
