In [None]:
!pip install diffusers==0.25.0 huggingface_hub==0.22.2 transformers accelerate 

In [None]:
!pip install --upgrade diffusers huggingface_hub transformers accelerate

In [None]:
import os
import copy
import math
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
from diffusers import DDPMPipeline, DDPMScheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


device(type='cuda')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import shutil

DRIVE_TRAIN_SOURCE = "/content/drive/MyDrive/dl_train"
DRIVE_TEST_SOURCE = "/content/drive/MyDrive/dl_test"

LOCAL_TRAIN_DIR = "/content/local_train_data"
LOCAL_TEST_DIR = "/content/local_test_data"

def copy_data_to_local(source_path, dest_path, data_name="Data"):
    if not os.path.exists(dest_path):
        print(f"ƒêang copy {data_name} t·ª´ Drive sang Local (TƒÉng t·ªëc ƒë·ªô)...")
        try:
            shutil.copytree(source_path, dest_path)
            print(f"Copy {data_name} ho√†n t·∫•t!")
        except Exception as e:
            print(f"L·ªói khi copy {data_name}: {e}")
    else:
        print(f"{data_name} ƒë√£ c√≥ s·∫µn ·ªü Local.")

copy_data_to_local(DRIVE_TRAIN_SOURCE, LOCAL_TRAIN_DIR, "Train Set")
copy_data_to_local(DRIVE_TEST_SOURCE, LOCAL_TEST_DIR, "Test Set")

GT_DIR = f"{LOCAL_TRAIN_DIR}/ground_truth"
MASKED_DIR = f"{LOCAL_TRAIN_DIR}/masked_images"
MASK_DIR = f"{LOCAL_TRAIN_DIR}/masks"

TEST_GT_DIR = f"{LOCAL_TEST_DIR}/ground_truth"
TEST_MASKED_DIR = f"{LOCAL_TEST_DIR}/masked_images"
TEST_MASK_DIR = f"{LOCAL_TEST_DIR}/masks"
NUM_TEST_IMAGES = 500

CKPT_DIR = "/content/drive/MyDrive/ver4_DDPM_checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)

ƒêang copy Train Set t·ª´ Drive sang Local (TƒÉng t·ªëc ƒë·ªô)...
Copy Train Set ho√†n t·∫•t!
ƒêang copy Test Set t·ª´ Drive sang Local (TƒÉng t·ªëc ƒë·ªô)...
Copy Test Set ho√†n t·∫•t!


In [None]:
IMG_SIZE = 256
BATCH_SIZE = 8
EPOCHS = 10
LR = 1e-5
GRAD_CLIP = 1.0
USE_AMP = True
EMA_DECAY = 0.999
NUM_WORKERS = 2

In [None]:
class ImageCompletionDataset(Dataset):
    def __init__(self, root_gt, root_masked, root_mask, size=256):
        self.gt_files = sorted([f for f in os.listdir(root_gt) if f.lower().endswith(('.png','.jpg','.jpeg'))])
        self.masked_files = sorted([f for f in os.listdir(root_masked) if f.lower().endswith(('.png','.jpg','.jpeg'))])
        self.mask_files = sorted([f for f in os.listdir(root_mask) if f.lower().endswith(('.png','.jpg','.jpeg'))])

        assert len(self.gt_files) == len(self.masked_files) == len(self.mask_files), \
            f"Counts mismatch: {len(self.gt_files)}, {len(self.masked_files)}, {len(self.mask_files)}"

        self.root_gt = root_gt
        self.root_masked = root_masked
        self.root_mask = root_mask

        self.img_tf = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
        ])
        self.mask_tf = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.gt_files)

    def __getitem__(self, idx):
        gt = Image.open(os.path.join(self.root_gt, self.gt_files[idx])).convert("RGB")
        masked = Image.open(os.path.join(self.root_masked, self.masked_files[idx])).convert("RGB")
        mask = Image.open(os.path.join(self.root_mask, self.mask_files[idx])).convert("L")

        gt = self.img_tf(gt)
        masked = self.img_tf(masked)
        mask = self.mask_tf(mask)
        mask = (mask < 0.5).float()

        return {"gt": gt, "masked": masked, "mask": mask}


