In [1]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torchvision import transforms
from PIL import Image
import torch
import random
import numpy as np

# Set seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Define transforms
resize_transform = transforms.Resize((128, 128))
to_tensor_transform = transforms.ToTensor()
normalize_transform = transforms.Normalize([0.5] * 3, [0.5] * 3)
color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)

def paired_augmentation(input_img, output_img):
    if random.random() < 0.5:
        input_img = transforms.functional.hflip(input_img)
        output_img = transforms.functional.hflip(output_img)
    if random.random() < 0.5:
        input_img = transforms.functional.vflip(input_img)
        output_img = transforms.functional.vflip(output_img)
    angle = random.uniform(-15, 15)
    input_img = transforms.functional.rotate(input_img, angle)
    output_img = transforms.functional.rotate(output_img, angle)
    input_img = color_jitter(input_img)
    return input_img, output_img

# PyTorch Dataset class
class HFDataset(Dataset):
    def __init__(self, hf_ds, augment=False):
        self.ds = hf_ds
        self.augment = augment

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        input_img = item["input"].convert("RGB")
        output_img = item["output"].convert("RGB")
        prompt = item["prompt"].lower()

        input_img = resize_transform(input_img)
        output_img = resize_transform(output_img)

        if self.augment:
            input_img, output_img = paired_augmentation(input_img, output_img)

        input_img = to_tensor_transform(input_img)
        output_img = to_tensor_transform(output_img)

        input_img = normalize_transform(input_img)
        output_img = normalize_transform(output_img)

        return {
            "input_image": input_img,
            "output_image": output_img,
            "prompt": prompt
        }

# 1. Load synthetic dataset and take first 3000
synthetic_raw = load_dataset("bhavya777/synthetic-colored-shapes")
synthetic_ds = synthetic_raw["train"].select(range(4000))
synthetic_pt = HFDataset(synthetic_ds, augment=True)

# 2. Load augmented dataset (no extra augment)
augmented_raw = load_dataset("bhavya777/augmented-colored-shapes")
augmented_pt = HFDataset(augmented_raw["train"], augment=False)

# 3. Combine
full_dataset = ConcatDataset([synthetic_pt, augmented_pt])

# 4. Split
total_len = len(full_dataset)
train_len = int(0.8 * total_len)
val_len = total_len - train_len

train_ds, val_ds = random_split(full_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(42))

# 5. Dataloaders
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


In [6]:
import torch
from torch import nn
from diffusers import UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel

class ColorDenoisingUNet(nn.Module):
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        super().__init__()
        self.device = device

        # CLIP text encoder
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

        # UNet architecture
        self.unet = UNet2DConditionModel(
            sample_size=128,
            in_channels=3,
            out_channels=3,
            down_block_types=(
                "DownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
            ),
            mid_block_type="UNetMidBlock2DCrossAttn",
            up_block_types=(
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "UpBlock2D",
            ),
            # block_out_channels=(32,64, 128),
            block_out_channels=(64,32,64),
            
            # block_out_channels=(64, 128,256),
            layers_per_block=1,
            attention_head_dim=12,
            cross_attention_dim=512,
            only_cross_attention=True,
            dropout=0.1,
            act_fn="silu",
        ).to(device)

    def get_text_embedding(self, prompt, x):
        if isinstance(prompt, str):
            prompt = [prompt] * x.size(0)

        text_inputs = self.tokenizer(
            prompt, padding="max_length", max_length=77, return_tensors="pt"
        ).to(self.device)

        text_embeddings = self.text_encoder(**text_inputs).last_hidden_state
        return text_embeddings  # (B, 77, 512)

    def forward(self, x: torch.Tensor, timestep: torch.Tensor, prompt):
        # print(f"[DEBUG] Input x shape: {x.shape}")               # (B, 3, 128, 128)
        # print(f"[DEBUG] Timestep shape: {timestep.shape}")       # (B,)
        # print(f"[DEBUG] Prompt: {prompt}")                       # batched prompt or string

        encoder_hidden_states = self.get_text_embedding(prompt, x)
        # print(f"[DEBUG] Encoder hidden states shape: {encoder_hidden_states.shape}")  # (B, 77, 512)

        out = self.unet(
            sample=x,
            timestep=timestep,
            encoder_hidden_states=encoder_hidden_states
        ).sample

        # print(f"[DEBUG] Output shape: {out.shape}")              # Should be (B, 3, 128, 128)
        return out


