In [1]:
pip install diffusers transformers accelerate bitsandbytes wandb pycocotools kagglehub

Note: you may need to restart the kernel to use updated packages.


In [3]:
# import os
# from huggingface_hub import hf_hub_download

# # --- CONFIGURATION ---
# REPO_ID = "ritishshrirao/controlnet-coco-multi"
# LOCAL_DIR = "pretrained_weights"

# def download_models():
#     os.makedirs(LOCAL_DIR, exist_ok=True)
    
#     print(f"Downloading weights from {REPO_ID}...")
    
#     # 1. Download Segmentation Weights
#     print("... Fetching Segmentation Model")
#     seg_path = hf_hub_download(
#         repo_id=REPO_ID,
#         filename="diffusion_pytorch_model.safetensors",
#         subfolder="segmentation",
#         local_dir=LOCAL_DIR
#     )
    
#     # 2. Download BBox Weights
#     print("... Fetching BBox Model")
#     bbox_path = hf_hub_download(
#         repo_id=REPO_ID,
#         filename="diffusion_pytorch_model.safetensors",
#         subfolder="bbox",
#         local_dir=LOCAL_DIR
#     )
    
#     print("\n‚úÖ Download Complete.")
#     print(f"Segmentation: {seg_path}")
#     print(f"BBox: {bbox_path}")

# download_models()

Downloading weights from ritishshrirao/controlnet-coco-multi...
... Fetching Segmentation Model


segmentation/diffusion_pytorch_model.saf(‚Ä¶):   0%|          | 0.00/1.45G [00:00<?, ?B/s]

... Fetching BBox Model


bbox/diffusion_pytorch_model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]


‚úÖ Download Complete.
Segmentation: pretrained_weights/segmentation/diffusion_pytorch_model.safetensors
BBox: pretrained_weights/bbox/diffusion_pytorch_model.safetensors


In [2]:
%%writefile train_multi_controlnet1.py
import os
import torch
import numpy as np
import random
import shutil
import requests
import zipfile
from tqdm.auto import tqdm
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
from torchvision import transforms
import bitsandbytes as bnb
import wandb

