In [None]:
# train_clip.ipynb

import torch
from torch.utils.data import DataLoader, random_split
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load dataset
dataset = load_dataset("JotDe/mscoco_100k")

# Split dataset
train_size = int(0.8 * len(dataset['train']))
val_size = int(0.1 * len(dataset['train']))
test_size = len(dataset['train']) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset['train'], [train_size, val_size, test_size])

# Initialize processor
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def preprocess_data(batch):
    images = [item['image'] for item in batch]
    texts = [item['text'] for item in batch]
    inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True).to(device)
    return {"input_ids": inputs["input_ids"], "pixel_values": inputs["pixel_values"]}

# Create dataloaders
batch_size = 56
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=preprocess_data)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=preprocess_data)

# Load model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_model.to(device)
clip_model.train()

# Initialize optimizer and loss function
optimizer = torch.optim.Adam(clip_model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 5
training_losses = []

for epoch in range(num_epochs):
    total_loss = 0
    with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}", unit="batch") as pbar:
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs = clip_model(**batch)
            logits = outputs.logits_per_image
            labels = torch.arange(len(logits)).to(device)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.update(1)

        avg_loss = total_loss / len(train_dataloader)
        training_losses.append(avg_loss)
        pbar.set_postfix(loss=f"{avg_loss:.4f}")

# Plot training loss
plt.plot(range(1, num_epochs + 1), training_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()
plt.show()

# Save model
checkpoint_path = "clip_finetuned.pt"
torch.save({
    'model_state_dict': clip_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'final_loss': training_losses[-1],
}, checkpoint_path)

print(f"Model saved as {checkpoint_path}")