In [9]:
import os
import math
import torch
import torch.nn.functional as F
from torch.optim import AdamW, lr_scheduler
from tqdm import tqdm
from torchvision.utils import make_grid, save_image
import wandb
import piq
import lpips
from kornia.color import rgb_to_lab

# ---------------- CONFIGS ----------------

loss_config = {
    "mse": True,
    "lpips": True,
    "ssim": True,
    "color": True
}
loss_weights = {
    "mse": 1.0,
    "lpips": 0.5,
    "ssim": 0.2,
    "color": 0.3
}
device = "cuda"
epochs = 5  # Change this as needed

# ---------------- LOSSES ----------------

def color_loss_lab(output, target):
    output = torch.clamp((output + 1) / 2, 0, 1)
    target = torch.clamp((target + 1) / 2, 0, 1)
    return F.l1_loss(rgb_to_lab(output), rgb_to_lab(target))

# ---------------- OPTIMIZER & SCHEDULER ----------------

def get_optimizer_and_scheduler(model, train_dataloader, epochs, base_lr=1e-4, wd=1e-4, warmup_ratio=0.1):
    optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=wd)
    steps_per_epoch = len(train_dataloader)
    total_steps = epochs * steps_per_epoch
    warmup_steps = int(warmup_ratio * total_steps)

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(progress * math.pi))
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    return optimizer, scheduler, total_steps, warmup_steps

# ---------------- TRAIN LOOP ----------------

def train_model(model, train_dataloader, val_dataloader=None,
                epochs=5, device="cuda",
                loss_config=None, loss_weights=None,
                base_lr=1e-4, wd=1e-4, warmup_ratio=0.1,save_dir = "chuchu"):

    model.to(device)
    lpips_loss = lpips.LPIPS(net='vgg').to(device)
    optimizer, scheduler, total_steps, warmup_steps = get_optimizer_and_scheduler(
        model, train_dataloader, epochs, base_lr, wd, warmup_ratio)
    wandb.init(
        project="assignment-color-adding-unet",
        name="head-12(64,32,64)_colorL-ssim-lpips-mse-smol-unet-15-epochs"
    )

    global_step = 0
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

        for batch in loop:
            input_image = batch["input_image"].to(device)
            output_image = batch["output_image"].to(device)
            prompt = batch["prompt"]
            timestep = torch.zeros(input_image.size(0), dtype=torch.long).to(device)

            outputs = model(input_image, timestep=timestep, prompt=prompt)
            outputs_01 = torch.clamp((outputs + 1) / 2, 0, 1)
            output_image_01 = torch.clamp((output_image + 1) / 2, 0, 1)

            # Compute selected losses
            loss = 0.0
            log_dict = {}
            if loss_config.get("mse"):
                mse = F.mse_loss(outputs, output_image)
                loss += loss_weights["mse"] * mse
                log_dict["batch_mse"] = mse.item()
            if loss_config.get("lpips"):
                lp = lpips_loss(outputs, output_image).mean()
                loss += loss_weights["lpips"] * lp
                log_dict["batch_lpips"] = lp.item()
            if loss_config.get("ssim"):
                ssim = piq.ssim(outputs_01, output_image_01, data_range=1.0).mean()
                loss += loss_weights["ssim"] * (1 - ssim)
                log_dict["batch_ssim"] = ssim.item()
            if loss_config.get("color"):
                col = color_loss_lab(outputs, output_image)
                loss += loss_weights["color"] * col
                log_dict["batch_color_loss"] = col.item()

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            global_step += 1

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

            log_dict["batch_loss"] = loss.item()
            log_dict["epoch"] = epoch + 1
            wandb.log(log_dict)

        # End of epoch logging
        avg_train_loss = total_loss / len(train_dataloader)
        wandb.log({"train_loss": avg_train_loss, "epoch": epoch + 1})
        print(f"Epoch {epoch+1} avg train loss: {avg_train_loss:.4f}")

        # ---------- Validation & Image Saving ----------
        if val_dataloader:
            model.eval()
            val_loss = 0.0
            val_lpips_total = 0.0
            val_mse_total = 0.0
            val_ssim_total = 0.0
            val_color_total = 0.0
            grid = None

            with torch.no_grad():
                for batch_idx, val_batch in enumerate(tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]")):
                    val_input = val_batch["input_image"].to(device)
                    val_target = val_batch["output_image"].to(device)
                    val_prompt = val_batch["prompt"]
                    val_timestep = torch.zeros(val_input.size(0), dtype=torch.long).to(device)

                    val_output = model(val_input, timestep=val_timestep, prompt=val_prompt)
                    val_output_01 = torch.clamp((val_output + 1) / 2, 0, 1)
                    val_target_01 = torch.clamp((val_target + 1) / 2, 0, 1)

                    val_mse = F.mse_loss(val_output, val_target)
                    val_lpips = lpips_loss(val_output, val_target).mean()
                    val_ssim = piq.ssim(val_output_01, val_target_01, data_range=1.0).mean()
                    val_col_loss = color_loss_lab(val_output, val_target)

                    val_combined_loss = (
                        (loss_weights["mse"] * val_mse if loss_config.get("mse") else 0) +
                        (loss_weights["lpips"] * val_lpips if loss_config.get("lpips") else 0) +
                        (loss_weights["ssim"] * (1 - val_ssim) if loss_config.get("ssim") else 0) +
                        (loss_weights["color"] * val_col_loss if loss_config.get("color") else 0)
                    )

                    val_loss += val_combined_loss.item()
                    val_mse_total += val_mse.item()
                    val_lpips_total += val_lpips.item()
                    val_ssim_total += val_ssim.item()
                    val_color_total += val_col_loss.item()

                    # Save images from first batch only
                    if batch_idx == 0:
                        save_dir = save_dir
                        os.makedirs(save_dir, exist_ok=True)
                        n = min(4, val_output.size(0))
                        outputs = val_output[:n]
                        gts = val_target[:n]
                        # Interleave output and GT: [output1, gt1, output2, gt2, ...]
                        pairs = []
                        for i in range(n):
                            pairs.append(outputs[i])
                            pairs.append(gts[i])
                        comparison = torch.stack(pairs, dim=0)
                        grid = make_grid(comparison, nrow=2*n, normalize=True, value_range=(-1, 1))
                        save_image(grid, f"{save_dir}/epoch_{epoch+1}.png")
                        wandb.log({f"val_images_epoch_{epoch+1}": wandb.Image(grid)})

            avg_val_loss = val_loss / len(val_dataloader)
            avg_val_mse = val_mse_total / len(val_dataloader)
            avg_val_lpips = val_lpips_total / len(val_dataloader)
            avg_val_ssim = val_ssim_total / len(val_dataloader)
            avg_val_color = val_color_total / len(val_dataloader)

            wandb.log({
                "val_loss": avg_val_loss,
                "val_mse": avg_val_mse,
                "val_lpips": avg_val_lpips,
                "val_ssim": avg_val_ssim,
                "val_color_loss": avg_val_color,
                "epoch": epoch + 1
            })
            print(f"Epoch {epoch+1} avg val loss: {avg_val_loss:.4f}")

    wandb.finish()

