## Requirements

```
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install pandas matplotlib transformers pillow wandb scikit-learn ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
```


In [1]:
import os
import json
import torch
import clip
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import MultiLabelBinarizer

import wandb

In [2]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Load CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [4]:
# Custom dataset
class FoodIngredientsDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None, mlb=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.data = self.load_data()
        self.mlb = mlb
        
        if self.mlb is None:
            self.mlb = MultiLabelBinarizer()
            all_ingredients = [item['suffix'].split(', ') for item in self.data]
            self.mlb.fit(all_ingredients)
        
        self.num_classes = len(self.mlb.classes_)

    def load_data(self):
        with open(os.path.join(self.data_dir, f'{self.split}.jsonl'), 'r') as f:
            return [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = item['image']
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        ingredients = item['suffix'].split(', ')
        labels = self.mlb.transform([ingredients])[0]
        
        return image, torch.FloatTensor(labels)

In [5]:
# Data loading function
def load_data(data_dir, batch_size):
    train_dataset = FoodIngredientsDataset(data_dir, 'train', transform=preprocess)
    mlb = train_dataset.mlb
    val_dataset = FoodIngredientsDataset(data_dir, 'test', transform=preprocess, mlb=mlb)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, mlb

In [6]:
# Training function
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        image_features = model.encode_image(images)
        logits = model.visual.proj(image_features)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

In [7]:
# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device)

            image_features = model.encode_image(images)
            logits = model.visual.proj(image_features)

            loss = criterion(logits, labels)
            total_loss += loss.item()

            preds = (torch.sigmoid(logits) > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='samples')
    recall = recall_score(all_labels, all_preds, average='samples')
    f1 = f1_score(all_labels, all_preds, average='samples')

    return total_loss / len(val_loader), accuracy, precision, recall, f1

In [8]:
# Main training loop
def run_training():
    # Hyperparameters
    BATCH_SIZE = 32
    LEARNING_RATE = 1e-5
    EPOCHS = 1
    DATA_DIR = './data/food-ingredients-101'

    # Initialize wandb
    wandb.init(project="clip-finetuning", config={
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
    })

    # Load data
    train_loader, val_loader, mlb = load_data(DATA_DIR, BATCH_SIZE)

    # Modify the model's projection layer to match the number of classes
    num_classes = len(mlb.classes_)
    
    # Get the current output dimension of the visual projection
    current_dim = model.visual.proj.shape[0]
    
    # Create a new linear layer
    new_proj = torch.nn.Linear(current_dim, num_classes).to(device)
    
    # Initialize the weights of the new layer
    torch.nn.init.normal_(new_proj.weight, std=0.001)
    torch.nn.init.zeros_(new_proj.bias)
    
    # Replace the old projection layer with the new one
    model.visual.proj = new_proj

    # Initialize optimizer and loss function
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = BCEWithLogitsLoss()

    # Training loop
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch+1}/{EPOCHS}")
        
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_loss, accuracy, precision, recall, f1 = validate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        # Log to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1
        })

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'mlb': mlb,
        }, f"clip_checkpoint_epoch_{epoch+1}.pt")

    wandb.finish()

In [9]:
run_training()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mahirsch[0m ([33mhtwbe[0m). Use [1m`wandb login --relogin`[0m to force relogin


TypeError: cannot assign 'torch.nn.modules.linear.Linear' as parameter 'proj' (torch.nn.Parameter or None expected)