In [None]:
dataset = ImageCompletionDataset(GT_DIR, MASKED_DIR, MASK_DIR, size=256)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

print("Dataset size:", len(dataset))

Dataset size: 5000


In [None]:
test_dataset_eval = ImageCompletionDataset(
    root_gt=TEST_GT_DIR,
    root_masked=TEST_MASKED_DIR,
    root_mask=TEST_MASK_DIR,
    size=256
)
print("Test Dataset size:", len(test_dataset_eval))

Test Dataset size: 500


In [None]:
# RESUME_FROM_CHECKPOINT = None
RESUME_FROM_CHECKPOINT = "/content/drive/MyDrive/ver4_DDPM_checkpoints/epoch_6"

print("Initializing Model Architecture...")
pipeline = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")

orig_unet = pipeline.unet
old_conv = None; old_conv_name = None
for name, module in orig_unet.named_modules():
    if isinstance(module, nn.Conv2d) and module.in_channels == 3:
        old_conv = module; old_conv_name = name; break

new_first_conv = nn.Conv2d(
    in_channels=7,
    out_channels=old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    dilation=old_conv.dilation,
    groups=old_conv.groups,
    bias=(old_conv.bias is not None)
)

with torch.no_grad():
    nn.init.kaiming_normal_(new_first_conv.weight)
    if new_first_conv.bias is not None:
        nn.init.zeros_(new_first_conv.bias)
    new_first_conv.weight[:, 4:7, :, :] = old_conv.weight.clone()
    nn.init.zeros_(new_first_conv.weight[:, 0:4, :, :])

def set_module_by_name(model, name, new_module):
    parts = name.split("."); obj = model
    for p in parts[:-1]: obj = getattr(obj, p)
    setattr(obj, parts[-1], new_module)

set_module_by_name(pipeline.unet, old_conv_name, new_first_conv)
pipeline.unet.in_channels = 7
pipeline.unet.config["in_channels"] = 7
print("Model skeleton created (7 Channels).")

if RESUME_FROM_CHECKPOINT:
    print(f"RESUMING weights from: {RESUME_FROM_CHECKPOINT}")
    import shutil
    drive_bin_path = os.path.join(RESUME_FROM_CHECKPOINT, "unet", "diffusion_pytorch_model.bin")
    local_bin_path = "/content/temp_unet_load.bin"

    if os.path.exists(drive_bin_path):
        shutil.copyfile(drive_bin_path, local_bin_path)
        try:
            state_dict = torch.load(local_bin_path, map_location="cpu", weights_only=False)
            pipeline.unet.load_state_dict(state_dict)
            print("Weights loaded successfully!")
        except Exception as e:
            print(f"Error loading weights: {e}")
    else:
        print(f"Bin file not found at {drive_bin_path}")

pipeline.unet.enable_gradient_checkpointing()
pipeline.to(device)

