In [2]:
import os
import torch
import clip
import numpy
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.utils as nn_utils  # For gradient norm computation
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import matplotlib.pyplot as plt
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
from transformers import CLIPProcessor, CLIPModel

class ImageTextDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transform

        # Get all .jpg files and check for corresponding non-empty .txt files
        self.image_filenames = sorted([f for f in os.listdir(data_dir) if f.endswith(".jpg")])
        
        # Filter out pairs with empty or missing captions
        self.image_filenames = [
            f for f in self.image_filenames 
            if os.path.exists(os.path.join(data_dir, f.replace(".jpg", ".txt"))) 
            and os.path.getsize(os.path.join(data_dir, f.replace(".jpg", ".txt"))) > 0
        ]
        
        print(f"Total valid image-caption pairs: {len(self.image_filenames)}")

    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.data_dir, img_name)
    
        # Load image
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Error loading image {img_name}: {e}")
    
        # Load corresponding caption
        caption_file = img_name.replace(".jpg", ".txt")
        caption_path = os.path.join(self.data_dir, caption_file)
    
        try:
            with open(caption_path, "r") as file:
                caption = file.read().strip()
        except Exception as e:
            raise RuntimeError(f"Error loading caption {caption_file}: {e}")
    
        if self.transform:
            image = self.transform(image)
    
        return image, caption


# Load the pre-trained CLIP model
device = "cuda" 
model, preprocess = clip.load("ViT-L/14", device=device)

# Paths to your image and caption directories
data_dir = r"C:\Users\swtir\OneDrive\Documents\Deep_Machine_Learning_ease_of_working_locally\deep-machine-learning\Final_Project\images_small"

# Define image transformations (resizing, normalizing as per CLIP's expected input)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Ensure the right size for the model
    transforms.ToTensor(),
])

# Create dataset and dataloader with optimized loading
small_dataset = ImageTextDataset(data_dir=data_dir, transform=transform)
dataloader = DataLoader(small_dataset, batch_size=16, shuffle=True, pin_memory=True)

total_batches = len(dataloader)
print(f"Total number of batches: {total_batches}")

Total valid image-caption pairs: 5653
Total number of batches: 354


In [None]:
def fine_tune_clip(model, processor, dataloader, epochs, learning_rate=1e-5, accumulation_steps=15, print_every=1, lr_warmup=1000):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Scheduler to gradually warm up learning rate
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step / lr_warmup, 1.0))
    
    loss_fn = torch.nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler()  # Mixed-precision training

    model.train()
    train_losses = []

    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()

        print(f"\nEpoch {epoch + 1}/{epochs}")
        for i, (images, captions) in enumerate(dataloader):
            if images is None or captions is None:
                print(f"Warning: No data received at batch {i+1}")
                continue

            # Move data to device and truncate text input to max_length=77 tokens
            inputs = processor(text=captions, images=images, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
            pixel_values = inputs['pixel_values']
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']

            # Forward pass
            with torch.cuda.amp.autocast():  # Automatic mixed-precision
                outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, return_loss=True)
                loss = outputs.loss / accumulation_steps

            # Backward pass
            scaler.scale(loss).backward()

            # Gradient accumulation and update
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)  # Gradient scaling to avoid NaNs
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps

            if print_every and (i + 1) % print_every == 0:
                print(f"Batch {i+1}/{len(dataloader)}, Loss: {loss.item() * accumulation_steps:.4f}")
        
        # Adjust learning rate
        scheduler.step()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {avg_loss:.4f}")

    print("Fine-tuning completed.")
    return model


