In [23]:
# Import các library, classs và hàm cần thiết

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim

In [24]:
# Cài và import wandb
!pip install wandb
import wandb
# Gán key trực tiếp
os.environ["WANDB_API_KEY"] = "..."

# Login tự động
wandb.login(key=os.environ["WANDB_API_KEY"])





True

In [25]:
# # Download Input
# import kagglehub
# path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
# print("Path to dataset files:", path)
# !cp -r {path} ./celeba-dataset

In [26]:
# # Download Output
# import kagglehub
# path = kagglehub.dataset_download("romaingraux/bitmojis")
# print("Path to dataset files:", path)
# !cp -r {path} ./bitmojis-dataset

In [27]:
celeba_root = '/kaggle/input/celebahq-resized-256x256/celeba_hq_256'
bitmoji_root = '/kaggle/input/bitmojis/bitmojis'

In [28]:
# Thiết lập các tham số cho mô hình GAN

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=50, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)

Namespace(n_epochs=10, batch_size=16, lr=0.0002, b1=0.5, b2=0.999, n_cpu=8, latent_dim=100, img_size=224, channels=3, sample_interval=50)


In [None]:
wandb.init(
    project = "DTN",
    name = "Emoji Creation new",
    config = opt,
)

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/CONST,██▇▇▇▆▆▅▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/D,█▄▄▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁
loss/G,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/GAN,▃▄▁▃▂▅█▆▆▅▅▅▅▆▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▆▅▅▅▅▅▅▅
loss/TID,█▅▅▄▃▂▂▂▂▂▁▁▁▁▁▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/TV,█▅▅▄▃▂▂▂▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,0.0
loss/CONST,0.00047
loss/D,3.2959
loss/G,2.36442
loss/GAN,2.19717
loss/TID,0.11188
loss/TV,0.164


In [30]:
# Tạo thư mục images để lưu các ảnh kết quả

os.makedirs("images", exist_ok=True)
os.makedirs("samples", exist_ok=True)


# Thiết lập kích thước ảnh và check xem có sử dụng GPU hay không

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

In [31]:
!pip install facenet_pytorch



In [32]:
import torch.nn as nn
import torch.nn.functional as F
from facenet_pytorch import InceptionResnetV1

def create_face_extractor():
    """Simple FaceNet-based feature extractor"""
    facenet = InceptionResnetV1(pretrained='vggface2').eval()

    # Freeze parameters
    for param in facenet.parameters():
        param.requires_grad = False

    return facenet  # Output: 512D features

In [33]:
class Generator(nn.Module):
    def __init__(self, latent_dim=512, img_size=64):  # DTN specs
        super(Generator, self).__init__()
        self.init_size = img_size // 32  # Start from 2x2

        # Project 512D features to spatial dimensions
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 512 * self.init_size ** 2)
        )

        # DTN-style conv blocks (paper: 5 upscaling blocks + 1x1 convs)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),

            # Block 1: 2x2 -> 4x4
            nn.ConvTranspose2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 1),  # 1x1 conv (paper mentions this)

            # Block 2: 4x4 -> 8x8
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 1),  # 1x1 conv

            # Block 3: 8x8 -> 16x16
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 1),  # 1x1 conv

            # Block 4: 16x16 -> 32x32
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 1),  # 1x1 conv

            # Block 5: 32x32 -> 64x64
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 1),  # 1x1 conv

            # Final layer
            nn.Conv2d(32, 3, 3, stride=1, padding=1),  # RGB output
            nn.Tanh(),  # [-1, 1] range
        )

    def forward(self, z):
        # z = face features from f network (batch, 512)
        out = self.l1(z)
        out = out.view(out.shape[0], 512, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img  # (batch, 3, 64, 64)

In [34]:
class Discriminator(nn.Module):
    def __init__(self, img_size=152):  # DTN target domain size
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """DTN-style: Conv + BatchNorm + LeakyReLU(0.2)"""
            block = [
                nn.Conv2d(in_filters, out_filters, 4, 2, 1),  # stride=2
                nn.LeakyReLU(0.2, inplace=True)
            ]
            if bn:
                block.append(nn.BatchNorm2d(out_filters))
            return block

        # 6 blocks (paper specification)
        self.model = nn.Sequential(
            *discriminator_block(3, 64, bn=False),    # 152x152 -> 76x76
            *discriminator_block(64, 128),            # 76x76 -> 38x38
            *discriminator_block(128, 256),           # 38x38 -> 19x19
            *discriminator_block(256, 512),           # 19x19 -> 9x9
            *discriminator_block(512, 1024),          # 9x9 -> 4x4
            *discriminator_block(1024, 2048),         # 4x4 -> 2x2
        )

        # Ternary classification (not binary!)
        ds_size = img_size // (2 ** 6)  # After 6 blocks
        self.adv_layer = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global average pooling
            nn.Flatten(),
            nn.Linear(2048, 3),       # 3 classes instead of 1
        )

    def forward(self, img):
        # img: (batch, 3, 152, 152)
        out = self.model(img)
        validity = self.adv_layer(out)
        return validity  # (batch, 3) - probabilities for each class

