COCO_50.zip (50 sample for each one of the two classes I used in the research "Zebra & Giraffe"): 


https://drive.google.com/file/d/1_g1yMXFhHRK9EGr17tFs8a3XSMON34oZ/view

# Get Dataset

In [1]:
import os
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content
!mkdir zip
%cd zip
!wget https://www.7-zip.org/a/7z2301-linux-x64.tar.xz
!tar -xf 7z2301-linux-x64.tar.xz
%cd ..

/content
/content/zip
--2025-06-27 10:44:17--  https://www.7-zip.org/a/7z2301-linux-x64.tar.xz
Resolving www.7-zip.org (www.7-zip.org)... 49.12.202.237
Connecting to www.7-zip.org (www.7-zip.org)|49.12.202.237|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1527700 (1.5M) [application/octet-stream]
Saving to: ‘7z2301-linux-x64.tar.xz’


2025-06-27 10:44:18 (2.00 MB/s) - ‘7z2301-linux-x64.tar.xz’ saved [1527700/1527700]

/content


In [3]:
!/content/zip/7zz x "/content/drive/MyDrive/COCO_50.zip" -o"/content" -y


7-Zip (z) 23.01 (x64) : Copyright (c) 1999-2023 Igor Pavlov : 2023-06-20
 64-bit locale=en_US.UTF-8 Threads:2 OPEN_MAX:1048576, ASM

Scanning the drive for archives:
  0M Scan /content/drive/MyDrive/                                 1 file, 10424099 bytes (10180 KiB)

Extracting archive: /content/drive/MyDrive/COCO_50.zip
--
Path = /content/drive/MyDrive/COCO_50.zip
Type = zip
Physical Size = 10424099

  0%    Everything is Ok

Folders: 13
Files: 400
Size:       11652329
Compressed: 10424099


# Dataset Class

In [18]:
import os
from glob import glob
from torch.utils.data import Dataset
from PIL import Image


class CustomDataset(Dataset):
    def __init__(self, data_root="/content/COCO_50", classes = ["giraffe", "zebra"], transform=None):
        self.data_root = data_root
        self.classes = classes

        self.sketch_paths = []

        for cls in classes:
            img_dir = os.path.join(data_root, "COCOSketch", cls)
            sketch_dir = os.path.join(data_root, "sketch", cls)

            for img_path in glob(os.path.join(img_dir, "*.png")):
                filename = os.path.basename(img_path)
                sketch_path = os.path.join(sketch_dir, filename)

                self.sketch_paths.append(img_path)
                self.sketch_paths.append(sketch_path)


        self.transform = transform

        self.prompts = {"zebra": "a zebra in an empty savanna, dry grass plains, no trees, clear sky, minimalist nature, photorealistic", "giraffe" : "a giraffe in an open barren plain, flat dry land, no vegetation, soft sunlight, isolated wildlife, high detail"}

    def __len__(self):
        return len(self.sketch_paths)

    def get_sketch(self, idx):
      return Image.open(self.sketch_paths[idx])

    def __getitem__(self, idx):
        sketch = Image.open(self.sketch_paths[idx])

        if self.transform:
            sketch = self.transform(sketch)

        class_name = os.path.basename(os.path.dirname(self.sketch_paths[idx]))

        prompt = self.prompts[class_name]
        # prompt = f"a {class_name}"
        label = self.classes.index(class_name)

        return sketch, label

from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                         std=[0.229, 0.224, 0.225])
])


dataset = CustomDataset(transform=transform)

In [19]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

labels = [dataset[i][1] for i in range(len(dataset))]

train_indices, test_indices = train_test_split(
    list(range(len(labels))),
    test_size=0.2,
    stratify=labels,
    random_state=42
)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

In [20]:
# Extract labels from train and test datasets
train_labels = [dataset[i][1] for i in train_dataset.indices]
test_labels = [dataset[i][1] for i in test_dataset.indices]
from collections import Counter

train_label_counts = Counter(train_labels)
test_label_counts = Counter(test_labels)

