In [7]:
import os
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class SketchyCOCODataset(Dataset):
    def __init__(self, sketch_dir, gt_dir, transform=None):
        self.sketch_dir = sketch_dir
        self.gt_dir = gt_dir
        self.transform = transform
        
        # Recursively find all image files in sketch and GT directories
        sketch_files = self._get_all_files(sketch_dir)
        gt_files = self._get_all_files(gt_dir)

        # Match files by their basenames
        sketch_files_dict = {os.path.basename(f): f for f in sketch_files}
        gt_files_dict = {os.path.basename(f): f for f in gt_files}

        # Find common files
        common_files = set(sketch_files_dict.keys()) & set(gt_files_dict.keys())

        # Create aligned lists of file paths
        self.sketch_files = [sketch_files_dict[f] for f in sorted(common_files)]
        self.gt_files = [gt_files_dict[f] for f in sorted(common_files)]

        assert len(self.sketch_files) == len(self.gt_files), "Mismatch between sketches and GT images"
        
    def _get_all_files(self, root_dir):
        """Recursively get all image file paths from a directory."""
        image_files = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith(('.png', '.jpg', '.jpeg')):  # Include valid image formats
                    image_files.append(os.path.join(subdir, file))
        return sorted(image_files)  # Sort to ensure matching order

    def __len__(self):
        return len(self.sketch_files)

    def __getitem__(self, idx):
        # Get relative paths
        sketch_path = self.sketch_files[idx]
        gt_path = self.gt_files[idx]
        
        sketch = Image.open(sketch_path).convert("RGB")
        gt_image = Image.open(gt_path).convert("RGB")
        
        if self.transform:
            sketch = self.transform(sketch)
            gt_image = self.transform(gt_image)
        
        return sketch, gt_image

from torchvision import transforms

# Paths to SketchyCOCO folders (update these paths as needed)
train_sketch_dir = "./Object/Sketch/train"
train_gt_dir = "./Object/GT/train"
val_sketch_dir = "./Object/Sketch/val"
val_gt_dir = "./Object/GT/val"

# Define transformations for images (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to CLIP's input size
    transforms.ToTensor()
])

# Create datasets and dataloaders
train_dataset = SketchyCOCODataset(train_sketch_dir, train_gt_dir, transform=transform)
val_dataset = SketchyCOCODataset(val_sketch_dir, val_gt_dir, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")

Loaded 22718 training samples and 1340 validation samples.


In [9]:
# Test DataLoader
for i, (sketches, images) in enumerate(train_loader):
    print(f"Batch {i+1}: Sketch shape: {sketches.shape}, Image shape: {images.shape}")
    if i == 2:  # Test only a few batches
        break


Batch 1: Sketch shape: torch.Size([32, 3, 224, 224]), Image shape: torch.Size([32, 3, 224, 224])
Batch 2: Sketch shape: torch.Size([32, 3, 224, 224]), Image shape: torch.Size([32, 3, 224, 224])
Batch 3: Sketch shape: torch.Size([32, 3, 224, 224]), Image shape: torch.Size([32, 3, 224, 224])


In [10]:
print(f"Number of sketch files: {len(train_dataset.sketch_files)}")
print(f"Number of GT files: {len(train_dataset.gt_files)}")

Number of sketch files: 22718
Number of GT files: 22718


In [11]:
# Load pre-trained CLIP model and processor
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
# Training loop with debugging prints
for epoch in range(10):  # Number of epochs
    print(f"Starting Epoch {epoch+1}...")
    model.train()
    total_loss = 0
    
    for batch_idx, (sketches, images) in enumerate(train_loader):
        print(f"Processing Batch {batch_idx+1}...")
        
        try:
            # Move data to GPU
            sketches, images = sketches.to(device), images.to(device)

            # Forward pass through vision encoder for sketches and images
            sketch_features = model.vision_model(pixel_values=sketches).pooler_output  # Sketch embeddings
            image_features = model.vision_model(pixel_values=images).pooler_output    # Image embeddings

            # Normalize embeddings (optional but improves stability)
            sketch_features = sketch_features / sketch_features.norm(dim=-1, keepdim=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            # Compute similarity scores (dot product of normalized embeddings)
            logits_per_image_sketches = torch.matmul(image_features, sketch_features.t())

            # Ground truth labels (diagonal matrix)
            labels = torch.arange(len(sketches)).to(device)

            # Compute contrastive loss
            loss_img2sketches = loss_fn(logits_per_image_sketches, labels)

            loss = loss_img2sketches

            # Backpropagation and optimization step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        except Exception as e:
            print(f"Error in Batch {batch_idx+1}: {e}")
    
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")


# Save the fine-tuned model
model.save_pretrained("fine_tuned_clip")
processor.save_pretrained("fine_tuned_clip")

Starting Epoch 1...
Processing Batch 1...
Processing Batch 2...
Processing Batch 3...
Processing Batch 4...
Processing Batch 5...
Processing Batch 6...
Processing Batch 7...
Processing Batch 8...
Processing Batch 9...
Processing Batch 10...
Processing Batch 11...
Processing Batch 12...
Processing Batch 13...
Processing Batch 14...
Processing Batch 15...
Processing Batch 16...
Processing Batch 17...
Processing Batch 18...
Processing Batch 19...
Processing Batch 20...
Processing Batch 21...
Processing Batch 22...
Processing Batch 23...
Processing Batch 24...
Processing Batch 25...
Processing Batch 26...
Processing Batch 27...
Processing Batch 28...
Processing Batch 29...
Processing Batch 30...
Processing Batch 31...
Processing Batch 32...
Processing Batch 33...
Processing Batch 34...
Processing Batch 35...
Processing Batch 36...
Processing Batch 37...
Processing Batch 38...
Processing Batch 39...
Processing Batch 40...
Processing Batch 41...
Processing Batch 42...
Processing Batch 43...


KeyboardInterrupt: 