In [35]:
class DTNLoss:
    def __init__(self, alpha=300, beta=10, gamma=0.1):
        self.alpha = alpha  # L_CONST weight (critical!)
        self.beta = beta    # L_TID weight
        self.gamma = gamma  # L_TV weight

        self.mse_loss = nn.MSELoss()
        self.ce_loss = nn.CrossEntropyLoss()

    def discriminator_loss(self, D, g_outputs_s, g_outputs_t, x_t_real):
        """
        L_D = -E[log D1(g(f(x_s)))] - E[log D2(g(f(x_t)))] - E[log D3(x_t_real)]
        """
        batch_size = g_outputs_s.size(0)
        device = g_outputs_s.device

        # Class labels
        label_s = torch.zeros(batch_size, dtype=torch.long, device=device)      # Class 0
        label_t = torch.ones(batch_size, dtype=torch.long, device=device)       # Class 1
        label_real = torch.full((batch_size,), 2, dtype=torch.long, device=device) # Class 2

        # Discriminator predictions
        d_s = D(g_outputs_s)
        d_t = D(g_outputs_t)
        d_real = D(x_t_real)

        # Losses
        loss_s = self.ce_loss(d_s, label_s)
        loss_t = self.ce_loss(d_t, label_t)
        loss_real = self.ce_loss(d_real, label_real)

        return (loss_s + loss_t + loss_real)

    def generator_loss(self, D, f, g, x_s, x_t, x_t_real):
        """
        L_G = L_GAN_G + α*L_CONST + β*L_TID + γ*L_TV
        """
        batch_size = x_s.size(0)
        device = x_s.device

        # Forward pass
        f_x_s = f(x_s)  # (batch, 512)
        f_x_t = f(x_t)  # (batch, 512)

        g_f_x_s = g(f_x_s)  # (batch, 3, 64, 64)
        g_f_x_t = g(f_x_t)  # (batch, 3, 64, 64)

        # Upscale for discriminator (64x64 -> 152x152)
        g_f_x_s_up = F.interpolate(g_f_x_s, size=(152, 152), mode='bilinear')
        g_f_x_t_up = F.interpolate(g_f_x_t, size=(152, 152), mode='bilinear')

        # L_GAN_G: Fool discriminator (want class 2 = real)
        target_real = torch.full((batch_size,), 2, dtype=torch.long, device=device)

        d_s = D(g_f_x_s_up)
        d_t = D(g_f_x_t_up)

        loss_gan_s = self.ce_loss(d_s, target_real)
        loss_gan_t = self.ce_loss(d_t, target_real)
        loss_gan = loss_gan_s + loss_gan_t

        # L_CONST: f-constancy
        f_g_f_x_s = f(F.interpolate(g_f_x_s, size=(160, 160), mode='bilinear'))
        loss_const = self.mse_loss(f_x_s, f_g_f_x_s)

        # L_TID: Identity mapping for target domain
        x_t_64 = F.interpolate(x_t, size=(64, 64), mode='bilinear')
        loss_tid = self.mse_loss(x_t_64, g_f_x_t)

        # L_TV: Total variation
        loss_tv = self.total_variation_loss(g_f_x_s)

        # Combine losses
        total_loss = loss_gan + self.alpha * loss_const + self.beta * loss_tid + self.gamma * loss_tv

        return total_loss, {
            'gan': loss_gan.item(),
            'const': loss_const.item(),
            'tid': loss_tid.item(),
            'tv': loss_tv.item()
        }

    def total_variation_loss(self, images):
        """Safe isotropic TV loss với epsilon"""
        diff_h = torch.pow(images[:, :, :, 1:] - images[:, :, :, :-1], 2)
        diff_v = torch.pow(images[:, :, 1:, :] - images[:, :, :-1, :], 2)
        
        # CRITICAL: Add epsilon before sqrt
        diff_h_matched = diff_h[:, :, :-1, :]
        diff_v_matched = diff_v[:, :, :, :-1]
        gradient_magnitude = torch.sqrt(diff_h_matched + diff_v_matched + 1e-8)
        
        return torch.mean(gradient_magnitude)