# ---------------- USAGE EXAMPLE ----------------

# Assume you have your model, train_dataloader, val_dataloader defined
# model = ColorDenoisingUNet().to(device)
# train_dataloader = ...
# val_dataloader = ...



In [10]:
model = ColorDenoisingUNet().to(device)
train_model(
    model,
    train_loader,
    val_dataloader=val_loader,
    epochs=15,
    device=device,
    loss_config=loss_config,
    loss_weights=loss_weights,
    base_lr=1e-4,
    wd=1e-4,
    warmup_ratio=0.1,
    save_dir = "head-12(64,32,64)_colorL-ssim-lpips-mse-smol-unet-15-epochs")

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


0,1
batch_color_loss,▇▆█▆▆▇▇▆█▅▆▆▆▅▆▅▆▄▃▅▄▄▄▄▂▃▃▂▄▃▂▂▂▂▂▁▁▂
batch_loss,█▆█▆▆▇▇▆█▅▆▆▆▅▆▅▆▄▄▅▄▄▄▄▂▃▃▂▄▃▂▂▂▂▂▁▁▁
batch_lpips,▇▅▇█▇██▆█▅▆▅▇▆█▆▇▆▄▅▅▅▅▄▄▃▅▄▆▃▄▃▁▂▃▁▁▂
batch_mse,█▅█▇▅▆▅▆▇▅▆▆█▆▆▆▆▄▆▆▅▅▆▆▅▆▆▄▄▅▃▄▅▃▃▂▂▁
batch_ssim,▂▃▂▄▄▄▃▄▁▆▃▃▄▆▄▃▂▃▅▂▅▄▄▄▇▅▄▆▄▅▆▅▅▅▆█▆▇
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch_color_loss,25.73924
batch_loss,9.2669
batch_lpips,0.7278
batch_mse,1.05133
batch_ssim,0.35049
epoch,1.0


Epoch 1/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.20it/s, loss=0.412]


Epoch 1 avg train loss: 2.4073


Epoch 1 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.96it/s]


