In [None]:
!pip install -q kaggle
!pip install -q torch torchvision torchaudio
!pip install -q tqdm pillow matplotlib

In [None]:
import kagglehub

path = kagglehub.dataset_download("awsaf49/coco-2017-dataset")
print("Path to dataset files:", path)

Using Colab cache for faster access to the 'coco-2017-dataset' dataset.
Path to dataset files: /kaggle/input/coco-2017-dataset


In [None]:
import shutil

shutil.copytree(path, "coco_dataset", dirs_exist_ok=True)

'coco_dataset'

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# =========================
# DATASET
# =========================
class ColorDataset(Dataset):
    def __init__(self, folder):
        self.folder = folder
        self.files = os.listdir(folder)
        self.tf = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = os.path.join(self.folder, self.files[i])
        img = Image.open(path).convert("RGB")
        img = self.tf(img)

        L = img.mean(0, keepdim=True)
        ab = img[:2]

        return L, ab


# =========================
# MODEL
# =========================
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,3,1,1),
            nn.ReLU(),
            nn.Conv2d(64,128,3,1,1),
            nn.ReLU(),
            nn.Conv2d(128,2,3,1,1),
            nn.Tanh()
        )

    def forward(self,x):
        return self.net(x)


train_dir = "/content/coco_dataset/coco2017/train2017"
val_dir   = "/content/coco_dataset/coco2017/val2017"

train_ds = ColorDataset(train_dir)
val_ds   = ColorDataset(val_dir)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)

model = Model().to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

os.makedirs("checkpoints", exist_ok=True)

best = 999

for epoch in range(5):
    model.train()
    tl = 0
    for L,ab in tqdm(train_loader):
        L,ab = L.to(device), ab.to(device)

        pred = model(L)
        loss = loss_fn(pred, ab)

        opt.zero_grad()
        loss.backward()
        opt.step()
        tl += loss.item()

    model.eval()
    vl = 0
    with torch.no_grad():
        for L,ab in val_loader:
            L,ab = L.to(device), ab.to(device)
            pred = model(L)
            vl += loss_fn(pred,ab).item()

    tl /= len(train_loader)
    vl /= len(val_loader)

    print("epoch",epoch,"train",tl,"val",vl)

    if vl < best:
        best = vl
        torch.save(model.state_dict(),"checkpoints/best_model.pth")

torch.save(model.state_dict(),"checkpoints/final_model.pth")
print("Training complete → final_model.pth saved")

device: cuda


100%|██████████| 7393/7393 [25:23<00:00,  4.85it/s]


epoch 0 train 0.0062483214694283775 val 0.005675940721566542


100%|██████████| 7393/7393 [25:26<00:00,  4.84it/s]


epoch 1 train 0.005626103225104026 val 0.005562068856568239


100%|██████████| 7393/7393 [25:23<00:00,  4.85it/s]


epoch 2 train 0.005558233540063004 val 0.0055006225697529585


100%|██████████| 7393/7393 [25:23<00:00,  4.85it/s]


epoch 3 train 0.0055323173976962485 val 0.005502750853838298


100%|██████████| 7393/7393 [25:22<00:00,  4.86it/s]


epoch 4 train 0.005520502617280342 val 0.005461273959632546


100%|██████████| 7393/7393 [25:21<00:00,  4.86it/s]


epoch 5 train 0.005503141142422378 val 0.005616389952165584


100%|██████████| 7393/7393 [25:22<00:00,  4.86it/s]


epoch 6 train 0.005484155490785337 val 0.005486196644913655


  1%|          | 81/7393 [00:17<26:07,  4.66it/s]


KeyboardInterrupt: 

In [None]:
from google.colab import files
files.download("checkpoints/best_.pth")

FileNotFoundError: Cannot find file: checkpoints/model.pth