In [36]:
import os
import glob
import numpy as np
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from facenet_pytorch import MTCNN

def crop_face_tight(pil_image, target_size=(160, 160), padding_ratio=0.3):
    """Face cropping for source"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        margin = int(target_size[0] * padding_ratio)
        detector = MTCNN(image_size=target_size, margin=margin, keep_all=False, device=device)
        
        face_tensor = detector(pil_image)
        if face_tensor is not None:
            face_array = face_tensor.permute(1, 2, 0).cpu().numpy()
            face_array = ((face_array + 1) * 127.5).clip(0, 255).astype(np.uint8)
            return Image.fromarray(face_array)
        else:
            return generous_center_crop(pil_image, target_size, crop_ratio=0.7)
    except Exception as e:
        return generous_center_crop(pil_image, target_size, crop_ratio=0.7)

def crop_face_emoji(pil_image, target_size=(152, 152), padding_ratio=0.25):
    """Face cropping for emoji (unchanged from before)"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        margin = int(target_size[0] * padding_ratio)
        detector = MTCNN(image_size=target_size, margin=margin, keep_all=False, device=device)
        
        face_tensor = detector(pil_image)
        if face_tensor is not None:
            face_array = face_tensor.permute(1, 2, 0).cpu().numpy()
            face_array = ((face_array + 1) * 127.5).clip(0, 255).astype(np.uint8)
            return Image.fromarray(face_array)
        else:
            return generous_center_crop_emoji(pil_image, target_size, crop_ratio=0.75)
    except Exception as e:
        return generous_center_crop_emoji(pil_image, target_size, crop_ratio=0.75)

def generous_center_crop(pil_image, target_size=(160, 160), crop_ratio=0.7):
    """Generous center crop for real faces"""
    w, h = pil_image.size
    crop_w, crop_h = w * crop_ratio, h * crop_ratio
    left = (w - crop_w) / 2
    top = (h - crop_h) / 2
    cropped = pil_image.crop((left, top, left + crop_w, top + crop_h))
    return cropped.resize(target_size, Image.LANCZOS)

def generous_center_crop_emoji(pil_image, target_size=(152, 152), crop_ratio=0.75):
    """Generous center crop for emoji"""
    w, h = pil_image.size
    crop_w, crop_h = w * crop_ratio, h * crop_ratio
    left = (w - crop_w) / 2
    top = (h - crop_h) / 2
    cropped = pil_image.crop((left, top, left + crop_w, top + crop_h))
    return cropped.resize(target_size, Image.LANCZOS)

In [37]:
class CelebADataset(Dataset):
    def __init__(self, data_dir, transform=None, face_crop=True, target_size=(152, 152), 
                 padding_ratio=0.3, subset_size=None, subset_method='random', seed=42):
        self.data_dir = data_dir
        self.transform = transform
        self.face_crop = face_crop
        self.target_size = target_size
        self.padding_ratio = padding_ratio
        
        # Get all image paths (CelebA chủ yếu là JPG)
        all_image_paths = glob.glob(os.path.join(data_dir, "*.jpg"))
        total_images = len(all_image_paths)
        
        # Apply subsetting if requested
        if subset_size and subset_size < total_images:
            self.image_paths = self._create_subset(all_image_paths, subset_size, subset_method, seed)
            print(f"📊 Using subset: {len(self.image_paths)}/{total_images} CelebA images ({subset_method})")
        else:
            self.image_paths = all_image_paths
            print(f"📊 Using full dataset: {len(self.image_paths)} CelebA images")
        
        if face_crop:
            print(f"✅ Source face cropping ENABLED with padding_ratio={padding_ratio}")

    def _create_subset(self, all_paths, subset_size, method='random', seed=42):
        random.seed(seed)  # For reproducible results
        
        if method == 'random':
            return random.sample(all_paths, subset_size)
        elif method == 'first':
            return all_paths[:subset_size]
        elif method == 'interval':
            interval = len(all_paths) // subset_size
            return all_paths[::interval][:subset_size]
        elif method == 'balanced':
            return random.sample(all_paths, subset_size)
        else:
            return random.sample(all_paths, subset_size)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert('RGB')  # CelebA JPG → RGB trực tiếp
            
            if self.face_crop:
                image = crop_face_tight(image, self.target_size, self.padding_ratio)
            else:
                image = image.resize(self.target_size)
            
            if self.transform:
                image = self.transform(image)
                
            return image, 0
            
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            black_img = Image.new('RGB', self.target_size, (0, 0, 0))
            if self.transform:
                black_img = self.transform(black_img)
            return black_img, 0