Epoch 1 avg val loss: 0.2497


Epoch 2/15 [Train]: 100%|██████████| 540/540 [02:49<00:00,  3.19it/s, loss=0.159]


Epoch 2 avg train loss: 0.2779


Epoch 2 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.98it/s]


Epoch 2 avg val loss: 0.1300


Epoch 3/15 [Train]: 100%|██████████| 540/540 [02:49<00:00,  3.18it/s, loss=0.177] 


Epoch 3 avg train loss: 0.1545


Epoch 3 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.98it/s]


Epoch 3 avg val loss: 0.1112


Epoch 4/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.20it/s, loss=0.119] 


Epoch 4 avg train loss: 0.1117


Epoch 4 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.95it/s]


Epoch 4 avg val loss: 0.0926


Epoch 5 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.82it/s]it/s, loss=0.1]   


Epoch 5 avg val loss: 0.0692


Epoch 6/15 [Train]: 100%|██████████| 540/540 [02:49<00:00,  3.18it/s, loss=0.0591]


Epoch 6 avg train loss: 0.0789


Epoch 6 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.89it/s]


Epoch 6 avg val loss: 0.0575


Epoch 7/15 [Train]: 100%|██████████| 540/540 [02:49<00:00,  3.18it/s, loss=0.0867]


Epoch 7 avg train loss: 0.0712


Epoch 7 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.93it/s]


Epoch 7 avg val loss: 0.0561


Epoch 8/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.20it/s, loss=0.057] 


Epoch 8 avg train loss: 0.0628


Epoch 8 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.90it/s]


Epoch 8 avg val loss: 0.0465


Epoch 9/15 [Train]: 100%|██████████| 540/540 [02:47<00:00,  3.22it/s, loss=0.0552]


Epoch 9 avg train loss: 0.0571


Epoch 9 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.96it/s]


Epoch 9 avg val loss: 0.0423


Epoch 10/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.20it/s, loss=0.0497]


Epoch 10 avg train loss: 0.0530


Epoch 10 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.92it/s]


Epoch 10 avg val loss: 0.0397


Epoch 11/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.21it/s, loss=0.0357]


Epoch 11 avg train loss: 0.0479


Epoch 11 [Val]: 100%|██████████| 135/135 [00:16<00:00,  8.01it/s]


Epoch 11 avg val loss: 0.0382


Epoch 12/15 [Train]: 100%|██████████| 540/540 [02:47<00:00,  3.23it/s, loss=0.0426]


Epoch 12 avg train loss: 0.0452


Epoch 12 [Val]: 100%|██████████| 135/135 [00:17<00:00,  7.92it/s]


Epoch 12 avg val loss: 0.0354


Epoch 13/15 [Train]: 100%|██████████| 540/540 [02:48<00:00,  3.21it/s, loss=0.0486]


Epoch 13 avg train loss: 0.0434


Epoch 13 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.97it/s]


Epoch 13 avg val loss: 0.0336


Epoch 14/15 [Train]: 100%|██████████| 540/540 [02:47<00:00,  3.23it/s, loss=0.0455]


Epoch 14 avg train loss: 0.0423


Epoch 14 [Val]: 100%|██████████| 135/135 [00:16<00:00,  8.01it/s]


Epoch 14 avg val loss: 0.0338


Epoch 15/15 [Train]: 100%|██████████| 540/540 [02:47<00:00,  3.23it/s, loss=0.0418]


Epoch 15 avg train loss: 0.0419


Epoch 15 [Val]: 100%|██████████| 135/135 [00:16<00:00,  7.97it/s]

Epoch 15 avg val loss: 0.0330





0,1
batch_color_loss,▇█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_loss,█▇▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_lpips,█▆▅▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_mse,█▆▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_ssim,▁▃▇▆▇▇▅▇▆▇▇▇██▇▇▇███▇█▇█████████████████
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇██
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_color_loss,█▄▄▃▂▂▂▂▁▁▁▁▁▁▁
val_loss,█▄▄▃▂▂▂▁▁▁▁▁▁▁▁
val_lpips,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁

0,1
batch_color_loss,0.08932
batch_loss,0.04177
batch_lpips,0.01992
batch_mse,0.00409
batch_ssim,0.9954
epoch,15.0
train_loss,0.04189
val_color_loss,0.07279
val_loss,0.03301
val_lpips,0.01248


In [None]:

from huggingface_hub import HfApi, HfFolder
# Replace with your actual token
token = ""
from huggingface_hub import PyTorchModelHubMixin, login

# Log in with token
login(token=token)

