In [1]:
import torch
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    GenerationConfig,
)
from PIL import Image
import os
import requests
from io import BytesIO
import logging
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ImageCaptioner:
    def __init__(self, model_path, device=None):
        """
        Initialize the Image Captioner
        
        Args:
            model_path (str): Path to the saved model directory
            device (str, optional): Device to run inference on. Defaults to auto-detect.
        """
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_path = Path(model_path)
        
        # Validate model path
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model path does not exist: {model_path}")
        
        # Load the model components
        self._load_model()
        
    def _load_model(self):
        """Load the trained model, tokenizer, and image processor"""
        logger.info(f"Loading model from {self.model_path}")
        logger.info(f"Using device: {self.device}")
        
        try:
            # Method 1: Try to load complete fine-tuned model
            if (self.model_path / "config.json").exists():
                logger.info("Loading complete fine-tuned model...")
                self.model = VisionEncoderDecoderModel.from_pretrained(self.model_path)
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
                self.image_processor = ViTImageProcessor.from_pretrained(self.model_path)
            else:
                # Method 2: Fallback to manual loading (your original approach, but improved)
                logger.info("Loading model with manual configuration...")
                self._load_model_manual()
            
            # Set up generation configuration properly
            self._setup_generation_config()
            
            # Move model to device and set to evaluation mode
            self.model.to(self.device)
            self.model.eval()
            logger.info("Model loaded and ready for inference")
            
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise
    
    def _load_model_manual(self):
        """Manual model loading as fallback"""
        # Load components
        encoder_model_name = "google/vit-base-patch16-224-in21k"
        decoder_model_name = "gpt2"
        
        # Try to load saved components first, fallback to base models
        try:
            self.image_processor = ViTImageProcessor.from_pretrained(self.model_path)
        except:
            self.image_processor = ViTImageProcessor.from_pretrained(encoder_model_name)
            logger.warning("Using base ViT image processor")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        except:
            self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
            # Ensure pad token is set
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            logger.warning("Using base GPT2 tokenizer")
        
        # Create model architecture
        self.model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            encoder_model_name, decoder_model_name
        )
        
        # Try to load trained weights
        model_state_path = self.model_path / "best_model.pth"
        if model_state_path.exists():
            try:
                state_dict = torch.load(model_state_path, map_location=self.device)
                # Handle potential key mismatches
                missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
                if missing_keys:
                    logger.warning(f"Missing keys in state dict: {len(missing_keys)} keys")
                if unexpected_keys:
                    logger.warning(f"Unexpected keys in state dict: {len(unexpected_keys)} keys")
                logger.info("Loaded trained model weights")
            except Exception as e:
                logger.error(f"Failed to load trained weights: {str(e)}")
                logger.warning("Using base model weights")
        else:
            logger.warning("No trained weights found, using base model")
    
    def _setup_generation_config(self):
        """Set up generation configuration properly"""
        # GPT2 doesn't support beam search in VisionEncoderDecoder, so use greedy/sampling
        generation_config = GenerationConfig(
            decoder_start_token_id=self.tokenizer.bos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            max_length=64,
            no_repeat_ngram_size=3,
            num_beams=1,  # Must be 1 for GPT2 decoder
            do_sample=True,  # Use sampling since beam search not supported
            temperature=0.7,
            top_p=0.9,
        )
        
        # Save the generation config to the model
        self.model.generation_config = generation_config
    
    def caption_image(self, image_input, max_new_tokens=64, temperature=0.7, 
                     top_p=0.9, do_sample=True):
        """
        Generate a caption for a single image
        
        Args:
            image_input (str): Path to image file or URL
            max_new_tokens (int): Maximum number of new tokens to generate
            temperature (float): Sampling temperature
            top_p (float): Nucleus sampling parameter
            do_sample (bool): Whether to use sampling (True) or greedy (False)
            
        Returns:
            str: Generated caption
        """
        try:
            # Validate input
            if not image_input or not image_input.strip():
                raise ValueError("Empty image input")
            
            # Load image from URL or local path
            image = self._load_image(image_input)
            
            # Preprocess the image
            pixel_values = self.image_processor(
                image, 
                return_tensors="pt"
            ).pixel_values.to(self.device)
            
            # Create generation config for this inference
            # NOTE: GPT2 doesn't support beam search in VisionEncoderDecoder
            generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                temperature=temperature if do_sample else 1.0,
                top_p=top_p if do_sample else 1.0,
                num_beams=1,  # Must be 1 for GPT2 decoder
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                no_repeat_ngram_size=3,
            )
            
            # Generate caption
            with torch.no_grad():
                generated_ids = self.model.generate(
                    pixel_values, 
                    generation_config=generation_config
                )
            
            # Decode the generated tokens
            caption = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            # Clean up GPU memory if using CUDA
            if self.device == "cuda":
                torch.cuda.empty_cache()
            
            return caption.strip()
            
        except Exception as e:
            logger.error(f"Error processing image {image_input}: {str(e)}")
            return None
    
    def _load_image(self, image_input):
        """Load image from URL or local path"""
        if image_input.startswith(('http://', 'https://')):
            try:
                response = requests.get(image_input, timeout=10)
                response.raise_for_status()
                image = Image.open(BytesIO(response.content)).convert('RGB')
            except requests.RequestException as e:
                raise RuntimeError(f"Failed to download image from URL: {str(e)}")
        else:
            image_path = Path(image_input)
            if not image_path.exists():
                raise FileNotFoundError(f"Image file not found: {image_input}")
            try:
                image = Image.open(image_path).convert('RGB')
            except Exception as e:
                raise RuntimeError(f"Failed to open image file: {str(e)}")
        
        return image
    
    def caption_batch(self, image_inputs, **kwargs):
        """Generate captions for a batch of images"""
        results = []
        for image_input in image_inputs:
            caption = self.caption_image(image_input, **kwargs)
            results.append({
                'image': image_input,
                'caption': caption,
                'success': caption is not None
            })
        return results