# Set up the model, processor, optimizer, and fine-tuning process
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Assuming your ImageTextDataset is correctly defined and data loaded
batch_size = 16
dataloader = DataLoader(small_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Fine-tune the CLIP model
fine_tuned_model = fine_tune_clip(model, processor, dataloader, epochs=1, learning_rate=1e-5, accumulation_steps=8, print_every=1)

In [3]:
# Function to get text embeddings from the CLIP model
def get_text_embedding(text, clip_model, device):
    # Tokenize and encode the text with the CLIP model
    text_tokens = clip.tokenize([text]).to(device)
    text_embedding = clip_model.encode_text(text_tokens)
    
    # Normalize the embedding
    text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
    
    return text_embedding

device = torch.device("cuda")

# Load CLIP model
clip_model, _ = clip.load("ViT-L/14", device=device)

# Load pre-trained Stable Diffusion pipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)

# Disable safety checker
pipe.safety_checker = None

# Example text prompt
text_prompt = "A monkey on a tree smoking"

# Step 1: Get the text embedding from the CLIP model
text_embedding = get_text_embedding(text_prompt, clip_model, device)
text_embedding = text_embedding.to(torch.float32)  # Ensure float32 dtype

# Step 2: Adjust embedding shape for cross-attention in Stable Diffusion
batch_size = 1
sequence_length = pipe.tokenizer.model_max_length  # Sequence length expected by Stable Diffusion (usually 77)
embedding_dim = text_embedding.shape[-1]

# Expand the dimensions to match [batch_size, sequence_length, embedding_dim]
text_embedding = text_embedding.unsqueeze(0).expand(batch_size, sequence_length, embedding_dim)

# Create a dummy negative prompt embedding for classifier-free guidance
negative_prompt = ""
negative_embedding = get_text_embedding(negative_prompt, clip_model, device)
negative_embedding = negative_embedding.to(torch.float32).unsqueeze(0).expand(batch_size, sequence_length, embedding_dim)

# Step 3: Use the pipeline for generating the image with prompt embeddings
with torch.no_grad():
    generated_images = pipe(
        prompt_embeds=text_embedding,
        negative_prompt_embeds=negative_embedding,
        num_inference_steps=50,
        guidance_scale=7.5,  # Adjust guidance scale as needed
    ).images

# Step 4: Display the generated image
generated_image = generated_images[0]
generated_image.show()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


  0%|          | 0/50 [00:00<?, ?it/s]

In [4]:
num_epochs = 1
scaler = GradScaler()
def calculate_clip_score(clip_model, generated_images, text_embeddings):
    # Ensure generated images are in the format expected by CLIP (e.g., normalized, resized)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),  # CLIP expects 224x224 images
    ])
    
    # If images are in latent space, convert them to the appropriate format
    generated_images = preprocess(generated_images)

    # Encode the generated images into CLIP image embeddings
    image_embeddings = clip_model.encode_image(generated_images).float()

    # Normalize both embeddings
    image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

    # Compute cosine similarity (CLIP Score)
    clip_score = F.cosine_similarity(image_embeddings, text_embeddings, dim=-1)
    return clip_score.mean().item()

# Set up smaller UNet model (without .half())
small_unet = UNet2DConditionModel(
    sample_size=64,  # Reduce resolution to 64x64
    in_channels=4,
    out_channels=4,
    layers_per_block=1,
    block_out_channels=(64, 128, 256),  # Reduce the number of channels
    down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D'),
    up_block_types=('UpBlock2D', 'UpBlock2D', 'UpBlock2D'),
    attention_head_dim=32,  # Reduce attention heads for memory efficiency
    cross_attention_dim=768  # CLIP model dimension
)

# Ensure UNet and VAE are on the correct device (no .half())
pipe.unet = small_unet.to(device)
pipe.unet.enable_gradient_checkpointing()  # Enable gradient checkpointing to save memory
pipe.vae = pipe.vae.to(device)

# Dataloader with reduced batch size
dataloader = DataLoader(small_dataset, batch_size=4, shuffle=True)  # Reduce batch size for memory efficiency

# Optimizer setup
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=0.00001)
clip_loss_fn = torch.nn.CosineEmbeddingLoss()  # CLIP-based perceptual loss

