In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
from sklearn.model_selection import train_test_split
from transformers import AutoModel, AutoProcessor
from torch.optim import AdamW
from tqdm import tqdm
import glob
import numpy as np

class GroundingDinoDataset(Dataset):
    def __init__(self, image_paths, labels, processor, label_to_id):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.label_to_id = label_to_id
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(image_path).convert('RGB')
        
        # Prepare text prompt (using label as description)
        text = f"a photo of {label}"
        
        return {
            'image': image,
            'text': text,
            'label': self.label_to_id[label],
            'image_path': image_path
        }

def custom_collate_fn(batch):
    """Custom collate function to handle variable-sized images"""
    images = [item['image'] for item in batch]
    texts = [item['text'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch])
    image_paths = [item['image_path'] for item in batch]
    
    return {
        'images': images,
        'texts': texts,
        'labels': labels,
        'image_paths': image_paths
    }

class GroundingDinoFineTuner:
    def __init__(self, model_name="IDEA-Research/grounding-dino-base"):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = None
        self.model = None
        self.label_to_id = {}
        self.id_to_label = {}
        
    def setup_model(self):
        """Initialize the model and processor"""
        self.processor = AutoProcessor.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)
        
        # Freeze some layers to make training more stable
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Unfreeze the last few layers for fine-tuning
        # Adjust based on your needs and computational resources
        for name, param in self.model.named_parameters():
            if any(layer in name for layer in ['class_embed', 'bbox_embed', 'query_embed']):
                param.requires_grad = True
            if 'text_encoder' in name and 'layer.11' in name:  # Last layer of text encoder
                param.requires_grad = True
            if 'vision_model' in name and 'layer.23' in name:  # Last layer of vision model
                param.requires_grad = True
                
        self.model.to(self.device)
        print(f"Model loaded on {self.device}")
        
    def prepare_dataset(self, dataset_path, test_size=0.2, random_state=42):
        """Prepare dataset from folder structure"""
        image_paths = []
        labels = []
        
        # Get all label folders
        label_folders = [f for f in os.listdir(dataset_path) 
                        if os.path.isdir(os.path.join(dataset_path, f))]
        
        # Create label mappings
        self.label_to_id = {label: idx for idx, label in enumerate(label_folders)}
        self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
        
        # Collect all image paths and labels
        for label in label_folders:
            label_path = os.path.join(dataset_path, label)
            image_files = glob.glob(os.path.join(label_path, "*.jpg"))
            
            for img_path in image_files:
                image_paths.append(img_path)
                labels.append(label)
        
        print(f"Found {len(image_paths)} images across {len(label_folders)} classes")
        print(f"Classes: {list(self.label_to_id.keys())}")
        
        # Split dataset
        train_paths, test_paths, train_labels, test_labels = train_test_split(
            image_paths, labels, test_size=test_size, random_state=random_state, stratify=labels
        )
        
        # Create datasets
        train_dataset = GroundingDinoDataset(train_paths, train_labels, self.processor, self.label_to_id)
        test_dataset = GroundingDinoDataset(test_paths, test_labels, self.processor, self.label_to_id)
        
        return train_dataset, test_dataset
    
    def process_batch(self, batch):
        """Process a batch of images and texts with proper settings for GroundingDino"""
        try:
            inputs = self.processor(
                images=batch['images'], 
                text=batch['texts'], 
                return_tensors="pt", 
                padding=True
            )
            return inputs
        except Exception as e:
            print(f"Error processing batch: {e}")
            # Fallback: process one by one
            processed_inputs = []
            for image, text in zip(batch['images'], batch['texts']):
                try:
                    input_single = self.processor(
                        images=image, 
                        text=text, 
                        return_tensors="pt"
                    )
                    processed_inputs.append(input_single)
                except Exception as e2:
                    print(f"Error processing single item: {e2}")
                    continue
            
            # Manually collate
            if processed_inputs:
                keys = processed_inputs[0].keys()
                batch_inputs = {}
                for key in keys:
                    batch_inputs[key] = torch.cat([inp[key] for inp in processed_inputs], dim=0)
                return batch_inputs
            else:
                raise ValueError("Could not process any items in the batch")
    
    def fine_tune(self, dataset_path, output_dir="./fine_tuned_weights", 
                  batch_size=2, num_epochs=10, learning_rate=1e-5):
        """Fine-tune the GroundingDino model"""
        
        # Setup model
        self.setup_model()
        
        # Prepare datasets
        train_dataset, test_dataset = self.prepare_dataset(dataset_path)
        
        if len(train_dataset) == 0:
            raise ValueError("No training data found!")
        
        # Create data loaders with custom collate function
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            collate_fn=custom_collate_fn,
            num_workers=0
        )
        test_loader = DataLoader(
            test_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            collate_fn=custom_collate_fn,
            num_workers=0
        )
        
        # Optimizer - only trainable parameters
        trainable_params = [p for p in self.model.parameters() if p.requires_grad]
        optimizer = AdamW(trainable_params, lr=learning_rate, weight_decay=1e-4)
        
        # Training loop
        self.model.train()
        
        for epoch in range(num_epochs):
            total_loss = 0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            
            for batch_idx, batch in enumerate(progress_bar):
                try:
                    # Process batch
                    inputs = self.process_batch(batch)
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    labels = batch['labels'].to(self.device)
                    
                    # Forward pass
                    outputs = self.model(**inputs)
                    
                    # Compute loss - adapted for GroundingDino
                    loss = self.compute_detection_loss(outputs, labels, batch['texts'])
                    
                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                    optimizer.step()
                    
                    total_loss += loss.item()
                    progress_bar.set_postfix({"Loss": loss.item()})
                    
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue
            
            if len(train_loader) > 0:
                avg_loss = total_loss / len(train_loader)
                print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
                
                # Validate
                self.validate(test_loader)
        
        # Save model and metadata
        self.save_model(output_dir)
        
        return output_dir
    
    def compute_detection_loss(self, outputs, labels, texts):
        """
        Compute loss for GroundingDino - this is a simplified approach
        Since GroundingDino is a detection model, we need to adapt the loss
        """
        try:
            # GroundingDino outputs contain predictions and logits
            # We'll use a simple approach: treat it as classification on the best box
            
            if hasattr(outputs, 'logits') and outputs.logits is not None:
                # Get the classification logits for the best detection
                batch_size = outputs.logits.shape[0]
                total_loss = 0
                
                for i in range(batch_size):
                    # Get the detection with highest confidence
                    max_conf_idx = torch.argmax(outputs.logits[i].max(dim=1).values)
                    
                    # Use cross entropy loss on the predicted class
                    # This is a simplification - you might need to adapt this based on your needs
                    if max_conf_idx < outputs.logits[i].shape[0]:
                        class_logits = outputs.logits[i][max_conf_idx]
                        
                        # Simple classification loss
                        # Note: This assumes your labels match the text queries
                        loss_fn = nn.CrossEntropyLoss()
                        loss = loss_fn(class_logits.unsqueeze(0), labels[i].unsqueeze(0))
                        total_loss += loss
                
                return total_loss / batch_size if batch_size > 0 else torch.tensor(0.0)
            else:
                # Fallback: simple classification loss using hidden states
                return self.compute_classification_loss(outputs, labels)
                
        except Exception as e:
            print(f"Error in loss computation: {e}")
            return self.compute_classification_loss(outputs, labels)
    
    def compute_classification_loss(self, outputs, labels):
        """Fallback classification loss using hidden states"""
        try:
            if hasattr(outputs, 'last_hidden_state'):
                # Use average pooling over sequence dimension
                pooled_output = outputs.last_hidden_state.mean(dim=1)
                classifier = nn.Linear(pooled_output.size(-1), len(self.label_to_id)).to(self.device)
                logits = classifier(pooled_output)
                loss = nn.CrossEntropyLoss()(logits, labels)
                return loss
            else:
                return torch.tensor(0.0, requires_grad=True).to(self.device)
        except:
            return torch.tensor(0.0, requires_grad=True).to(self.device)
    
    def validate(self, test_loader):
        """Simple validation function"""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                try:
                    inputs = self.process_batch(batch)
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    labels = batch['labels'].to(self.device)
                    
                    outputs = self.model(**inputs)
                    
                    # Simple accuracy calculation
                    if hasattr(outputs, 'last_hidden_state'):
                        pooled_output = outputs.last_hidden_state.mean(dim=1)
                        classifier = nn.Linear(pooled_output.size(-1), len(self.label_to_id)).to(self.device)
                        logits = classifier(pooled_output)
                        predictions = torch.argmax(logits, dim=1)
                        correct += (predictions == labels).sum().item()
                        total += labels.size(0)
                except Exception as e:
                    print(f"Validation error: {e}")
                    continue
        
        if total > 0:
            accuracy = correct / total
            print(f"Validation Accuracy: {accuracy:.4f}")
        else:
            print("No validation samples processed")
            
        self.model.train()
    
    def save_model(self, output_dir):
        """Save fine-tuned model and metadata"""
        os.makedirs(output_dir, exist_ok=True)
        
        # Save model weights
        torch.save(self.model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
        
        # Save metadata
        metadata = {
            "label_to_id": self.label_to_id,
            "id_to_label": self.id_to_label,
            "model_name": self.model_name
        }
        
        with open(os.path.join(output_dir, "metadata.json"), "w") as f:
            json.dump(metadata, f, indent=2)
        
        # Save processor
        self.processor.save_pretrained(output_dir)
        
        print(f"Model saved to {output_dir}")

class QuickGroundingDino:
    def __init__(self, weights_path):
        """Quick loader for fine-tuned weights"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.weights_path = weights_path
        self.model = None
        self.processor = None
        self.label_to_id = {}
        self.id_to_label = {}
        
        self.load_model()
    
    def load_model(self):
        """Load fine-tuned model and metadata"""
        # Load metadata
        with open(os.path.join(self.weights_path, "metadata.json"), "r") as f:
            metadata = json.load(f)
        
        self.label_to_id = metadata["label_to_id"]
        self.id_to_label = {int(k): v for k, v in metadata["id_to_label"].items()}
        model_name = metadata["model_name"]
        
        # Load processor and model
        self.processor = AutoProcessor.from_pretrained(self.weights_path)
        self.model = AutoModel.from_pretrained(model_name)
        
        # Load fine-tuned weights
        model_path = os.path.join(self.weights_path, "pytorch_model.bin")
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        print(f"Model loaded from {self.weights_path}")
        print(f"Available classes: {list(self.label_to_id.keys())}")
    
    def predict(self, image_path):
        """Make prediction on a single image"""
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')
            
            # Try all possible labels and find the best match
            best_label = "unknown"
            best_confidence = 0.0
            all_predictions = {}
            
            for label in self.label_to_id.keys():
                text = f"a photo of {label}"
                
                try:
                    # Process inputs
                    inputs = self.processor(images=image, text=text, return_tensors="pt")
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    
                    # Get model output
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                    
                    # Calculate confidence score
                    confidence = self.calculate_confidence(outputs, label)
                    all_predictions[label] = confidence
                    
                    if confidence > best_confidence:
                        best_confidence = confidence
                        best_label = label
                        
                except Exception as e:
                    print(f"Error processing label {label}: {e}")
                    all_predictions[label] = 0.0
            
            return {
                "predicted_label": best_label,
                "confidence": best_confidence,
                "all_predictions": all_predictions
            }
            
        except Exception as e:
            print(f"Error in prediction: {e}")
            return {
                "predicted_label": "error",
                "confidence": 0.0,
                "all_predictions": {}
            }
    
    def calculate_confidence(self, outputs, label):
        """Calculate confidence score from model outputs"""
        try:
            # Method 1: Use detection confidence if available
            if hasattr(outputs, 'logits') and outputs.logits is not None:
                # Get maximum detection confidence
                max_confidence = torch.sigmoid(outputs.logits).max().item()
                return max_confidence
            
            # Method 2: Use hidden states for classification
            elif hasattr(outputs, 'last_hidden_state'):
                pooled_output = outputs.last_hidden_state.mean(dim=1)
                classifier = nn.Linear(pooled_output.size(-1), len(self.label_to_id)).to(self.device)
                logits = classifier(pooled_output)
                probabilities = torch.softmax(logits, dim=1)
                return probabilities[0, self.label_to_id[label]].item()
            
            else:
                return 0.5  # Default confidence
                
        except:
            return 0.0

def train_grounding_dino_safe():
    """Safe training function with error handling"""
    try:
        fine_tuner = GroundingDinoFineTuner()
        
        output_dir = fine_tuner.fine_tune(
            dataset_path="fine_tuning_dataset",
            output_dir="./fine_tuned_grounding_dino",
            batch_size=1,  # Start with batch size 1 for stability
            num_epochs=3,  # Start with fewer epochs
            learning_rate=1e-5
        )
        
        return output_dir
        
    except Exception as e:
        print(f"Training failed: {e}")
        return None


In [11]:
print("Starting training...")
weights_path = train_grounding_dino_safe()
if weights_path:
    print(f"Training completed! Weights saved to: {weights_path}")
else:
    print("Training failed!")

Starting training...
Model loaded on cuda
Found 611 images across 1 classes
Classes: ['pull up']


Epoch 1/3:   1%|          | 3/488 [01:58<5:18:50, 39.45s/it, Loss=0]


KeyboardInterrupt: 

In [None]:
def test_grounding_dino(weights_path, test_image_path):
    """Quick test function"""
    if not os.path.exists(weights_path):
        print(f"Weights path {weights_path} does not exist!")
        return
    
    if not os.path.exists(test_image_path):
        print(f"Test image {test_image_path} does not exist!")
        return
    
    try:
        model = QuickGroundingDino(weights_path)
        result = model.predict(test_image_path)
        
        print(f"\n=== Prediction Results ===")
        print(f"Image: {os.path.basename(test_image_path)}")
        print(f"Predicted Label: {result['predicted_label']}")
        print(f"Confidence: {result['confidence']:.4f}")
        print(f"\nAll predictions:")
        for label, score in sorted(result['all_predictions'].items(), key=lambda x: x[1], reverse=True):
            print(f"  {label}: {score:.4f}")
        
        return result
    except Exception as e:
        print(f"Testing failed: {e}")
        return None

In [None]:
weights_path = ''
test_image_path = ''
test_grounding_dino(weights_path, test_image_path)