def test_sample_images(captioner):
    """Test captioning on sample COCO images"""
    
    # Sample COCO image URLs
    sample_urls = [
        "http://images.cocodataset.org/val2017/000000039769.jpg",
        "http://images.cocodataset.org/val2017/000000397133.jpg", 
        "http://images.cocodataset.org/val2017/000000037777.jpg"
    ]
    
    logger.info("Testing with sample COCO images:")
    print("=" * 60)
    
    for i, url in enumerate(sample_urls, 1):
        print(f"\nImage {i}: {url}")
        print("-" * 40)
        
        # Test both greedy and sampling
        caption_greedy = captioner.caption_image(url, do_sample=False)
        caption_sample = captioner.caption_image(url, do_sample=True, temperature=0.7)
        
        if caption_greedy:
            print(f"Greedy Caption: {caption_greedy}")
        if caption_sample:
            print(f"Sampling Caption: {caption_sample}")
        
        if not caption_greedy and not caption_sample:
            print("Failed to generate caption")
        
        print("-" * 60)


def main():
    """Main function to run image captioning inference"""
    
    # Configuration - UPDATE THIS PATH
    MODEL_PATH = "../vit-gpt2-coco-finetuned-from-scratch"
    
    print("=" * 60)
    print("Image Captioning Inference Script")
    print("=" * 60)
    
    try:
        # Initialize the captioner
        captioner = ImageCaptioner(MODEL_PATH)
        
        # Test with sample images
        test_sample_images(captioner)
        
        # Interactive mode
        print("\nðŸŽ¯ Interactive Mode")
        print("Enter image paths or URLs to caption (or 'quit' to exit)")
        print("\nYou can test with:")
        print("â€¢ Local image paths: /path/to/image.jpg")
        print("â€¢ URLs: http://images.cocodataset.org/val2017/000000039769.jpg")
        print("\nGeneration options:")
        print("â€¢ Add 'greedy' for deterministic mode: image.jpg greedy")
        print("â€¢ Default: sampling mode (more creative)")
        
        while True:
            user_input = input("\nEnter image path or URL: ").strip()
            if user_input.lower() in ['quit', 'exit', 'q']:
                break
            
            if user_input:
                # Parse input for generation mode
                parts = user_input.split()
                image_path = parts[0]
                use_greedy = len(parts) > 1 and 'greedy' in parts[1].lower()
                
                caption = captioner.caption_image(
                    image_path, 
                    do_sample=not use_greedy,
                    temperature=0.8 if not use_greedy else 1.0,
                    top_p=0.9 if not use_greedy else 1.0
                )
                
                if caption:
                    mode = "Greedy" if use_greedy else "Sampling"
                    print(f"Caption ({mode}): {caption}")
                else:
                    print("Failed to generate caption")
            else:
                print("Please enter a valid path or URL")
    
    except Exception as e:
        logger.error(f"Failed to initialize captioner: {str(e)}")
        return
    
    print("\nThanks for using the Image Captioning Script!")


if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
INFO:__main__:Loading model from ..\vit-gpt2-coco-finetuned-from-scratch
INFO:__main__:Using device: cuda
INFO:__main__:Loading model with manual configuration...


Image Captioning Inference Script


Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.10.cros


Image 1: http://images.cocodataset.org/val2017/000000039769.jpg
----------------------------------------


`generation_config` default values have been modified to match model-specific defaults: {'max_length': 64, 'do_sample': True, 'temperature': 0.7, 'top_p': 0.9, 'decoder_start_token_id': 50256}. If this is not desired, please set these values explicitly.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. You should pass an instance of `Cache` instead, e.g. `past_key_values=DynamicCache.from_legacy_cache(past_key_values)`.
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 64, 'decoder_start_token_id': 50256}. If this is not desired, please set these values explicitly.


Greedy Caption: cat on bed to a on laptop a and t bear
Sampling Caption: cat on bed front a and mouse
------------------------------------------------------------

Image 2: http://images.cocodataset.org/val2017/000000397133.jpg
----------------------------------------
Greedy Caption: woman on kitchen in with and oven
Sampling Caption: woman at table several in kitchen and.
------------------------------------------------------------

Image 3: http://images.cocodataset.org/val2017/000000037777.jpg
----------------------------------------
Greedy Caption: kitchen a and stove a and oven
Sampling Caption: kitchen a and a with cabinets a and stove
------------------------------------------------------------

ðŸŽ¯ Interactive Mode
Enter image paths or URLs to caption (or 'quit' to exit)

You can test with:
â€¢ Local image paths: /path/to/image.jpg
â€¢ URLs: http://images.cocodataset.org/val2017/000000039769.jpg

Generation options:
â€¢ Add 'greedy' for deterministic mode: image.jpg greedy
â€¢