class BitmojiDataset(Dataset):
    def __init__(self, data_dir, transform=None, face_crop=True, target_size=(152, 152),
                 padding_ratio=0.25, subset_size=None, subset_method='random', seed=42):
        self.data_dir = data_dir
        self.transform = transform
        self.face_crop = face_crop
        self.target_size = target_size
        self.padding_ratio = padding_ratio
        
        all_image_paths = glob.glob(os.path.join(data_dir, "*.png"))
        total_images = len(all_image_paths)
        
        # Apply subsetting if requested
        if subset_size and subset_size < total_images:
            self.image_paths = self._create_subset(all_image_paths, subset_size, subset_method, seed)
            print(f"📊 Using subset: {len(self.image_paths)}/{total_images} Bitmoji images ({subset_method})")
        else:
            self.image_paths = all_image_paths
            print(f"📊 Using full dataset: {len(self.image_paths)} Bitmoji images")
        
        # Check format distribution
        png_count = sum(1 for p in self.image_paths if p.lower().endswith('.png'))
        jpg_count = len(self.image_paths) - png_count
        print(f"📊 Format distribution: {png_count} PNG, {jpg_count} JPG/JPEG")
        
        if face_crop:
            print(f"Target face cropping ENABLED with padding_ratio={padding_ratio}")

    def _create_subset(self, all_paths, subset_size, method='random', seed=42):
        """Create dataset subset using different methods"""
        random.seed(seed)
        
        if method == 'random':
            return random.sample(all_paths, subset_size)
        elif method == 'first':
            return all_paths[:subset_size]
        elif method == 'interval':
            interval = len(all_paths) // subset_size
            return all_paths[::interval][:subset_size]
        else:
            return random.sample(all_paths, subset_size)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            image = Image.open(img_path)
            
            # Proper PNG transparency handling
            if image.mode == 'RGBA':
                # Create white background for transparent PNGs
                background = Image.new('RGB', image.size, (255, 255, 255))
                background.paste(image, mask=image.split()[-1])  # Use alpha as mask
                image = background
            elif image.mode == 'P':
                # Handle palette mode with possible transparency
                image = image.convert('RGBA')
                background = Image.new('RGB', image.size, (255, 255, 255))
                background.paste(image, mask=image.split()[-1] if len(image.split()) == 4 else None)
                image = background
            elif image.mode != 'RGB':
                # Convert any other mode to RGB
                image = image.convert('RGB')
            
            if self.face_crop:
                image = crop_face_emoji(image, self.target_size, self.padding_ratio)
            else:
                image = image.resize(self.target_size)
            
            if self.transform:
                image = self.transform(image)
                
            return image, 0
            
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            black_img = Image.new('RGB', self.target_size, (0, 0, 0))
            if self.transform:
                black_img = self.transform(black_img)
            return black_img, 0


In [38]:
# Transforms
face_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

emoji_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

source_dataset = CelebADataset(
    data_dir=celeba_root, 
    transform=face_transform,
    face_crop=True,
    target_size=(160, 160),
    padding_ratio=0.35,
    #subset_size=10000,      
    #subset_method='first'  
)

target_dataset = BitmojiDataset(
    data_dir=bitmoji_root, 
    transform=emoji_transform,
    face_crop=False,
    target_size=(152, 152),
    padding_ratio=0.3,
    #subset_size=10000,       
    #subset_method='first'  
)

