In [None]:
!pip -q install datasets scikit-image opencv-python

In [None]:
from datasets import load_dataset
ds = load_dataset("nickpai/coco2017-colorization")  # train + val/test (depending on revision)
print(ds)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/27.5M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112268 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['license', 'file_name', 'coco_url', 'height', 'width', 'date_captured', 'flickr_url', 'image_id', 'ids', 'captions'],
        num_rows: 112268
    })
    validation: Dataset({
        features: ['license', 'file_name', 'coco_url', 'height', 'width', 'date_captured', 'flickr_url', 'image_id', 'ids', 'captions'],
        num_rows: 5000
    })
})


In [None]:
!pip -q install datasets scikit-image opencv-python requests


In [None]:
from datasets import load_dataset
ds = load_dataset("nickpai/coco2017-colorization")
train_meta, val_meta = ds["train"], ds["validation"]

In [None]:
import os, io, time, random, requests, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage import color
from pathlib import Path

CACHE_DIR = Path("/content/coco_cache"); CACHE_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# ====================== CONFIG ======================
CROP        = 256              # final side length
RESIZE_MIN  = 224              # random short-side (train)
RESIZE_MAX  = 288
BATCH_SIZE  = 16               # adjust per GPU
NUM_WORKERS = 2
VAL_FRAC    = 0.05             # from COCO train2017
SEED        = 42

# ================== SETUP & METADATA =================
!pip -q install datasets scikit-image opencv-python requests
from datasets import load_dataset
ds = load_dataset("nickpai/coco2017-colorization")
train_meta, val_meta = ds["train"], ds["validation"]  # we'll split train_meta into train/val; val_meta becomes test

from pathlib import Path
CACHE_DIR = Path("/content/coco_cache"); CACHE_DIR.mkdir(parents=True, exist_ok=True)

# =================== HELPERS =========================
import io, time, requests, numpy as np, torch, random, cv2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage import color

def fetch_image(url, retries=3, timeout=15):
    """Download with small on-disk cache (by filename)."""
    fname = url.split("/")[-1]
    fpath = CACHE_DIR / fname
    if fpath.exists():
        return Image.open(fpath).convert("RGB")
    err = None
    for k in range(retries):
        try:
            r = requests.get(url, timeout=timeout)
            r.raise_for_status()
            img = Image.open(io.BytesIO(r.content)).convert("RGB")
            img.save(fpath)  # cache
            return img
        except Exception as e:
            err = e
            time.sleep(1.0*(k+1))
    raise err

def random_resize_then_center_crop(img_np, out_size, lo=RESIZE_MIN, hi=RESIZE_MAX):
    """Random short-side resize ∈ [lo,hi], keep aspect; center-square crop; resize to out_size."""
    h, w = img_np.shape[:2]
    s = np.random.randint(lo, hi+1)
    if h < w:
        new_h, new_w = s, int(round(w*s/h))
    else:
        new_h, new_w = int(round(h*s/w)), s
    rs = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    h2, w2 = rs.shape[:2]
    side = min(h2, w2)
    y0 = (h2 - side)//2; x0 = (w2 - side)//2
    sq = rs[y0:y0+side, x0:x0+side]
    return cv2.resize(sq, (out_size, out_size), interpolation=cv2.INTER_CUBIC)

def to_L_ab_tensors(img_np):
    """
    Input: img_np as RGB float in [0,1].
    Output: L in [0,1], ab in [-1,1], both as CHW torch tensors (1xHxW, 2xHxW).
    """
    # ensure strict [0,1] before Lab
    rgb = np.clip(img_np, 0.0, 1.0).astype(np.float32)
    lab = color.rgb2lab(rgb).astype(np.float32)   # L[0..100], a/b ~ [-110,110]

    L  = lab[..., :1] / 100.0
    ab = lab[..., 1:] / 110.0

    # clip to target ranges to avoid tiny overshoots from interpolation/rounding
    L  = np.clip(L,  0.0, 1.0)
    ab = np.clip(ab, -1.0, 1.0)

    L  = torch.from_numpy(L).permute(2,0,1).contiguous()   # (1,H,W)
    ab = torch.from_numpy(ab).permute(2,0,1).contiguous()  # (2,H,W)
    return L, ab

# =================== DATASET =========================
class CocoURLColorizeFromURLs(Dataset):
    def __init__(self, urls, crop=CROP, train=True, hflip=True,
                 resize_lo=RESIZE_MIN, resize_hi=RESIZE_MAX):
        self.urls = urls
        self.crop = crop
        self.train = train
        self.hflip = (hflip and train)
        self.lo, self.hi = resize_lo, resize_hi

    def __len__(self): return len(self.urls)

    def __getitem__(self, i):
        url = self.urls[i]
        # fetch & to np RGB [0,1]
        try:
            img = fetch_image(url)
        except Exception:
            # fallback on occasional broken URL
            j = random.randrange(len(self.urls))
            img = fetch_image(self.urls[j])
        img_np = np.asarray(img, dtype=np.float32) / 255.0

        if self.train:
            img_np = random_resize_then_center_crop(img_np, self.crop, self.lo, self.hi)
            if random.random() < 0.5:
                img_np = img_np[:, ::-1, :].copy()
        else:
            # deterministic center-square + resize
            h, w = img_np.shape[:2]; side = min(h, w)
            y0 = (h - side)//2; x0 = (w - side)//2
            img_np = img_np[y0:y0+side, x0:x0+side]
            img_np = cv2.resize(img_np, (self.crop, self.crop), interpolation=cv2.INTER_CUBIC)

        L, ab = to_L_ab_tensors(img_np)
        return {"L": L, "ab": ab, "url": url}

# =================== SPLITS =========================
# Split COCO train2017 into train/val; use official val2017 as TEST
urls_all = [row["coco_url"] for row in train_meta]
random.Random(SEED).shuffle(urls_all)
n = len(urls_all); n_val = int(round(VAL_FRAC * n))
urls_val = urls_all[:n_val]
urls_train = urls_all[n_val:]
urls_test = [row["coco_url"] for row in val_meta]  # official val2017 as test

print(f"Train: {len(urls_train)} | Val: {len(urls_val)} | Test: {len(urls_test)}")

# =================== LOADERS ========================
def collate(batch):
    L  = torch.stack([b["L"]  for b in batch], dim=0)
    ab = torch.stack([b["ab"] for b in batch], dim=0)
    return {"L": L, "ab": ab}

train_ds = CocoURLColorizeFromURLs(urls_train, crop=CROP, train=True,  hflip=True)
val_ds   = CocoURLColorizeFromURLs(urls_val,   crop=CROP, train=False, hflip=False)
test_ds  = CocoURLColorizeFromURLs(urls_test,  crop=CROP, train=False, hflip=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=collate, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=collate)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=collate)

print("Batches -> train:", len(train_loader), "| val:", len(val_loader), "| test:", len(test_loader))

# ================= SANITY CHECK =====================
batch = next(iter(train_loader))
L, ab = batch["L"], batch["ab"]
print("L:",  L.shape,  f"[{L.min().item():.3f}, {L.max().item():.3f}]")
print("ab:", ab.shape, f"[{ab.min().item():.3f}, {ab.max().item():.3f}]")


Train: 106655 | Val: 5613 | Test: 5000
Batches -> train: 6665 | val: 351 | test: 313
L: torch.Size([16, 1, 256, 256]) [0.000, 1.000]
ab: torch.Size([16, 2, 256, 256]) [-0.817, 0.859]
