In [1]:
!pip install openai-clip



In [89]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import clip
from tqdm import tqdm

class StyleTransfer:
    def __init__(self, clip_model_name="ViT-B/32"):
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        # Load CLIP model
        self.model, self.preprocess = clip.load(clip_model_name, device=self.device)

        # Freeze the CLIP model parameters
        for param in self.model.parameters():
            param.requires_grad = False

    def get_text_features(self, text):
        """Get CLIP text features for style"""
        text_tokens = clip.tokenize([text]).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return text_features

    def transfer_style(self,
                      content_image_path,
                      style_text,
                      num_steps=300,
                      lr=0.5,
                      content_weight=0.0,
                      style_weight=1.0,
                      tv_weight=0.001):
        """Perform style transfer"""
        # Load and preprocess content image
        content_pil = Image.open(content_image_path).convert("RGB")

        # Get original size for later
        original_size = content_pil.size

        # Process for CLIP input (resize to what CLIP expects)
        content_clip = self.preprocess(content_pil).unsqueeze(0).to(self.device)

        # IMPORTANT: Create a fresh tensor for optimization, detached from computation graph
        # Start with random noise and then add content
        opt_img = torch.randn(content_clip.shape, device=self.device) * 0.1
        # Add content image (without gradients)
        with torch.no_grad():
            opt_img = torch.clamp(opt_img + content_clip.detach(), 0, 1)
        # Make sure it's a leaf tensor requiring gradients
        opt_img.requires_grad_(True)

        # Get text features for style
        text_features = self.get_text_features(style_text)

        # Get content features (if needed for content preservation)
        if content_weight > 0:
            with torch.no_grad():
                content_features = self.model.encode_image(content_clip)
                content_features = content_features / content_features.norm(dim=-1, keepdim=True)

        # Setup optimizer with higher learning rate
        optimizer = optim.Adam([opt_img], lr=lr)

        # Use a scheduler for better convergence
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        # Track best result
        best_loss = float('inf')
        best_img = None

        # Main optimization loop
        for i in tqdm(range(num_steps)):
            optimizer.zero_grad()

            # Get current image features
            image_features = self.model.encode_image(opt_img)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            # Style loss - cosine distance to style text embedding
            style_loss = style_weight * (1 - torch.cosine_similarity(image_features, text_features))

            # Content loss - if requested
            if content_weight > 0:
                content_loss = content_weight * (1 - torch.cosine_similarity(image_features, content_features))
            else:
                content_loss = torch.tensor(0.0, device=self.device)

            # Total variation regularization for smoothness
            diff_y = torch.abs(opt_img[:, :, :-1, :] - opt_img[:, :, 1:, :])
            diff_x = torch.abs(opt_img[:, :, :, :-1] - opt_img[:, :, :, 1:])
            tv_loss = tv_weight * (torch.sum(diff_x) + torch.sum(diff_y))

            # Combine losses
            total_loss = style_loss + content_loss + tv_loss

            # Backward pass
            total_loss.backward()

            # Check gradients - helpful for debugging
            if i == 0:
                grad_norm = opt_img.grad.norm().item()
                print(f"Gradient norm: {grad_norm}")
                if grad_norm < 1e-4:
                    print("WARNING: Very small gradients. Trying higher learning rate or different model.")

            # Update image
            optimizer.step()
            scheduler.step()

            # Clamp values to valid image range
            with torch.no_grad():
                opt_img.data.clamp_(0, 1)

            # Track best result
            if total_loss.item() < best_loss:
                best_loss = total_loss.item()
                best_img = opt_img.clone().detach()

            # Print progress
            if i % 20 == 0 or i == num_steps - 1:
                print(f"Step {i}, Style: {style_loss.item():.4f}, Content: {content_loss.item():.4f}, TV: {tv_loss.item():.4f}, Total: {total_loss.item():.4f}")

        # Use best result
        final_img = best_img if best_img is not None else opt_img.detach()

        # Convert to PIL and resize back to original dimensions
        to_pil = transforms.ToPILImage()
        result = to_pil(final_img.squeeze().cpu())

        # Resize back to original size if needed
        if result.size != original_size:
            result = result.resize(original_size, Image.LANCZOS)

        return result

# Example usage
def main(save_image = "imaged"):
    styler = StyleTransfer()
    content_path = "face.jpg"
    style_text = "Woman standing in a red dress, blue hair."

    result = styler.transfer_style(
        content_path,
        style_text,
        num_steps=100,
        lr=1,           # Higher learning rate for better optimization
        content_weight=0.0,  # Small content weight to maintain some structure
        style_weight=1,    # Emphasis on style
        tv_weight=1e-3,      # Total variation for smoothness
    )

    result.save(f"{save_image}.jpg")
    print("Style transfer complete!")

if __name__ == "__main__":
  main()

Using device: cuda


  5%|▌         | 5/100 [00:00<00:02, 43.33it/s]

Gradient norm: 0.5797377228736877
Step 0, Style: 0.8789, Content: 0.0000, TV: 10.9085, Total: 11.7812


 25%|██▌       | 25/100 [00:00<00:01, 43.52it/s]

Step 20, Style: 0.5669, Content: 0.0000, TV: 50.7933, Total: 51.3438


 50%|█████     | 50/100 [00:01<00:01, 44.54it/s]

Step 40, Style: 0.5234, Content: 0.0000, TV: 30.7401, Total: 31.2500


 64%|██████▍   | 64/100 [00:01<00:01, 35.03it/s]

Step 60, Style: 0.5059, Content: 0.0000, TV: 17.3518, Total: 17.8594


 84%|████████▍ | 84/100 [00:02<00:00, 33.36it/s]

Step 80, Style: 0.4834, Content: 0.0000, TV: 6.6991, Total: 7.1836


100%|██████████| 100/100 [00:02<00:00, 36.86it/s]


Step 99, Style: 0.4512, Content: 0.0000, TV: 1.3135, Total: 1.7646
Style transfer complete!