# Data loaders
source_loader = torch.utils.data.DataLoader(
    source_dataset, 
    batch_size=opt.batch_size, 
    shuffle=True, 
    num_workers=0,
    pin_memory=True,
    drop_last=True
)

target_loader = torch.utils.data.DataLoader(
    target_dataset, 
    batch_size=opt.batch_size, 
    shuffle=True, 
    num_workers=0,
    pin_memory=True,
    drop_last=True
)

print(f"Ready for fast training!")
print(f"Source: {len(source_dataset)} images")
print(f"Target: {len(target_dataset)} images")

📊 Using full dataset: 30000 CelebA images
✅ Source face cropping ENABLED with padding_ratio=0.35
📊 Using full dataset: 130227 Bitmoji images
📊 Format distribution: 130227 PNG, 0 JPG/JPEG
🚀 Ready for fast training!
📊 Source: 30000 images
📊 Target: 130227 images
⏱️  Expected speedup: ~6x faster


In [39]:
# Khởi tạo generator and discriminator và thêm wandb để track
g_net = Generator().cuda()
d_net = Discriminator().cuda()
f_net = create_face_extractor().cuda()

wandb.watch(g_net, log="all", log_freq=200)
wandb.watch(d_net, log="all", log_freq=200)

In [40]:
# Optimizers
optimizer_G = torch.optim.Adam(g_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(d_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

In [41]:
# Loss function
criterion = DTNLoss(alpha=100, beta=1, gamma=0.05)
# Định dạng Tensor sử dụng GPU nếu có và device type
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [42]:
# Thêm vào sau phần khởi tạo models, trước training loop:
import os
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(g_net, d_net, f_net, optimizer_G, optimizer_D, batches_done, loss_dict):
    """Save model checkpoint"""
    checkpoint = {
        'batches_done': batches_done,
        'generator_state_dict': g_net.state_dict(),
        'discriminator_state_dict': d_net.state_dict(),
        'face_extractor_state_dict': f_net.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'loss_dict': loss_dict,
        'model_config': {
            'latent_dim': 512,
            'img_size': 64,
            'discriminator_img_size': 152
        }
    }
    
    checkpoint_path = f"{checkpoint_dir}/dtn_checkpoint_batch_{batches_done}.pth"
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")
    return checkpoint_path

In [43]:
import matplotlib.pyplot as plt
import numpy as np
# Tạo face features cố định để theo dõi chất lượng qua các epoch
# Sample một batch từ source để tạo fixed features
fixed_faces, _ = next(iter(source_loader))
fixed_faces = fixed_faces[:16].to(device)
with torch.no_grad():
    z_fixed = f_net(fixed_faces)  # Fixed face features

# ✅ THÊM: Fixed target images để so sánh
fixed_targets, _ = next(iter(target_loader))
fixed_targets = fixed_targets[:16].to(device)

print(f"Fixed features shape: {z_fixed.shape}")

# Sửa training loop để dừng ở batch 1000:
TARGET_BATCH = 1000

# Training loop theo style mong muốn
for epoch in range(opt.n_epochs):
    for i, ((source_imgs, _), (target_imgs, _)) in enumerate(zip(source_loader, target_loader)):

        # Lưu ảnh mẫu sau mỗi batches được xử lí
        batches_done = epoch * len(source_loader) + i

        # Chuẩn bị dữ liệu
        batch_size = min(source_imgs.size(0), target_imgs.size(0))
        source_imgs = source_imgs[:batch_size].to(device)
        target_imgs = target_imgs[:batch_size].to(device)

        # -----------------
        #  Train Generator
        # -----------------

        #if batches_done % 3 != 0:
        optimizer_G.zero_grad()

        # Generate emoji from face features
        g_loss, loss_dict = criterion.generator_loss(
            d_net, f_net, g_net,
            source_imgs, target_imgs, target_imgs
        )

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        #if batches_done % 2 == 0:
        optimizer_D.zero_grad()

        with torch.no_grad():
            f_x_s = f_net(source_imgs)
            f_x_t = f_net(target_imgs)
            g_f_x_s = g_net(f_x_s)
            g_f_x_t = g_net(f_x_t)

            g_f_x_s_up = F.interpolate(g_f_x_s, size=(152, 152), mode='bilinear')
            g_f_x_t_up = F.interpolate(g_f_x_t, size=(152, 152), mode='bilinear')

        d_loss = criterion.discriminator_loss(
            d_net, g_f_x_s_up, g_f_x_t_up, target_imgs
        )

        d_loss.backward()
        optimizer_D.step()

        # In ra thông tin về quá trình huấn luyện
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f] [GAN: %.4f] [CONST: %.4f] [TID: %.4f] [TV: %.4f]"
            % (epoch, opt.n_epochs, i, len(source_loader),
                d_loss.item(),
                g_loss.item(),
                loss_dict["gan"], loss_dict["const"],
                loss_dict["tid"], loss_dict["tv"])
        )

        # Wandb log loss
        log_dict = {
            "loss/G": g_loss.item(),
            "loss/D": d_loss.item(),
            "loss/GAN": loss_dict["gan"],
            "loss/CONST": loss_dict["const"],
            "loss/TID": loss_dict["tid"],
            "loss/TV": loss_dict["tv"],
            "epoch": epoch,
        }
        wandb.log(log_dict, step=batches_done)

        if batches_done % opt.sample_interval == 0:
            # Generate current batch samples
            with torch.no_grad():
                face_features = f_net(source_imgs[:16])
                gen_imgs = g_net(face_features)
                gen_fixed = g_net(z_fixed)
                
                # g(f(g(x))) reconstruction
                # Từ generated images → features → reconstruction 
                target_up = F.interpolate(target_imgs, size=(152, 152), mode='bilinear')
                target_imgs_feature = f_net(target_up)
                reconstructed_target = g_net(target_imgs_feature)

            # # Lưu ảnh vào local storage
            # save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=4, normalize=True)

            # Log grid của ảnh được gen hiện tại
            gen_imgs_vis = (gen_imgs + 1) / 2  # [-1,1] -> [0,1]
            grid_current = make_grid(gen_imgs_vis, nrow=4)
            wandb.log({
                "samples/current_batch": wandb.Image(grid_current, caption=f"step {batches_done}")
            }, step=batches_done)

            # Log grid ảnh từ features cố định
            with torch.no_grad():
                gen_fixed = g_net(z_fixed)
            gen_fixed_vis = (gen_fixed + 1)/2
            grid_fixed = make_grid(gen_fixed_vis, nrow=4)
            wandb.log({
                "samples/fixed_faces": wandb.Image(grid_fixed, caption=f"fixed faces @ step {batches_done}")
            }, step=batches_done)

            # Log source faces for comparison
            source_vis = (fixed_faces + 1) / 2
            grid_source = make_grid(source_vis, nrow=4)
            wandb.log({
                "samples/fixed_source_faces": wandb.Image(grid_source, caption=f"fixed_source faces")
            }, step=batches_done)

            # Log source faces for comparison
            source_cur = (source_imgs[:16] + 1) / 2
            grid_source = make_grid(source_cur, nrow=4)
            wandb.log({
                "samples/current_source_faces": wandb.Image(grid_source, caption=f"current_source faces")
            }, step=batches_done)

            # Log target faces for comparison  
            #target_vis = (fixed_targets + 1) / 2
            #grid_target = make_grid(target_vis, nrow=4)
            #wandb.log({
            #    "samples/fixed_target": wandb.Image(grid_target, 
            #        caption="Fixed target emoji (reference)")
            #}, step=batches_done)
            
            # Log current batch target for comparison
            current_target_vis = (target_imgs[:16] + 1) / 2
            grid_current_target = make_grid(current_target_vis, nrow=4)
            wandb.log({
                "samples/current_batch_target": wandb.Image(grid_current_target,
                    caption=f"Current batch targets @ step {batches_done}")
            }, step=batches_done)

            reconstructed_target_vis = (reconstructed_target + 1) / 2
            grid_reconstructed_target = make_grid(reconstructed_target_vis, nrow=4)
            wandb.log({
                "samples/emoji_to_emoji": wandb.Image(grid_reconstructed_target,
                    caption=f"Current batch targets @ step {batches_done}")
            }, step=batches_done)

            print("logged samples")

        # # Break early if we've processed enough batches for demo
        # if i >= 100:  # Limit for faster demo
        #     break
        if batches_done % 100 == 0 and batches_done > 0:
            save_checkpoint(g_net, d_net, f_net, optimizer_G, optimizer_D, batches_done, loss_dict)
            
        # THÊM: Check if reached target batch
        if batches_done >= TARGET_BATCH:
            print(f"Reached target batch {TARGET_BATCH}, saving checkpoint and stopping...")
            final_checkpoint = save_checkpoint(
                g_net, d_net, f_net, 
                optimizer_G, optimizer_D, 
                batches_done, loss_dict
            )
            print(f"Training completed! Final checkpoint: {final_checkpoint}")
            wandb.finish()
            exit()  # Dừng training
            break

