In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image

# Function to preprocess the image
def preprocess_image(image_path, device):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match encoder input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to GPU
    return image

# Function to generate a caption
def generate_caption(image_path, encoder, decoder, vocab, device, max_length=20):
    encoder.eval()
    decoder.eval()
    
    # Move models to the device (GPU)
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    
    # Preprocess the image and move it to GPU
    image = preprocess_image(image_path, device)
    
    # Extract features from the image
    with torch.no_grad():
        features = encoder(image)  # Features will now be on the same device as the model
    
    # Initialize the input to the decoder with the <SOS> token
    caption = []
    input_word = torch.tensor([vocab.stoi["<SOS>"]]).unsqueeze(0).to(device)  # Move <SOS> token to device
    
    # Generate words one-by-one
    # Generate words one-by-one
    for _ in range(max_length):
        with torch.no_grad():
            # Pass the features to the decoder
            output, _ = decoder.lstm(decoder.embed(input_word), (features.unsqueeze(0), torch.zeros_like(features).unsqueeze(0)))
            output = decoder.fc(output.squeeze(1))

            # Debugging: Print the output probabilities and the predicted word
            print(f"Output logits at step {_}: {output}")
            predicted = output.argmax(1)  # Get the index of the best word
            print(f"Predicted word index: {predicted.item()}")

        # Convert index to word
        word = vocab.itos[predicted.item()]
        if word == "<EOS>":
            break
        caption.append(word)
        
        # Set the input for the next step as the predicted word
        input_word = predicted.unsqueeze(0)
    
    # Join the words to form the final caption
    return " ".join(caption)

# Example usage
image_path = "test/test.jpg"  # Replace with the path to your image
device = torch.device("mps")  # Check if GPU is available
caption = generate_caption(image_path, encoder, decoder, vocab, device)
print("Generated Caption:", caption)