# Accelerate & Diffusers
from accelerate import Accelerator
from diffusers import (
    StableDiffusionControlNetPipeline, 
    ControlNetModel, 
    MultiControlNetModel,
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from transformers import AutoTokenizer, CLIPTextModel

# Enable TF32 for H100 speedup
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# --- CONFIGURATION ---
class Config:
    # Paths
    DATA_ROOT = "data"
    COCO_ROOT = os.path.join(DATA_ROOT, "coco2017")
    TRAIN_IMG_DIR = os.path.join(COCO_ROOT, "train2017")
    TRAIN_ANN_FILE = os.path.join(COCO_ROOT, "annotations/instances_train2017.json")
    
    MODEL_ID = "runwayml/stable-diffusion-v1-5"
    OUTPUT_DIR = "controlnet-coco-multi-h100-new"
    
    # --- RESUME CONFIGURATION ---
    # Path to the folder containing 'segmentation' and 'bbox' subfolders downloaded via the script
    PRETRAINED_WEIGHTS_DIR = "pretrained_weights" 
    
    # Set to "latest" to try loading full accelerator state (optimizer+model).
    # If not found, it will fall back to loading weights from PRETRAINED_WEIGHTS_DIR if available.
    RESUME_FROM_CHECKPOINT = "latest"
    
    # Hyperparameters
    RESOLUTION = 512
    BATCH_SIZE = 64
    GRAD_ACCUM_STEPS = 1
    LEARNING_RATE = 2e-5   
    NUM_EPOCHS = 15
    
    # Logging
    LOG_INTERVAL = 200       
    LOG_BATCH_SIZE = 8
    SAVE_INTERVAL = 1000
    MAX_SAMPLES = None
    
    # Data Loading
    NUM_WORKERS = 16
    
    # Dropout Probabilities
    PROMPT_DROPOUT_PROB = 0.4
    PROB_SEG_ONLY = 0.35
    PROB_BBOX_ONLY = 0.35

# --- DOWNLOAD UTILS ---
def download_file(url, save_path):
    if os.path.exists(save_path): return
    print(f"Downloading {url} to {save_path}...")
    response = requests.get(url, stream=True)
    with open(save_path, "wb") as file:
        for data in response.iter_content(1024):
            file.write(data)

def unzip_file(zip_path, extract_to):
    print(f"Extracting {zip_path}...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

def setup_coco_dataset():
    os.makedirs(Config.DATA_ROOT, exist_ok=True)
    os.makedirs(Config.COCO_ROOT, exist_ok=True)
    
    train_images_url = "http://images.cocodataset.org/zips/train2017.zip"
    annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    
    zip_train = os.path.join(Config.DATA_ROOT, "train2017.zip")
    zip_ann = os.path.join(Config.DATA_ROOT, "annotations.zip")
    
    if not os.path.exists(Config.TRAIN_IMG_DIR):
        download_file(train_images_url, zip_train)
        unzip_file(zip_train, Config.COCO_ROOT)
        if os.path.exists(zip_train): os.remove(zip_train)
        
    if not os.path.exists(os.path.dirname(Config.TRAIN_ANN_FILE)):
        download_file(annotations_url, zip_ann)
        unzip_file(zip_ann, Config.COCO_ROOT)
        if os.path.exists(zip_ann): os.remove(zip_ann)

# --- DATASET CLASS ---
class COCOMultiDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, tokenizer, size=512, max_samples=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.img_ids = self.coco.getImgIds()
        self.tokenizer = tokenizer
        self.size = size
        self.img_ids = [img_id for img_id in self.img_ids if len(self.coco.getAnnIds(imgIds=img_id)) > 0]
        if max_samples: self.img_ids = self.img_ids[:max_samples]

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]), 
        ])
        
        self.cond_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(size),
            transforms.ToTensor(), 
        ])
        
        self.color_map = self._generate_color_map()

    def _generate_color_map(self):
        cats = self.coco.loadCats(self.coco.getCatIds())
        palette = {}
        for cat in cats:
            import hashlib
            h = hashlib.md5(str(cat['id']).encode()).hexdigest()
            palette[cat['id']] = (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
        return palette

    def draw_segmentation_map(self, img_shape, anns):
        mask = np.zeros((img_shape[1], img_shape[0], 3), dtype=np.uint8)
        anns = sorted(anns, key=lambda x: x['area'], reverse=True)
        for ann in anns:
            color = self.color_map.get(ann['category_id'], (255, 255, 255))
            binary_mask = self.coco.annToMask(ann)
            mask[binary_mask == 1] = color
        return Image.fromarray(mask)

    def draw_bbox_map(self, img_shape, anns):
        mask = np.zeros((img_shape[1], img_shape[0], 3), dtype=np.uint8)
        canvas = Image.fromarray(mask)
        draw = ImageDraw.Draw(canvas)
        anns = sorted(anns, key=lambda x: x['area'], reverse=True)
        for ann in anns:
            x, y, w, h = ann['bbox']
            color = self.color_map.get(ann['category_id'], (255, 255, 255))
            draw.rectangle([x, y, x+w, y+h], fill=color, outline=None)
        return canvas

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        try:
            image_path = os.path.join(self.img_dir, img_info['file_name'])
            image = Image.open(image_path).convert("RGB")
        except Exception:
            return self.__getitem__((idx + 1) % len(self))
            
        anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
        seg_image = self.draw_segmentation_map(image.size, anns)
        bbox_image = self.draw_bbox_map(image.size, anns)
        
        cat_ids = [ann['category_id'] for ann in anns]
        cat_names = list(set([cat['name'] for cat in self.coco.loadCats(cat_ids)]))
        text_prompt = "" if random.random() < Config.PROMPT_DROPOUT_PROB else \
                      f"A photorealistic image containing {', '.join(cat_names)}" if cat_names else "A photorealistic image"
        
        return {
            "pixel_values": self.image_transforms(image),
            "seg_pixel_values": self.cond_transforms(seg_image),
            "bbox_pixel_values": self.cond_transforms(bbox_image),
            "input_ids": self.tokenizer(
                text_prompt, max_length=self.tokenizer.model_max_length, 
                padding="max_length", truncation=True, return_tensors="pt"
            ).input_ids[0],
            "raw_prompt": text_prompt
        }

# --- VALIDATION HELPER ---
def log_validation(accelerator, controlnet_seg, controlnet_bbox, unet, vae, text_encoder, tokenizer, val_batch, step):
    if not accelerator.is_main_process: return
    print(f"Running Validation at step {step}...")

    try:
        multi_controlnet = MultiControlNetModel([controlnet_seg, controlnet_bbox])
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            Config.MODEL_ID,
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
            unet=unet, controlnet=multi_controlnet,
            safety_checker=None, torch_dtype=torch.bfloat16
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)
        
        log_images = []
        num_samples = min(len(val_batch["raw_prompt"]), Config.LOG_BATCH_SIZE)
        
        def to_pil(tensor):
            tensor = tensor.detach().cpu().float()
            if tensor.shape[0] == 3: tensor = tensor.permute(1, 2, 0)
            if tensor.min() < 0: tensor = (tensor + 1) / 2.0
            tensor = tensor.clamp(0, 1).numpy()
            return Image.fromarray((tensor * 255).astype(np.uint8))

        for i in range(num_samples):
            prompt = val_batch["raw_prompt"][i] or "A photorealistic image"
            gt_image = to_pil(val_batch["pixel_values"][i])
            seg_image = to_pil(val_batch["seg_pixel_values"][i])
            bbox_image = to_pil(val_batch["bbox_pixel_values"][i])
            
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda", dtype=torch.bfloat16):
                pred_both = pipeline(prompt, image=[seg_image, bbox_image], num_inference_steps=20, generator=generator, controlnet_conditioning_scale=[1.0, 1.0]).images[0]
                pred_seg = pipeline(prompt, image=[seg_image, bbox_image], num_inference_steps=20, generator=generator, controlnet_conditioning_scale=[1.0, 0.0]).images[0]
                pred_bbox = pipeline(prompt, image=[seg_image, bbox_image], num_inference_steps=20, generator=generator, controlnet_conditioning_scale=[0.0, 1.0]).images[0]

            log_images.extend([
                wandb.Image(seg_image, caption=f"{i} Seg In"),
                wandb.Image(bbox_image, caption=f"{i} BBox In"),
                wandb.Image(gt_image, caption=f"{i} Truth"),
                wandb.Image(pred_both, caption=f"{i} Both"),
                wandb.Image(pred_seg, caption=f"{i} Seg Only"),
                wandb.Image(pred_bbox, caption=f"{i} BBox Only")
            ])
        
        tracker = accelerator.get_tracker("wandb")
        if tracker: tracker.log({"validation": log_images}, step=step)
        
        del pipeline, multi_controlnet
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Skipping validation log due to error: {e}")

