## Load data



In [None]:
# extract data in processed_icons.zip file from Google Drive
import zipfile
from google.colab import drive
import os

drive.mount('/content/drive')
uploaded_zip_path = '/content/drive/My Drive/RM/processed_icons.zip'
extracted_dir = '/content/extracted_icons'

os.makedirs(extracted_dir, exist_ok=True)
if not os.listdir(extracted_dir):  # Check if the extraction folder is empty
    with zipfile.ZipFile(uploaded_zip_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_dir)
print("extract data complete")

In [None]:
# Ckeck number of icon images
import os

extracted_icons_dir = "/content/extracted_icons"
png_jpg_count = 0

for root, _, files in os.walk(extracted_icons_dir):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            png_jpg_count += 1

print(f"Number of .png or .jpg files in 'extracted_icons': {png_jpg_count}")


Number of .png or .jpg files in 'extracted_icons': 1296492


## Train Model

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Define dataset class
class IconDataset(Dataset):
    def __init__(self, root_dir, transform=None, sketch_transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.sketch_transform = sketch_transform
        self.filepairs = []

        # Find paired images (sketch and color icons)
        print("Finding image pairs...")
        for cls in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, cls)
            if os.path.isdir(class_dir):
                for subfolder in os.listdir(class_dir):
                    subfolder_dir = os.path.join(class_dir, subfolder)
                    if os.path.isdir(subfolder_dir):
                        sketch_path = None
                        color_icon_path = None

                        for file in os.listdir(subfolder_dir):
                            if "sketch_icon" in file:
                                sketch_path = os.path.join(subfolder_dir, file)
                            elif "color_icon" in file:
                                color_icon_path = os.path.join(subfolder_dir, file)

                        if sketch_path and color_icon_path:
                            self.filepairs.append((sketch_path, color_icon_path))

        print(f"Found {len(self.filepairs)} valid image pairs.")

    def __len__(self):
        return len(self.filepairs)

    def __getitem__(self, idx):
        sketch_icon_path, color_icon_path = self.filepairs[idx]

        # Load images
        sketch_icon = Image.open(sketch_icon_path).convert('L')  # Grayscale for sketches
        color_icon = Image.open(color_icon_path).convert('RGB')  # RGB for color icons

        # Apply transformations
        if self.sketch_transform:
            sketch_icon = self.sketch_transform(sketch_icon)
        if self.transform:
            color_icon = self.transform(color_icon)

        # Using a dummy label of 0
        label = 0
        return sketch_icon, color_icon, label

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # RGB normalization
])

sketch_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Grayscale normalization
])

# Load the dataset from Google Drive
data_dir = "/content/extracted_icons/processed_icons"
icon_dataset = IconDataset(root_dir=data_dir, transform=transform, sketch_transform=sketch_transform)

# Split the dataset into training, validation, and testing sets
def split_dataset(dataset, train_split=0.8, val_split=0.1):
    train_size = int(len(dataset) * train_split)
    val_size = int(len(dataset) * val_split)
    test_size = len(dataset) - train_size - val_size
    return torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_dataset, val_dataset, test_dataset = split_dataset(icon_dataset)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4),  # Output a single channel
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = torch.mean(x, dim=[2, 3])  # Global Average Pooling to reduce spatial dimensions
        x = self.sigmoid(x)
        return x

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Set up directories for saving checkpoints and models
save_dir = '/content/drive/My Drive/RM/saved_model'
os.makedirs(save_dir, exist_ok=True)

# Load checkpoints if available
def load_checkpoint():
    generator_ckpt = os.path.join(save_dir, 'generator_checkpoint.pth')
    discriminator_ckpt = os.path.join(save_dir, 'discriminator_checkpoint.pth')

    if os.path.exists(generator_ckpt):
        generator.load_state_dict(torch.load(generator_ckpt))
        print("Loaded Generator checkpoint.")

    if os.path.exists(discriminator_ckpt):
        discriminator.load_state_dict(torch.load(discriminator_ckpt))
        print("Loaded Discriminator checkpoint.")

# Load the checkpoint of the model to resume training
load_checkpoint()

