In [None]:
import cv2
import os
from torchvision.datasets import Imagenette
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import glob

#cartoonify filter
def cartoonify(img, d, sigmaColor, sigmaSpace, edge_size=9):

    assert edge_size % 2 == 1, "edge_size must be odd"
    color = cv2.bilateralFilter(img, d=d, sigmaColor=sigmaColor, sigmaSpace=sigmaSpace)

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray_blur = cv2.medianBlur(gray, 7)

    edges = cv2.adaptiveThreshold(gray_blur, 255,
                              cv2.ADAPTIVE_THRESH_MEAN_C,
                                  cv2.THRESH_BINARY, edge_size, 9)


    cartoon = cv2.bitwise_and(color, color, mask=edges)
return cartoon

save_input_dir = "data/celeba_pairs/input"
save_target_dir = "data/celeba_pairs/target"
os.makedirs(save_input_dir, exist_ok=True)
os.makedirs(save_target_dir, exist_ok=True)


dataset_root = "/content/drive/MyDrive/CV_project/"
dataset = Imagenette(root="data",split="train", size="320px", download=True)

for idx in tqdm(range(len(dataset))):
    img, _ = dataset[idx]
    img = img.convert("RGB")

    input_path = os.path.join(save_input_dir, f"{idx:06d}.jpg")
    img.save(input_path)

    img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)

    cartoon_cv = cartoonify(img=img_cv, d=20, sigmaColor=100, sigmaSpace=100, edge_size=25)

    target_path = os.path.join(save_target_dir, f"{idx:06d}.jpg")
    cv2.imwrite(target_path, cartoon_cv)


100%|██████████| 9469/9469 [21:08<00:00,  7.47it/s]


In [None]:
#!mv data /content/drive/MyDrive/CV_project/

In [None]:
#half an hour break ! :)

# import time

# for i in tqdm(range(900)):
#   time.sleep(2)

100%|██████████| 900/900 [30:00<00:00,  2.00s/it]


In [None]:
dataset_root = "/content/drive/MyDrive/CV_project/"

In [None]:
from torch.utils.data import Dataset, DataLoader


class CartoonDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_files = sorted(glob.glob(os.path.join(input_dir, "*.jpg")))
        self.target_files = sorted(glob.glob(os.path.join(target_dir, "*.jpg")))
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        inp = Image.open(self.input_files[idx]).convert("RGB")
        tgt = Image.open(self.target_files[idx]).convert("RGB")

        if self.transform:
            inp = self.transform(inp)
            tgt = self.transform(tgt)
        return inp, tgt

In [None]:
import torch.nn as nn
import torch
import torch.optim as optim
class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = UNetBlock(3, 64)
        self.enc2 = UNetBlock(64, 128)
        self.enc3 = UNetBlock(128, 256)
        self.enc4 = UNetBlock(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.middle = UNetBlock(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = UNetBlock(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = UNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = UNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = UNetBlock(128, 64)

        self.final = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        m = self.middle(self.pool(e4))

        d4 = self.up4(m)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return torch.sigmoid(self.final(d1))

In [None]:
from torch.utils.data import random_split


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
input_path = dataset_root + save_input_dir
target_path = dataset_root + save_target_dir



print(input_path)
print(target_path)

dataset = CartoonDataset(input_path, target_path, transform=transform)

train_size = int(0.7 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

model = UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

/content/drive/MyDrive/CV_project/data/celeba_pairs/input
/content/drive/MyDrive/CV_project/data/celeba_pairs/target


In [None]:
epochs = 20
patience = 2
best_val_loss = float("inf")
counter = 0
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=1
)

for epoch in tqdm(range(epochs)):
    train_counter = 0
    val_counter = 0
    model.train()
    epoch_loss = 0
    for inp, tgt in train_loader:
        train_loader_size = len(train_loader)
        if train_counter % 20 == 0:
            print(f"Train batch: {train_counter}/{train_loader_size}")
        train_counter += 1

        inp, tgt = inp.to(device), tgt.to(device)

        optimizer.zero_grad()
        out = model(inp)
        loss = criterion(out, tgt)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inp, tgt in val_loader:
            val_loader_size = len(val_loader)
            if val_counter % 20 == 0:
               print(f"Val batch: {val_counter}/{val_loader_size}")
            val_counter += 1
            inp, tgt = inp.to(device), tgt.to(device)
            out = model(inp)
            loss = criterion(out, tgt)
            val_loss += loss.item()
    val_loss /= len(val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), "cartoonifier.pth")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break

    scheduler.step(val_loss)

    print(
        f"Epoch {epoch+1}/{epochs} "
        f"| Train Loss: {train_loss:.4f} "
        f"| Val Loss: {val_loss:.4f} "
        f"| LR: {optimizer.param_groups[0]['lr']:.6f}"
    )

print("Antrenare terminată. Cel mai bun model salvat ca cartoon_unet_best.pth")

  0%|          | 0/20 [00:00<?, ?it/s]

Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


  5%|▌         | 1/20 [08:19<2:38:15, 499.78s/it]

Epoch 1/20 | Train Loss: 0.0462 | Val Loss: 0.0307 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 10%|█         | 2/20 [16:39<2:29:59, 499.98s/it]

Epoch 2/20 | Train Loss: 0.0382 | Val Loss: 0.0277 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 15%|█▌        | 3/20 [24:59<2:21:39, 499.96s/it]

Epoch 3/20 | Train Loss: 0.0346 | Val Loss: 0.0262 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 20%|██        | 4/20 [33:19<2:13:14, 499.65s/it]

Epoch 4/20 | Train Loss: 0.0322 | Val Loss: 0.0244 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 25%|██▌       | 5/20 [41:37<2:04:48, 499.21s/it]

Epoch 5/20 | Train Loss: 0.0309 | Val Loss: 0.0253 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 30%|███       | 6/20 [49:55<1:56:25, 498.94s/it]

Epoch 6/20 | Train Loss: 0.0301 | Val Loss: 0.0223 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 35%|███▌      | 7/20 [58:14<1:48:04, 498.85s/it]

Epoch 7/20 | Train Loss: 0.0283 | Val Loss: 0.0237 | LR: 0.000100
Train batch: 0/415
Train batch: 20/415
Train batch: 40/415
Train batch: 60/415
Train batch: 80/415
Train batch: 100/415
Train batch: 120/415
Train batch: 140/415
Train batch: 160/415
Train batch: 180/415
Train batch: 200/415
Train batch: 220/415
Train batch: 240/415
Train batch: 260/415
Train batch: 280/415
Train batch: 300/415
Train batch: 320/415
Train batch: 340/415
Train batch: 360/415
Train batch: 380/415
Train batch: 400/415
Val batch: 0/178
Val batch: 20/178
Val batch: 40/178
Val batch: 60/178
Val batch: 80/178
Val batch: 100/178
Val batch: 120/178
Val batch: 140/178
Val batch: 160/178


 35%|███▌      | 7/20 [1:06:32<2:03:35, 570.43s/it]

Early stopping!
Antrenare terminată. Cel mai bun model salvat ca cartoon_unet_best.pth