# Make your model compatible
class HFUNet(PyTorchModelHubMixin, nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ColorDenoisingUNet()

    def forward(self, x):
        return self.model(x)

# Instantiate and move to device
hf_model = HFUNet().to(device)
hf_model.model.load_state_dict(model.state_dict())  # copy weights

# Push to hub
hf_model.push_to_hub("head-12-64-32-64_colorL-ssim-lpips-mse-smol-unet-15-epochs")


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...ol-unet-15-epochs/model.safetensors:   6%|6         | 15.7MB /  261MB            

CommitInfo(commit_url='https://huggingface.co/bhavya777/head-12-64-32-64_colorL-ssim-lpips-mse-smol-unet-15-epochs/commit/3559aafb5cd20e9d3a72710d279593894c765724', commit_message='Push model using huggingface_hub.', commit_description='', oid='3559aafb5cd20e9d3a72710d279593894c765724', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bhavya777/head-12-64-32-64_colorL-ssim-lpips-mse-smol-unet-15-epochs', endpoint='https://huggingface.co', repo_type='model', repo_id='bhavya777/head-12-64-32-64_colorL-ssim-lpips-mse-smol-unet-15-epochs'), pr_revision=None, pr_num=None)

In [14]:
from torchinfo import summary  # if you want detailed stats (optional)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model):,}")

Total trainable parameters: 65,212,403


In [None]:
#model1 = 
#model2 - 5epochs,smol-mse-lpips-unet - bhavya777/lpips-mse-smol-unet-5-epochs
#model3-ssim-lpips-mse-smol-unet-5-epochs
model4-colorL-ssim-lpips-mse-smol-unet-5-epochs
model5-colorL-ssim-lpips-mse-big-unet-5-epochss=2
model6-layer = 2 in smol unet - layer_big_colorL-ssim-lpips-mse-smol-unet-10-epochs
model7-used layers=1,hed_dim=8 usuall smol unet head-8_colorL-ssim-lpips-mse-smol-unet-15-epochs
model-8-head-12(64,32,64)_colorL-ssim-lpips-mse-smol-unet-15-epochs


In [None]:
#inference 
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from diffusers import UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from huggingface_hub import PyTorchModelHubMixin

# Define your model class
class ColorDenoisingUNet(nn.Module):
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        super().__init__()
        self.device = device

        # CLIP text encoder
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

        # UNet model
        self.unet = UNet2DConditionModel(
            sample_size=128,
            in_channels=3,
            out_channels=3,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
            mid_block_type="UNetMidBlock2DCrossAttn",
            up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
            block_out_channels=(64, 128, 256),
            layers_per_block=2,
            attention_head_dim=8,
            cross_attention_dim=512,
            only_cross_attention=True,
            dropout=0.1,
            act_fn="silu",
        ).to(device)

    def get_text_embedding(self, prompt, x):
        if isinstance(prompt, str):
            prompt = [prompt] * x.size(0)
        text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(self.device)
        text_embeddings = self.text_encoder(**text_inputs).last_hidden_state
        return text_embeddings

    def forward(self, x: torch.Tensor, timestep: torch.Tensor, prompt):
        encoder_hidden_states = self.get_text_embedding(prompt, x)
        return self.unet(sample=x, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample


# Wrapper for Hugging Face Hub loading
class HFUNet(PyTorchModelHubMixin, nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ColorDenoisingUNet()

    def forward(self, x, timestep, prompt):
        return self.model(x, timestep, prompt)

    @property
    def device(self):
        return self.model.device


# Load model from Hugging Face Hub
model = HFUNet.from_pretrained("head-12-64-32-64_colorL-ssim-lpips-mse-smol-unet-15-epochs").to("cuda" if torch.cuda.is_available() else "cpu")


# Inference function
def infer_batch_prompts(model, image_path, prompts):
    device = model.device

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])

    image = Image.open(image_path).convert("RGB")
    input_image = transform(image)

    input_batch = input_image.unsqueeze(0).repeat(len(prompts), 1, 1, 1).to(device)
    timestep = torch.zeros(len(prompts), dtype=torch.long).to(device)
    prompt_batch = prompts

    model.eval()
    with torch.no_grad():
        output = model(input_batch, timestep, prompt_batch)
        output = (output.clamp(-1, 1) + 1) / 2

        for i in range(len(prompts)):
            out_img = output[i].permute(1, 2, 0).cpu().numpy()
            plt.imshow(out_img)
            plt.axis('off')
            plt.title(f"Prompt: {prompts[i]}")
            plt.show()
