# ImageBind from Meta

- Download the pre-trained model from Meta.
- Implement a training script that adapts ImageBind to your specific labels and dataset.

In [9]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
import torch.nn as nn
from torch.optim import Adam

# Verify that ImageBind is correctly installed and import the necessary modules
try:
    from imagebind import data
    from imagebind.models import imagebind_model
    from imagebind.models.imagebind_model import ModalityType
except ImportError as e:
    print("Error: ImageBind library is not installed or not found.")
    print(e)

# Define the custom dataset class
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        return image, label

# Preprocess your dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define the dataset path and annotations file
dataset_path = "/Users/colleenjung/Desktop/UChicago/24SummerCorrugated/filtered_data_new/dataset"
annotations_file = "/Users/colleenjung/Desktop/UChicago/24SummerCorrugated/filtered_data_new/annotations.csv"

train_dataset = CustomImageDataset(annotations_file=annotations_file, img_dir=dataset_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize the ImageBind model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.train()
model.to(device)

# Define a simple classification head
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Assuming the model outputs embeddings of a certain dimension, adjust input_dim accordingly
input_dim = 1024  # Replace with the actual dimension of the embeddings
num_classes = 4  # Updated to include 'unknown' class
classification_head = ClassificationHead(input_dim, num_classes).to(device)

# Set up the optimizer and loss function
optimizer = Adam(list(model.parameters()) + list(classification_head.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(10):  # Number of epochs
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # Get embeddings from the ImageBind model
        embeddings = model({ModalityType.VISION: images})[ModalityType.VISION]
        
        # Get predictions from the classification head
        outputs = classification_head(embeddings)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Save the fine-tuned model
torch.save({
    'model_state_dict': model.state_dict(),
    'classification_head_state_dict': classification_head.state_dict()
}, 'fine_tuned_imagebind.pth')

# Function to load and preprocess a single image for inference
def load_and_transform_image(image_path, device):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device
    return image

# Function to generate text description from an image
def generate_text_from_image(image_path, model, classification_head, device):
    # Load and preprocess image
    image = load_and_transform_image(image_path, device)
    
    # Switch the model to evaluation mode
    model.eval()
    classification_head.eval()
    
    # Perform inference
    with torch.no_grad():
        embeddings = model({ModalityType.VISION: image})[ModalityType.VISION]
        outputs = classification_head(embeddings)
    
    # Get the predicted class
    _, predicted = torch.max(outputs, 1)
    
    # Map predicted class to text label
    class_names = ['unknown', 'busted/gouged', 'crushed core', 'wet/foreign substances']
    predicted_label = class_names[predicted.item()]
    
    return predicted_label

# Example usage for inference
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
classification_head.to(device)

# Load the fine-tuned model
checkpoint = torch.load('fine_tuned_imagebind.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
classification_head.load_state_dict(checkpoint['classification_head_state_dict'])

image_path = "/path/to/your/image.jpg"
generated_text = generate_text_from_image(image_path, model, classification_head, device)
print("Generated Text Description: ", generated_text)


Error: ImageBind library is not installed or not found.
No module named 'imagebind'


FileNotFoundError: [Errno 2] No such file or directory: '/Users/colleenjung/Desktop/UChicago/24SummerCorrugated/filtered_data_new/annotations.csv'