def print_distribution(label_counts, total_samples):
    for cls, count in label_counts.items():
        print(f"Class {cls}: {count} samples, {100 * count / total_samples:.2f}%")

print("Train set distribution:")
print_distribution(train_label_counts, len(train_labels))

print("Test set distribution:")
print_distribution(test_label_counts, len(test_labels))


Train set distribution:
Class 0: 80 samples, 50.00%
Class 1: 80 samples, 50.00%
Test set distribution:
Class 0: 20 samples, 50.00%
Class 1: 20 samples, 50.00%


# Classification

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

class SketchClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(SketchClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 128x128

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),  # 256x1x1
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )


    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [22]:
import torch
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    correct = 0
    total = 0

    for batch, ( X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        total += y.size(0)

        if total % 80 == 0:
            accuracy = correct / total
            print(f"Loss: {loss.item():.4f}, Accuracy: {accuracy * 100:.2f}%")

        del X, y, pred, loss
        torch.cuda.empty_cache()

In [23]:
def test_loop(dataloader, model, loss_fn):
    model.eval()
    test_loss, correct = 0, 0
    total = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)

            del X, y, pred
            torch.cuda.empty_cache()

    accuracy = correct / total
    avg_loss = test_loss / len(dataloader)
    print(f"Test Error: Accuracy: {accuracy * 100:.2f}%, Avg loss: {avg_loss:.4f}")
    return avg_loss

In [None]:
epochs = 50

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SketchClassifier(num_classes=2).to(device)

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)


for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_loop(train_loader, model, loss_fn, optimizer)

    val_loss = test_loop(test_loader, model, loss_fn)
    scheduler.step(val_loss)


Epoch 1/50
Loss: 0.6144, Accuracy: 73.75%
Loss: 0.5777, Accuracy: 76.25%
Test Error: Accuracy: 80.00%, Avg loss: 0.6442

Epoch 2/50
Loss: 0.6174, Accuracy: 87.50%
Loss: 0.5576, Accuracy: 83.75%
Test Error: Accuracy: 85.00%, Avg loss: 0.5673

Epoch 3/50
Loss: 0.4948, Accuracy: 87.50%
Loss: 0.4450, Accuracy: 87.50%
Test Error: Accuracy: 90.00%, Avg loss: 0.5305

Epoch 4/50
Loss: 0.5062, Accuracy: 80.00%
Loss: 0.4239, Accuracy: 83.75%
Test Error: Accuracy: 92.50%, Avg loss: 0.5145

Epoch 5/50
Loss: 0.4260, Accuracy: 83.75%
Loss: 0.7640, Accuracy: 82.50%
Test Error: Accuracy: 70.00%, Avg loss: 0.5465

Epoch 6/50
Loss: 0.4396, Accuracy: 82.50%
Loss: 0.4131, Accuracy: 88.12%
Test Error: Accuracy: 92.50%, Avg loss: 0.4666

Epoch 7/50
Loss: 0.6552, Accuracy: 85.00%
Loss: 0.3832, Accuracy: 88.75%
Test Error: Accuracy: 90.00%, Avg loss: 0.4842

Epoch 8/50
Loss: 0.4101, Accuracy: 92.50%
Loss: 0.3739, Accuracy: 91.88%
Test Error: Accuracy: 90.00%, Avg loss: 0.4255

Epoch 9/50
Loss: 0.3226, Accura

In [30]:
import torch

save_path = "/content/drive/MyDrive/SketchClassifier.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")


# model = SketchClassifier(num_classes=2)
# model.load_state_dict(torch.load("/content/drive/MyDrive/SketchClassifier.pth"))
# model.to(device)
# model.eval()

Model saved to /content/drive/MyDrive/SketchClassifier.pth


In [27]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

model = models.resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, 2)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001, weight_decay=1e-5)  # L2 regularization
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.5, patience=5,
                                                       verbose=True)