# --- MAIN FUNCTION ---
def main():
    setup_coco_dataset()

    accelerator = Accelerator(
        gradient_accumulation_steps=Config.GRAD_ACCUM_STEPS,
        mixed_precision="bf16",
        log_with="wandb",
    )
    
    if accelerator.is_main_process:
        wandb_key = os.getenv("wandb")
        if wandb_key: wandb.login(key=wandb_key)
        cfg_dict = {k: v for k, v in Config.__dict__.items() if not k.startswith("__")}
        accelerator.init_trackers("controlnet-coco-multi-h100", config=cfg_dict)

    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, subfolder="tokenizer", use_fast=False)
    noise_scheduler = DDPMScheduler.from_pretrained(Config.MODEL_ID, subfolder="scheduler")
    text_encoder = CLIPTextModel.from_pretrained(Config.MODEL_ID, subfolder="text_encoder", torch_dtype=torch.bfloat16)
    vae = AutoencoderKL.from_pretrained(Config.MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16)
    unet = UNet2DConditionModel.from_pretrained(Config.MODEL_ID, subfolder="unet", torch_dtype=torch.bfloat16)
    
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    
    # --- LOAD WEIGHTS LOGIC ---
    seg_path = os.path.join(Config.PRETRAINED_WEIGHTS_DIR, "segmentation")
    bbox_path = os.path.join(Config.PRETRAINED_WEIGHTS_DIR, "bbox")
    
    # 1. Initialize empty models first
    print("Initializing ControlNets...")
    controlnet_seg = ControlNetModel.from_unet(unet)
    controlnet_bbox = ControlNetModel.from_unet(unet)

    # 2. Try loading from PRETRAINED_WEIGHTS_DIR (The Safetensors you downloaded)
    # Only do this if we are NOT loading a full checkpoint state later
    loaded_pretrained = False
    
    # Check if a full checkpoint exists
    full_checkpoint_exists = False
    if os.path.exists(Config.OUTPUT_DIR):
        dirs = [d for d in os.listdir(Config.OUTPUT_DIR) if d.startswith("checkpoint")]
        if dirs: full_checkpoint_exists = True

    if not full_checkpoint_exists and os.path.exists(seg_path) and os.path.exists(bbox_path):
        if accelerator.is_main_process:
            print(f"üîÑ No full checkpoint found. Loading Pretrained Weights from {Config.PRETRAINED_WEIGHTS_DIR}...")
            
            try:
                controlnet_seg = ControlNetModel.from_pretrained(seg_path)
                print("‚úÖ Segmentation Weights Loaded.")
                
                controlnet_bbox = ControlNetModel.from_pretrained(bbox_path)
                print("‚úÖ BBox Weights Loaded.")
                loaded_pretrained = True
            except Exception as e:
                print(f"‚ùå Failed to load pretrained weights: {e}")
                print("Falling back to scratch training.")

    controlnet_seg.train()
    controlnet_bbox.train()
    controlnet_seg.enable_gradient_checkpointing()
    controlnet_bbox.enable_gradient_checkpointing()
    unet.enable_gradient_checkpointing()

    dataset = COCOMultiDataset(Config.TRAIN_IMG_DIR, Config.TRAIN_ANN_FILE, tokenizer, max_samples=Config.MAX_SAMPLES)
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
        num_workers=Config.NUM_WORKERS, pin_memory=True, drop_last=True
    )

    val_batch = next(iter(train_dataloader))
    params = list(controlnet_seg.parameters()) + list(controlnet_bbox.parameters())
    optimizer = bnb.optim.AdamW8bit(params, lr=Config.LEARNING_RATE)

    controlnet_seg, controlnet_bbox, optimizer, train_dataloader = accelerator.prepare(
        controlnet_seg, controlnet_bbox, optimizer, train_dataloader
    )
    
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    unet.to(accelerator.device)

    global_step = 0
    start_epoch = 0
    
    # 3. Full Resume Logic (Accelerate Save State)
    if Config.RESUME_FROM_CHECKPOINT == "latest" and os.path.exists(Config.OUTPUT_DIR):
        dirs = [d for d in os.listdir(Config.OUTPUT_DIR) if d.startswith("checkpoint")]
        if dirs:
            path = sorted(dirs, key=lambda x: int(x.split("-")[1]))[-1]
            accelerator.print(f"‚è© Resuming full training state from {path}")
            accelerator.load_state(os.path.join(Config.OUTPUT_DIR, path))
            global_step = int(path.split("-")[1])
            start_epoch = global_step // len(train_dataloader)
        elif loaded_pretrained:
            accelerator.print(f"‚ÑπÔ∏è  Starting new run, but initialized with weights from {Config.PRETRAINED_WEIGHTS_DIR}")
        else:
            accelerator.print("‚ö†Ô∏è  No checkpoint or pretrained weights found. Starting from scratch.")

    if accelerator.is_main_process: print(f"üöÄ Starting training from Step {global_step}...")
    
    for epoch in range(start_epoch, Config.NUM_EPOCHS):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch}")
        
        for batch in train_dataloader:
            with accelerator.accumulate([controlnet_seg, controlnet_bbox]):
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.bfloat16)).latent_dist.sample() * vae.config.scaling_factor
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                
                # --- SIMPLIFIED LOGIC ---
                r = random.random()
                
                active_down = []
                active_mid = []

                # 1. SEGMENTATION
                if r < Config.PROB_SEG_ONLY or r >= (Config.PROB_SEG_ONLY + Config.PROB_BBOX_ONLY):
                    real_seg = batch["seg_pixel_values"].to(dtype=torch.bfloat16)
                    d, m = controlnet_seg(
                        noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, 
                        controlnet_cond=real_seg, return_dict=False
                    )
                    active_down.append(d)
                    active_mid.append(m)

                # 2. BBOX
                if (r >= Config.PROB_SEG_ONLY and r < (Config.PROB_SEG_ONLY + Config.PROB_BBOX_ONLY)) or r >= (Config.PROB_SEG_ONLY + Config.PROB_BBOX_ONLY):
                    real_bbox = batch["bbox_pixel_values"].to(dtype=torch.bfloat16)
                    d, m = controlnet_bbox(
                        noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, 
                        controlnet_cond=real_bbox, return_dict=False
                    )
                    active_down.append(d)
                    active_mid.append(m)

                # --- AGGREGATION ---
                if len(active_down) == 0:
                    down_block_res = None
                    mid_block_res = None
                else:
                    down_block_res = active_down[0]
                    mid_block_res = active_mid[0]
                    for i in range(1, len(active_down)):
                        down_block_res = [a + b for a, b in zip(down_block_res, active_down[i])]
                        mid_block_res = mid_block_res + active_mid[i]
                
                if down_block_res is not None:
                    down_block_res = [res.to(dtype=torch.bfloat16) for res in down_block_res]
                    mid_block_res = mid_block_res.to(dtype=torch.bfloat16)

                model_pred = unet(
                    noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=down_block_res,
                    mid_block_additional_residual=mid_block_res,
                ).sample
                
                loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                global_step += 1
                progress_bar.update(1)
                
                if accelerator.is_main_process:
                    accelerator.log({"train_loss": loss.item()}, step=global_step)
                    if global_step % Config.LOG_INTERVAL == 0:
                        u_seg = accelerator.unwrap_model(controlnet_seg)
                        u_bbox = accelerator.unwrap_model(controlnet_bbox)
                        log_validation(accelerator, u_seg, u_bbox, unet, vae, text_encoder, tokenizer, val_batch, global_step)
                    if global_step % Config.SAVE_INTERVAL == 0:
                        save_path = os.path.join(Config.OUTPUT_DIR, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        try:
                            checkpoints = sorted([d for d in os.listdir(Config.OUTPUT_DIR) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
                            if len(checkpoints) > 2:
                                for ckpt in checkpoints[:-2]: shutil.rmtree(os.path.join(Config.OUTPUT_DIR, ckpt))
                        except Exception: pass
    
    if accelerator.is_main_process:
        accelerator.unwrap_model(controlnet_seg).save_pretrained(os.path.join(Config.OUTPUT_DIR, "final_seg"))
        accelerator.unwrap_model(controlnet_bbox).save_pretrained(os.path.join(Config.OUTPUT_DIR, "final_bbox"))
        accelerator.end_training()

if __name__ == "__main__":
    main()

Overwriting train_multi_controlnet1.py


In [None]:
!accelerate launch --mixed_precision=bf16 train_multi_controlnet1.py

The following values were not passed to `accelerate launch` and had defaults used instead:
	`--num_processes` was set to a value of `1`
	`--num_machines` was set to a value of `1`
	`--dynamo_backend` was set to a value of `'no'`
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /teamspace/studios/this_studio/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mritishshrirao[0m ([33mritishtest1[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: [38;5;178m‚¢ø[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m‚£ª[0m Waiting for wandb.init()...
[34m[1mwandb[0m: Tracking run with wandb version 0.23.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/teamspace/studios/this_studio/wandb/run-20251216_084654-o9bsleru[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mswift-sunset-12[0m
[34m[1mwandb[0m: ‚≠êÔ∏è View projec

In [18]:
%%writefile upload_hf.py

import os
import shutil
import torch

from safetensors.torch import load_file
from huggingface_hub import HfApi, create_repo
from diffusers import ControlNetModel, UNet2DConditionModel

# =========================
# CONFIGURATION
# =========================
HF_USERNAME = "ritishshrirao"
REPO_NAME = "controlnet-coco-multi"
MODEL_ID = f"{HF_USERNAME}/{REPO_NAME}"

BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"

# Explicit checkpoint path (recommended)
CHECKPOINT_PATH = (
    "/teamspace/studios/this_studio/"
    "controlnet-coco-multi-h100-new/checkpoint-6000"
)

# =========================
# UTILS
# =========================
def load_weights_to_model(model, weights_path):
    print(f"Loading weights from {weights_path}...")

    if not os.path.exists(weights_path):
        raise FileNotFoundError(weights_path)

    if weights_path.endswith(".safetensors"):
        state_dict = load_file(weights_path)
    else:
        state_dict = torch.load(weights_path, map_location="cpu")

    # Remove DDP prefix if present
    cleaned = {}
    for k, v in state_dict.items():
        cleaned[k[7:]] = v if k.startswith("module.") else v

    missing, unexpected = model.load_state_dict(cleaned, strict=False)

    if missing:
        print(f"‚ö†Ô∏è Missing keys: {len(missing)}")
    if unexpected:
        print(f"‚ö†Ô∏è Unexpected keys: {len(unexpected)}")

    print("‚úÖ Weights loaded.")
    return model


# =========================
# MAIN
# =========================
def extract_and_upload():
    # --- Auth ---
    hf_token = os.getenv("HF")
    if not hf_token:
        raise RuntimeError("‚ùå HF environment variable not set")

    if not os.path.exists(CHECKPOINT_PATH):
        raise RuntimeError(f"‚ùå Checkpoint not found: {CHECKPOINT_PATH}")

    print(f"üìÇ Processing Checkpoint: {CHECKPOINT_PATH}")

    # --- Weight files (Accelerate order) ---
    seg_weights = os.path.join(CHECKPOINT_PATH, "model.safetensors")
    bbox_weights = os.path.join(CHECKPOINT_PATH, "model_1.safetensors")

    if not os.path.exists(seg_weights) or not os.path.exists(bbox_weights):
        raise RuntimeError(
            "‚ùå Expected model.safetensors and model_1.safetensors"
        )

    # --- Build clean architectures ---
    print("üèóÔ∏è Initializing model architecture...")
    unet = UNet2DConditionModel.from_pretrained(
        BASE_MODEL_ID, subfolder="unet"
    )

    c_seg = ControlNetModel.from_unet(unet)
    c_bbox = ControlNetModel.from_unet(unet)

    # --- Load weights ---
    c_seg = load_weights_to_model(c_seg, seg_weights)
    c_bbox = load_weights_to_model(c_bbox, bbox_weights)

    # --- Save clean safetensors ---
    export_dir = "temp_clean_models"
    if os.path.exists(export_dir):
        shutil.rmtree(export_dir)

    os.makedirs(export_dir, exist_ok=True)

    print("üíæ Saving clean safetensors models...")
    c_seg.save_pretrained(
        os.path.join(export_dir, "segmentation"),
        safe_serialization=True,
    )
    c_bbox.save_pretrained(
        os.path.join(export_dir, "bbox"),
        safe_serialization=True,
    )

    # --- Upload ---
    api = HfApi(token=hf_token)

    print(f"üåê Creating / checking repo: {MODEL_ID}")
    create_repo(
        repo_id=MODEL_ID,
        exist_ok=True,
        private=False,
        token=hf_token,
    )

    print("üöÄ Uploading to Hugging Face Hub...")

    api.upload_folder(
        folder_path=os.path.join(export_dir, "segmentation"),
        repo_id=MODEL_ID,
        path_in_repo="segmentation",
        commit_message="Upload segmentation ControlNet",
    )

    api.upload_folder(
        folder_path=os.path.join(export_dir, "bbox"),
        repo_id=MODEL_ID,
        path_in_repo="bbox",
        commit_message="Upload bbox ControlNet",
    )

    print("\n‚úÖ Upload complete!")
    print(f"üëâ https://huggingface.co/{MODEL_ID}/tree/main")

    shutil.rmtree(export_dir)


if __name__ == "__main__":
    extract_and_upload()


Overwriting upload_hf.py


In [19]:
!python upload_hf.py

üìÇ Processing Checkpoint: /teamspace/studios/this_studio/controlnet-coco-multi-h100-new/checkpoint-5000
üèóÔ∏è Initializing model architecture...
Loading weights from /teamspace/studios/this_studio/controlnet-coco-multi-h100-new/checkpoint-5000/model.safetensors...
‚ö†Ô∏è Missing keys: 340
‚ö†Ô∏è Unexpected keys: 340
‚úÖ Weights loaded.
Loading weights from /teamspace/studios/this_studio/controlnet-coco-multi-h100-new/checkpoint-5000/model_1.safetensors...
‚ö†Ô∏è Missing keys: 340
‚ö†Ô∏è Unexpected keys: 340
‚úÖ Weights loaded.
üíæ Saving clean safetensors models...
üåê Creating / checking repo: ritishshrirao/controlnet-coco-multi
üöÄ Uploading to Hugging Face Hub...
Processing Files (0 / 0)      : |                  |  0.00B /  0.00B            
New Data Upload               : |                  |  0.00B /  0.00B            [A

  ...pytorch_model.safetensors:   0%|              | 1.06MB / 1.45GB            [A[A

Processing Files (0 / 1)      :   0%|              | 1.06MB / 1.

In [12]:
ls /teamspace/studios/this_studio/controlnet-coco-multi-h100-new/checkpoint-5000

model.safetensors  model_1.safetensors  optimizer.bin  random_states_0.pkl