print("Training completed!")

Fixed features shape: torch.Size([16, 512])
[Epoch 0/10] [Batch 0/1875] [D loss: 3.7330] [G loss: 3.3843] [GAN: 2.2540] [CONST: 0.0040] [TID: 0.7157] [TV: 0.2853]
logged samples
[Epoch 0/10] [Batch 1/1875] [D loss: 6.9208] [G loss: 5.2542] [GAN: 4.1399] [CONST: 0.0038] [TID: 0.7199] [TV: 0.2872]
[Epoch 0/10] [Batch 2/1875] [D loss: 5.1734] [G loss: 4.1643] [GAN: 3.0649] [CONST: 0.0039] [TID: 0.6966] [TV: 0.2914]
[Epoch 0/10] [Batch 3/1875] [D loss: 4.3150] [G loss: 3.4459] [GAN: 2.3349] [CONST: 0.0039] [TID: 0.7097] [TV: 0.2986]
[Epoch 0/10] [Batch 4/1875] [D loss: 4.2921] [G loss: 3.3617] [GAN: 2.2511] [CONST: 0.0039] [TID: 0.7069] [TV: 0.3016]
[Epoch 0/10] [Batch 5/1875] [D loss: 3.5878] [G loss: 3.3515] [GAN: 2.2496] [CONST: 0.0040] [TID: 0.6903] [TV: 0.3030]
[Epoch 0/10] [Batch 6/1875] [D loss: 3.4084] [G loss: 3.3240] [GAN: 2.2507] [CONST: 0.0039] [TID: 0.6636] [TV: 0.3005]
[Epoch 0/10] [Batch 7/1875] [D loss: 3.4519] [G loss: 3.2805] [GAN: 2.2423] [CONST: 0.0036] [TID: 0.6674] [T

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/CONST,█▆▆▆▅▃▃▃▂▂▂▂▂▂▂▁▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/D,█▅▄▂▃▂▄▂▂▁▂▂▁▂▂▁▁▂▂▁▁▁▁▃▃▂▄▂▁▁▁▁▁▁▁▁▁▂▁▁
loss/G,█▆▅▅▄▄▄▄▄▄▄▃▃▂▂▂▂▂▂▂▁▂▁▂▁▁▁▂▁▁▂▂▁▁▁▁▁▁▁▁
loss/GAN,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/TID,█▆▃▃▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▂▁▁▂▁▂
loss/TV,█▄▄▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,0.0
loss/CONST,0.00026
loss/D,3.29587
loss/G,2.32977
loss/GAN,2.19723
loss/TID,0.09923
loss/TV,0.14873


[Epoch 1/10] [Batch 0/1875] [D loss: 3.2959] [G loss: 2.3298] [GAN: 2.1972] [CONST: 0.0002] [TID: 0.1012] [TV: 0.1487]


Error: You must call wandb.init() before wandb.log()

In [None]:
# Import thư viện cần thiết
import os
from IPython.display import FileLink, display

# Đường dẫn đến checkpoint file
checkpoint_path = "checkpoints/dtn_checkpoint_batch_200.pth"

# Method 1: Direct download nếu file đã trong /kaggle/working
if os.path.exists(checkpoint_path):
    print(f"Found checkpoint: {checkpoint_path}")
    
    # Tạo download link trực tiếp
    display(FileLink(checkpoint_path))
    
    print("🔗 Click vào link phía trên để download file!")
else:
    print(f"Checkpoint not found at: {checkpoint_path}")
    # List files trong thư mục checkpoints
    if os.path.exists("checkpoints"):
        print("Available checkpoints:")
        for f in os.listdir("checkpoints"):
            print(f"  - {f}")