# Training loop with checkpoint saving
epochs = 10
print("Starting training...")
for epoch in range(epochs):
    for i, (sketches, real_images, _) in enumerate(train_loader):
        sketches, real_images = sketches.to(device), real_images.to(device)
        batch_size = sketches.size(0)
        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)
        sketches = sketches.repeat(1, 3, 1, 1)  # Convert 1-channel sketches to 3-channel

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_images = generator(sketches)
        real_outputs = discriminator(torch.cat((sketches, real_images), dim=1))  # Real input
        fake_outputs = discriminator(torch.cat((sketches, fake_images), dim=1))  # Fake input
        d_loss = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels)
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_images = generator(sketches)
        outputs = discriminator(torch.cat((sketches, fake_images), dim=1))
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # Save checkpoints
    torch.save(generator.state_dict(), os.path.join(save_dir, 'generator_checkpoint.pth'))
    torch.save(discriminator.state_dict(), os.path.join(save_dir, 'discriminator_checkpoint.pth'))
    print(f"Checkpoints saved for epoch {epoch+1}.")

print("Training complete.")

Finding image pairs...
Found 648246 valid image pairs.
Using device: cuda
Starting training...
Epoch [1/10], Step [10/32413], D Loss: 1.1385, G Loss: 1.1567
Epoch [1/10], Step [20/32413], D Loss: 0.7459, G Loss: 2.5479
Epoch [1/10], Step [30/32413], D Loss: 0.2542, G Loss: 3.7096
Epoch [1/10], Step [40/32413], D Loss: 1.2030, G Loss: 2.6620
Epoch [1/10], Step [50/32413], D Loss: 0.7146, G Loss: 1.6203
Epoch [1/10], Step [60/32413], D Loss: 0.9042, G Loss: 1.8591
Epoch [1/10], Step [70/32413], D Loss: 1.2598, G Loss: 0.2452
Epoch [1/10], Step [80/32413], D Loss: 1.0510, G Loss: 0.9623
Epoch [1/10], Step [90/32413], D Loss: 1.0960, G Loss: 4.1219
Epoch [1/10], Step [100/32413], D Loss: 1.4111, G Loss: 0.5492
Epoch [1/10], Step [110/32413], D Loss: 1.0753, G Loss: 1.7104
Epoch [1/10], Step [120/32413], D Loss: 0.8812, G Loss: 1.8177
Epoch [1/10], Step [130/32413], D Loss: 0.9675, G Loss: 1.0166
Epoch [1/10], Step [140/32413], D Loss: 1.1612, G Loss: 0.7356
Epoch [1/10], Step [150/32413], 

## Test model

In [None]:
import torch
import torch.nn as nn
import os

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

#Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4),  # Output a single channel
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = torch.mean(x, dim=[2, 3])  # Global Average Pooling to reduce spatial dimensions
        x = self.sigmoid(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

save_dir = "/content/"
pic_dir = "/content/181107.png_output_image.png"
epoch = 1

# Load the saved state dictionaries
with open(os.path.join(save_dir, f"Icon_generator_epoch_{epoch}.pth"), 'rb') as f:
    state_dict = torch.load(f, map_location=device)
generator.load_state_dict(state_dict)

with open(os.path.join(save_dir, f"Icon_discriminator_epoch_{epoch}.pth"), 'rb') as f:
    state_dict = torch.load(f, map_location=device)
discriminator.load_state_dict(state_dict)
# Set models to evaluation mode (if needed)
generator.eval()
discriminator.eval()

Using device: cuda


  state_dict = torch.load(f, map_location=device)
  state_dict = torch.load(f, map_location=device)


Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  )
  (sigmoid): Sigmoid()
)

In [None]:
!ls /content/

181107.png_output_image.png	Icon_generator_epoch_1.pth
Icon_discriminator_epoch_1.pth	sample_data


In [None]:
# prompt: display result from this model by using 181107.png_output_image.png as input image.

import cv2
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import os

# Assuming the model and image are in the /content directory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the saved generator model
generator = Generator().to(device)
save_dir = "/content/"
epoch = 1
with open(os.path.join(save_dir, f"Icon_generator_epoch_{epoch}.pth"), 'rb') as f:
    state_dict = torch.load(f, map_location=device)
generator.load_state_dict(state_dict)
generator.eval()

# Load and preprocess the input image
img_path = "/content/181107.png_output_image.png"
# img_path = "/content/acorn-line-drawing-9.jpg"
try:
    img = Image.open(img_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)  # Add batch dimension
    with torch.no_grad():
        output_tensor = generator(img_tensor)

    # Post-process the output tensor
    output_image = output_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
    output_image = ((output_image + 1) * 127.5).astype(np.uint8)
    output_image = Image.fromarray(output_image)
    output_image.save("colored_icon_from_saved_model.png")
    print("Colored icon saved as 'colored_icon_from_saved_model.png'")

except FileNotFoundError:
    print(f"Error: Image file not found at {img_path}")
except Exception as e:
    print(f"An error occurred: {e}")

Colored icon saved as 'colored_icon_from_saved_model.png'


  state_dict = torch.load(f, map_location=device)