optimizer = torch.optim.AdamW(pipeline.unet.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
ema_params = {n: p.detach().clone().to(device) for n, p in pipeline.unet.named_parameters()}

def update_ema(model, ema_params, decay):
    with torch.no_grad():
        for n, p in model.named_parameters():
            if p.requires_grad: ema_params[n].mul_(decay).add_(p.data, alpha=1.0 - decay)

start_epoch = 1
global_step = 0

if RESUME_FROM_CHECKPOINT:
    state_path = os.path.join(RESUME_FROM_CHECKPOINT, "train_state.pt")
    local_state_path = "/content/temp_train_state_load.pt"

    if os.path.exists(state_path):
        print("Loading training state...")
        import shutil
        shutil.copyfile(state_path, local_state_path)

        state_dict = torch.load(local_state_path, map_location=device, weights_only=False)
        optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        scaler.load_state_dict(state_dict['scaler_state_dict'])
        start_epoch = state_dict['epoch'] + 1

        if 'global_step' in state_dict: global_step = state_dict['global_step']
        if 'ema_params' in state_dict:
             ema_params = {k: v.to(device) for k, v in state_dict['ema_params'].items()}

        print(f"Resumed! Start Epoch: {start_epoch}")

print(f"READY TO TRAIN from Epoch {start_epoch}")

In [None]:
import os
import shutil
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
# 4. TRAINING LOOP
EMA_DECAY = 0.999
num_train_timesteps = scheduler.config.num_train_timesteps
print(f"Total training timesteps: {num_train_timesteps}")

print(f"üöÄ Training starting from Epoch {start_epoch}...")

for epoch in range(start_epoch, EPOCHS + 1):
    pipeline.unet.train()
    running_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{EPOCHS}")

    for step, batch in enumerate(pbar):
        gt = batch["gt"].to(device)
        masked = batch["masked"].to(device)
        mask = batch["mask"].to(device)

        B = gt.shape[0]
        gt_in = gt * 2.0 - 1.0
        masked_in = masked * 2.0 - 1.0

        noise = torch.randn_like(gt_in).to(device)
        timesteps = torch.randint(0, num_train_timesteps, (B,), device=device).long()
        noisy_gt = scheduler.add_noise(gt_in, noise, timesteps)

        mask_channel = mask.to(dtype=noisy_gt.dtype)
        model_input = torch.cat([masked_in, mask_channel, noisy_gt], dim=1)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            noise_pred = pipeline.unet(model_input, timesteps).sample
            loss_mse = F.mse_loss(noise_pred, noise, reduction='none')
            loss_weights = mask_channel * 1.0 + (1 - mask_channel) * 0.05
            loss = (loss_mse * loss_weights).mean()

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(pipeline.unet.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()
        update_ema(pipeline.unet, ema_params, EMA_DECAY)

        running_loss += loss.item()
        global_step += 1

        if global_step % 10 == 0:
            current_avg_loss = running_loss / (step + 1)
            pbar.set_postfix({"loss": f"{current_avg_loss:.5f}"})

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch} finished. Avg loss: {epoch_loss:.6f}")

    # SAFE SAVE STRATEGY
    print(f"Saving checkpoint for Epoch {epoch} (Raw Weights)...")

    temp_save_dir = "/content/temp_save_checkpoint"
    if os.path.exists(temp_save_dir):
        shutil.rmtree(temp_save_dir)
    os.makedirs(temp_save_dir, exist_ok=True)

    pipeline.save_pretrained(temp_save_dir)

    torch.save({
        'epoch': epoch,
        'global_step': global_step,
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'ema_params': ema_params,
    }, os.path.join(temp_save_dir, "train_state.pt"))
    drive_latest_path = os.path.join(CKPT_DIR, "checkpoint_latest")
    print(f"   -> Overwriting 'checkpoint_latest'...")
    shutil.copytree(temp_save_dir, drive_latest_path, dirs_exist_ok=True)
    print("   -> Updated successfully.")

    SAVE_INTERVAL = 5
    if epoch % SAVE_INTERVAL == 0 or epoch == EPOCHS:
        drive_epoch_path = os.path.join(CKPT_DIR, f"epoch_{epoch}")
        print(f"   -> Saving milestone 'epoch_{epoch}'...")
        shutil.copytree(temp_save_dir, drive_epoch_path, dirs_exist_ok=True)
        print("   -> Milestone saved.")

    print("Checkpoint handling finished.")

print("Training finished.")

Total training timesteps: 1000
üöÄ Training starting from Epoch 7...


Epoch 7/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 7 finished. Avg loss: 0.005430
Saving checkpoint for Epoch 7 (Raw Weights)...
   -> Overwriting 'checkpoint_latest'...
   -> Updated successfully.
Checkpoint handling finished.


Epoch 8/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 8 finished. Avg loss: 0.005141
Saving checkpoint for Epoch 8 (Raw Weights)...
   -> Overwriting 'checkpoint_latest'...
   -> Updated successfully.
Checkpoint handling finished.


Epoch 9/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 9 finished. Avg loss: 0.004959
Saving checkpoint for Epoch 9 (Raw Weights)...
   -> Overwriting 'checkpoint_latest'...
   -> Updated successfully.
Checkpoint handling finished.


Epoch 10/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 10 finished. Avg loss: 0.005399
Saving checkpoint for Epoch 10 (Raw Weights)...
   -> Overwriting 'checkpoint_latest'...
   -> Updated successfully.
   -> Saving milestone 'epoch_10'...
   -> Milestone saved.
Checkpoint handling finished.
Training finished.
