In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from diffusers import DDPMScheduler
from transformers import T5Tokenizer, T5EncoderModel
from pytorch_wavelets import DWTForward, DWTInverse
from torch import optim
from datetime import datetime
from diffusers import StableDiffusionXLImg2ImgPipeline


# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dir = r'C:\Datasets\Celeb-HQ-Split\train'
test_dir = r'C:\Datasets\Celeb-HQ-Split\test'
output_dir = r'C:\Datasets\Denoised from models\SDXL_Enhanced_Celeb-HQ'
checkpoint_dir = r"C:\Datasets\Checkpoints\sdxl_t5_litevae_txt"
BATCH_SIZE = 32
IMG_SIZE = (1024, 1024)
NUM_SAMPLES_PER_IMAGE = 1
STRENGTH = 0.9
GUIDANCE_SCALE = 10
os.makedirs(checkpoint_dir, exist_ok=True)


In [None]:
def resize_with_padding(image, target_size=(1024, 1024)):
    old_size = image.size
    ratio = min(target_size[0] / old_size[0], target_size[1] / old_size[1])
    new_size = tuple([int(x * ratio) for x in old_size])
    image = image.resize(new_size, Image.LANCZOS)
    new_image = Image.new("RGB", target_size, (0, 0, 0))
    new_image.paste(image, ((target_size[0] - new_size[0]) // 2, (target_size[1] - new_size[1]) // 2))
    return new_image

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [
            os.path.join(root_dir, img_name)
            for img_name in os.listdir(root_dir)
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        image = resize_with_padding(image, IMG_SIZE)
        if self.transform:
            image = self.transform(image)
        return image, img_path

# Transforms and loaders
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

train_dataset = CustomImageDataset(train_dir, transform=data_transforms)
test_dataset = CustomImageDataset(test_dir, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
class LiteVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt = DWTForward(J=1, mode='zero', wave='haar')
        self.iwt = DWTInverse(mode='zero', wave='haar')
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 4, 3, 1, 1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, 1, 1), nn.Tanh()
        )

    def encode(self, x):
        yl, _ = self.dwt(x)
        return self.encoder(yl)

    def decode(self, z):
        decoded = self.decoder(z)
        return self.iwt((decoded, [None]))

class TinyUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU()
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 4, 3, padding=1)
        )

    def forward(self, x, t, encoder_hidden_states=None):
        return self.block2(self.block1(x))


In [None]:
vae = LiteVAE().to(device)
unet = TinyUNet().to(device)
text_encoder = T5EncoderModel.from_pretrained("t5-base").to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-base")
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
optimizer = optim.AdamW(unet.parameters(), lr=1e-4)


In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [
            os.path.join(root_dir, img_name)
            for img_name in os.listdir(root_dir)
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        image = resize_with_padding(image, (1024, 1024))

        if self.transform:
            image = self.transform(image)

        return image, img_path

# Define transformations
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

# Load datasets
def load_dataset():
    train_dataset = CustomImageDataset(train_dir, transform=data_transforms)
    test_dataset = CustomImageDataset(test_dir, transform=data_transforms)
    return train_dataset, test_dataset

train_dataset, test_dataset = load_dataset()
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def train_t5_litevae_model(dataloader, epochs=2, sleep_per_step=0):
    print("🚀 Training SDXL with T5 + LiteVAE started...\n", flush=True)
    start_time = time.time()
    for epoch in range(epochs):
        unet.train()
        total_loss = 0
        epoch_start = time.time()
        print(f"📘 Epoch {epoch + 1}/{epochs}...\n", flush=True)

        for step, (images, img_paths) in enumerate(dataloader):
            image_name = os.path.basename(img_paths[0])
            print(f"🖼️  Training on image {step + 1}/{len(dataloader)}: {image_name}", flush=True)

            images = images.to(device)
            with torch.no_grad():
                latents = vae.encode(images)

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

            prompts = [
    "young woman", "a smiling man", "an old man with glasses", "a person with curly hair",
    "a serious expression portrait", "a face with makeup", "a person wearing hat", "a person with long hair",
    "highly detailed, photorealistic portrait of a person with narrow eyes",
    "straight eyebrows, and a pointy nose",
    "The person could be of any gender, age, or ethnicity",
    "The photo is taken in a professional, neutral studio environment with soft lighting and ultra-high resolution"
] * images.size(0)
            inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
            with torch.no_grad():
                encoder_hidden_states = text_encoder(**inputs).last_hidden_state

            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)
            loss = F.mse_loss(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            torch.cuda.synchronize()
            time.sleep(sleep_per_step)

        avg_loss = total_loss / len(dataloader)
        epoch_time = time.time() - epoch_start
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        print(f"\n✅ Epoch {epoch+1} complete | Avg Loss: {avg_loss:.4f} | Duration: {epoch_time:.2f} sec\n", flush=True)
        torch.save(unet.state_dict(), os.path.join(checkpoint_dir, f"unet_epoch_{epoch+1}.pt"))
        torch.save(vae.state_dict(), os.path.join(checkpoint_dir, f"vae_epoch_{epoch+1}.pt"))
        with open(os.path.join(checkpoint_dir, f"epoch_{epoch+1}_log.txt"), "w") as f:
            f.write(f"Epoch: {epoch+1}\nAverage Loss: {avg_loss:.4f}\nTimestamp: {timestamp}\nEpoch Duration: {epoch_time:.2f} sec\n")

    total_time = time.time() - start_time
    print(f"🎯 Training finished in {(total_time/60):.2f} minutes.")
    print(f"📁 Logs saved to: {checkpoint_dir}", flush=True)
    
train_t5_litevae_model(train_loader, epochs=1, sleep_per_step=0)


In [None]:


@torch.no_grad()
def generate_and_save_images_with_custom_model(dataloader, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    model = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True
    ).to(device)
    model.enable_attention_slicing()

    prompts = [
        "young woman", "a smiling man", "an old man with glasses", "a person with curly hair",
        "a serious expression portrait", "a face with makeup", "a person wearing hat", "a person with long hair",
        "highly detailed, photorealistic portrait of a person with narrow eyes",
        "straight eyebrows, and a pointy nose",
        "The person could be of any gender, age, or ethnicity",
        "The photo is taken in a professional, neutral studio environment with soft lighting and ultra-high resolution"
    ]

    for batch_idx, (images, img_paths) in enumerate(dataloader):
        for i in range(NUM_SAMPLES_PER_IMAGE):
            img = images[0].unsqueeze(0).to(device)
            img_pil = transforms.ToPILImage()(img.squeeze(0).cpu().detach())

            prompt = prompts[i % len(prompts)]

            result = model(
                prompt=prompt,
                image=img_pil,
                strength=STRENGTH,
                guidance_scale=GUIDANCE_SCALE
            )
            output_image = result.images[0]

            filename = os.path.splitext(os.path.basename(img_paths[0]))[0]
            save_path = os.path.join(output_dir, f"{filename}_sample_{i}.png")
            os.makedirs(output_dir, exist_ok=True)
            output_image.save(save_path)

    print(f"🖼️  All generated images saved to: {output_dir}")

# Call the function
generate_and_save_images_with_custom_model(test_loader, output_dir)