# Fine-tuning loop
for epoch in range(num_epochs):
    for images, captions in dataloader:
        images = images.to(device)

        # Generate text embeddings using fine-tuned CLIP
        text_embeddings = model.encode_text(clip.tokenize(captions, truncate=True).to(device))

        with autocast():  # Mixed precision context
            # Encode the images into latent space using the VAE
            latents = pipe.vae.encode(images).latent_dist.sample().to(device)

            # Add noise to the latent representations
            noise = torch.randn_like(latents).to(device)

            # Sample a random timestep
            timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()

            # Pass the noisy latents and text embeddings to the UNet
            generated_latents = pipe.unet(noise, timesteps, encoder_hidden_states=text_embeddings.unsqueeze(1).repeat(1, 77, 1)).sample

            # Decode the generated latents back into image space using the VAE decoder
            decoded_images = pipe.vae.decode(generated_latents).sample

            # Compute diffusion loss (e.g., MSE between generated images and real images)
            diffusion_loss = torch.nn.functional.mse_loss(decoded_images, images)

            # Optionally add a CLIP-based perceptual loss
            decoded_image_embeddings = model.encode_image(decoded_images)

            # CLIP loss based on cosine similarity
            target = torch.ones(decoded_image_embeddings.shape[0], device=device)  # CosineEmbeddingLoss expects target 1 for matching pairs
            clip_loss = clip_loss_fn(decoded_image_embeddings, text_embeddings, target)

            # Total loss: diffusion loss + CLIP perceptual loss
            alpha = 0.8  # Adjust this weight to balance the two losses
            loss = alpha * diffusion_loss + (1 - alpha) * clip_loss

        # Backpropagation and optimization using mixed precision
        optimizer.zero_grad()

        # Perform backpropagation with scaled loss
        scaler.scale(loss).backward()

        # Unscale the gradients (letting GradScaler handle this)
        scaler.unscale_(optimizer)

        # Clip gradients to avoid exploding gradients (after unscaling)
        torch.nn.utils.clip_grad_norm_(pipe.unet.parameters(), max_norm=1.0)

        # Step the optimizer
        scaler.step(optimizer)

        # Update the scaler for the next iteration
        scaler.update()

        # Calculate CLIP Score
        with torch.no_grad():
            clip_score = calculate_clip_score(model, decoded_images, text_embeddings)
        print(f"CLIP Score: {clip_score:.4f}")

        # Free memory at the end of the batch
        del images, latents, generated_latents, decoded_images
        torch.cuda.empty_cache()

    print(f"Epoch {epoch+1}/{num_epochs} completed, Loss: {loss.item():.4f}, CLIP Score: {clip_score:.4f}")


CLIP Score: 0.1252
CLIP Score: 0.1044
CLIP Score: 0.1229
CLIP Score: 0.1398
CLIP Score: 0.0885
CLIP Score: 0.1332
CLIP Score: 0.1327
CLIP Score: 0.1402
CLIP Score: 0.1023
CLIP Score: 0.1168
CLIP Score: 0.1251
CLIP Score: 0.1132
CLIP Score: 0.1195
CLIP Score: 0.1107
CLIP Score: 0.1277
CLIP Score: 0.1505
CLIP Score: 0.1088
CLIP Score: 0.1397
CLIP Score: 0.1199
CLIP Score: 0.1401
CLIP Score: 0.1209
CLIP Score: 0.1158
CLIP Score: 0.1217
CLIP Score: 0.1243
CLIP Score: 0.1217
CLIP Score: 0.1273
CLIP Score: 0.1434
CLIP Score: 0.1034
CLIP Score: 0.1426
CLIP Score: 0.1335
CLIP Score: 0.1100
CLIP Score: 0.1272
CLIP Score: 0.1347
CLIP Score: 0.1311
CLIP Score: 0.1136
CLIP Score: 0.0857
CLIP Score: 0.1270
CLIP Score: 0.1305
CLIP Score: 0.1397
CLIP Score: 0.1193
CLIP Score: 0.1290
CLIP Score: 0.1303
CLIP Score: 0.1345
CLIP Score: 0.1412
CLIP Score: 0.1215
CLIP Score: 0.1583
CLIP Score: 0.1507
CLIP Score: 0.1301
CLIP Score: 0.1265
CLIP Score: 0.1517
CLIP Score: 0.1402
CLIP Score: 0.1427
CLIP Score: 

In [12]:
# Save the fine-tuned models
torch.save(pipe.unet.state_dict(), 'fine_tuned_unet.pth')
torch.save(pipe.vae.state_dict(), 'fine_tuned_vae.pth')
print("Fine-tuned models saved.")


Fine-tuned models saved.


In [13]:
# Load the fine-tuned models
pipe.unet.load_state_dict(torch.load('fine_tuned_unet.pth'))
pipe.vae.load_state_dict(torch.load('fine_tuned_vae.pth'))

# Ensure the models are moved to the appropriate device
pipe.unet.to(device)
pipe.vae.to(device)
print("Fine-tuned models loaded.")


Fine-tuned models loaded.