num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for sketch, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        sketch, labels = sketch.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(sketch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    scheduler.step(avg_loss)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for sketch, labels in test_loader:
        sketch, labels = sketch.to(device), labels.to(device)
        outputs = model(sketch)
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Epoch 1/50: 100%|██████████| 20/20 [00:00<00:00, 23.34it/s]


Epoch [1/50], Loss: 0.6882


Epoch 2/50: 100%|██████████| 20/20 [00:01<00:00, 18.59it/s]


Epoch [2/50], Loss: 0.6482


Epoch 3/50: 100%|██████████| 20/20 [00:01<00:00, 19.12it/s]


Epoch [3/50], Loss: 0.6220


Epoch 4/50: 100%|██████████| 20/20 [00:01<00:00, 19.00it/s]


Epoch [4/50], Loss: 0.5871


Epoch 5/50: 100%|██████████| 20/20 [00:00<00:00, 24.05it/s]


Epoch [5/50], Loss: 0.5409


Epoch 6/50: 100%|██████████| 20/20 [00:00<00:00, 24.89it/s]


Epoch [6/50], Loss: 0.5197


Epoch 7/50: 100%|██████████| 20/20 [00:00<00:00, 25.10it/s]


Epoch [7/50], Loss: 0.5166


Epoch 8/50: 100%|██████████| 20/20 [00:00<00:00, 24.85it/s]


Epoch [8/50], Loss: 0.4529


Epoch 9/50: 100%|██████████| 20/20 [00:00<00:00, 24.91it/s]


Epoch [9/50], Loss: 0.4458


Epoch 10/50: 100%|██████████| 20/20 [00:00<00:00, 24.35it/s]


Epoch [10/50], Loss: 0.4450


Epoch 11/50: 100%|██████████| 20/20 [00:00<00:00, 24.61it/s]


Epoch [11/50], Loss: 0.4193


Epoch 12/50: 100%|██████████| 20/20 [00:00<00:00, 24.71it/s]


Epoch [12/50], Loss: 0.4169


Epoch 13/50: 100%|██████████| 20/20 [00:00<00:00, 24.93it/s]


Epoch [13/50], Loss: 0.4340


Epoch 14/50: 100%|██████████| 20/20 [00:00<00:00, 25.22it/s]


Epoch [14/50], Loss: 0.3664


Epoch 15/50: 100%|██████████| 20/20 [00:00<00:00, 24.66it/s]


Epoch [15/50], Loss: 0.3775


Epoch 16/50: 100%|██████████| 20/20 [00:00<00:00, 24.83it/s]


Epoch [16/50], Loss: 0.3659


Epoch 17/50: 100%|██████████| 20/20 [00:00<00:00, 21.25it/s]


Epoch [17/50], Loss: 0.3811


Epoch 18/50: 100%|██████████| 20/20 [00:01<00:00, 19.12it/s]


Epoch [18/50], Loss: 0.3268


Epoch 19/50: 100%|██████████| 20/20 [00:01<00:00, 18.82it/s]


Epoch [19/50], Loss: 0.3303


Epoch 20/50: 100%|██████████| 20/20 [00:00<00:00, 21.59it/s]


Epoch [20/50], Loss: 0.2915


Epoch 21/50: 100%|██████████| 20/20 [00:00<00:00, 25.05it/s]


Epoch [21/50], Loss: 0.3293


Epoch 22/50: 100%|██████████| 20/20 [00:00<00:00, 24.34it/s]


Epoch [22/50], Loss: 0.3123


Epoch 23/50: 100%|██████████| 20/20 [00:00<00:00, 24.59it/s]


Epoch [23/50], Loss: 0.2878


Epoch 24/50: 100%|██████████| 20/20 [00:00<00:00, 24.51it/s]


Epoch [24/50], Loss: 0.2656


Epoch 25/50: 100%|██████████| 20/20 [00:00<00:00, 24.51it/s]


Epoch [25/50], Loss: 0.3052


Epoch 26/50: 100%|██████████| 20/20 [00:00<00:00, 24.95it/s]


Epoch [26/50], Loss: 0.2678


Epoch 27/50: 100%|██████████| 20/20 [00:00<00:00, 24.69it/s]


Epoch [27/50], Loss: 0.2717


Epoch 28/50: 100%|██████████| 20/20 [00:00<00:00, 24.63it/s]


Epoch [28/50], Loss: 0.2362


Epoch 29/50: 100%|██████████| 20/20 [00:00<00:00, 24.50it/s]


Epoch [29/50], Loss: 0.2552


Epoch 30/50: 100%|██████████| 20/20 [00:00<00:00, 24.72it/s]


Epoch [30/50], Loss: 0.2446


Epoch 31/50: 100%|██████████| 20/20 [00:00<00:00, 24.79it/s]


Epoch [31/50], Loss: 0.2515


Epoch 32/50: 100%|██████████| 20/20 [00:00<00:00, 22.71it/s]


Epoch [32/50], Loss: 0.2448


Epoch 33/50: 100%|██████████| 20/20 [00:01<00:00, 18.64it/s]


Epoch [33/50], Loss: 0.2500


Epoch 34/50: 100%|██████████| 20/20 [00:01<00:00, 18.62it/s]


Epoch [34/50], Loss: 0.2731


Epoch 35/50: 100%|██████████| 20/20 [00:00<00:00, 20.12it/s]


Epoch [35/50], Loss: 0.1994


Epoch 36/50: 100%|██████████| 20/20 [00:00<00:00, 24.19it/s]


Epoch [36/50], Loss: 0.2525


Epoch 37/50: 100%|██████████| 20/20 [00:00<00:00, 24.73it/s]


Epoch [37/50], Loss: 0.2431


Epoch 38/50: 100%|██████████| 20/20 [00:00<00:00, 24.83it/s]


Epoch [38/50], Loss: 0.2658


Epoch 39/50: 100%|██████████| 20/20 [00:00<00:00, 24.74it/s]


Epoch [39/50], Loss: 0.2178


Epoch 40/50: 100%|██████████| 20/20 [00:00<00:00, 24.82it/s]


Epoch [40/50], Loss: 0.2073


Epoch 41/50: 100%|██████████| 20/20 [00:00<00:00, 24.53it/s]


Epoch [41/50], Loss: 0.2245


Epoch 42/50: 100%|██████████| 20/20 [00:00<00:00, 25.02it/s]


Epoch [42/50], Loss: 0.1822


Epoch 43/50: 100%|██████████| 20/20 [00:00<00:00, 24.78it/s]


Epoch [43/50], Loss: 0.2267


Epoch 44/50: 100%|██████████| 20/20 [00:00<00:00, 24.29it/s]


Epoch [44/50], Loss: 0.2412


Epoch 45/50: 100%|██████████| 20/20 [00:00<00:00, 24.13it/s]


Epoch [45/50], Loss: 0.1782


Epoch 46/50: 100%|██████████| 20/20 [00:00<00:00, 24.35it/s]


Epoch [46/50], Loss: 0.2230


Epoch 47/50: 100%|██████████| 20/20 [00:00<00:00, 24.09it/s]


Epoch [47/50], Loss: 0.2170


Epoch 48/50: 100%|██████████| 20/20 [00:01<00:00, 18.85it/s]


Epoch [48/50], Loss: 0.2090


Epoch 49/50: 100%|██████████| 20/20 [00:01<00:00, 18.61it/s]


Epoch [49/50], Loss: 0.2362


Epoch 50/50: 100%|██████████| 20/20 [00:01<00:00, 19.37it/s]


Epoch [50/50], Loss: 0.2222
Test Accuracy: 100.00%


In [28]:
# Save the model weights only (recommended way)
save_path = "/content/drive/MyDrive/ResNet18_SketchClassifier.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# model = models.resnet18(pretrained=True)
# for param in model.parameters():
#     param.requires_grad = False
# model.fc = nn.Linear(model.fc.in_features, 2)

# model.load_state_dict(torch.load("/content/drive/MyDrive/ResNet18_SketchClassifier.pth"))
# model.to(device)
# model.eval()

Model saved to /content/drive/MyDrive/ResNet18_SketchClassifier.pth
