In [23]:
import clip
import torch
from PIL import Image
import torchvision.transforms as transforms

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Load an image
# image_path = "C:/Users/naska/fiftyone/coco-2017/train/data/000000000009.jpg"
image_path = "./image.jpg"
image = Image.open(image_path)

# Convert the image to RGB (if it's not already in RGB format)
# image = image.convert("RGB")

# Preprocess the image
image_input = preprocess(image).unsqueeze(0).to(device)

# Encode the image
with torch.no_grad():
    image_features = model.encode_image(image_input)

# Choose a caption for style transfer
caption = "a painting of a forest"

# Encode the text
text_input = clip.tokenize([caption]).to(device)
text_features = model.encode_text(text_input)

# Initialize a random image for optimization
stylized_image = image_input.clone()

# Set optimization parameters
optimizer = torch.optim.Adam([stylized_image], lr=10)

# Optimization loop
for _ in range(1000):
    optimizer.zero_grad()
    
    # Encode the stylized image
    stylized_image_features = model.encode_image(stylized_image)
    
    # Loss: minimize the distance between stylized image features and both image and text features
    loss = torch.nn.functional.mse_loss(stylized_image_features, image_features) + torch.nn.functional.mse_loss(stylized_image_features, text_features)
    
    # Backpropagation
    loss.backward(retain_graph=True)
    optimizer.step()

# Convert the stylized image tensor to a PIL image
stylized_image = transforms.functional.to_pil_image(stylized_image.squeeze().cpu())

# Display the stylized image
stylized_image.show()
