# Aircraft Damage Classification and Multimodal Captioning with Cross-Attention

In [None]:
!pip install torch torchvision transformers tensorflow matplotlib scikit-learn -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models as keras_models
import os
import numpy as np

## Step 2: Load and Preprocess Dataset

In [None]:
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

data_dir = './sample_data'  # Replace with your dataset path
dataset = datasets.ImageFolder(root=data_dir, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Step 3: Feature Extraction Using VGG16

In [None]:
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
    param.requires_grad = False

feature_extractor = nn.Sequential(*list(vgg16.children())[:-1])

## Step 4: Build and Compile Keras VGG16 Model

In [None]:
from tensorflow.keras.applications import VGG16

base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
for layer in base_model.layers:
    layer.trainable = False

keras_model = keras_models.Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(len(dataset.classes), activation='softmax')
])

keras_model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## Step 5: Train the VGG16 Model

In [None]:
train_images = np.random.rand(100, 224, 224, 3)
train_labels = tf.keras.utils.to_categorical(np.random.randint(0, len(dataset.classes), 100), num_classes=len(dataset.classes))

history = keras_model.fit(
    train_images,
    train_labels,
    epochs=5,
    validation_split=0.2
)

## Step 6: Plot Accuracy Curves

In [None]:
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

## Step 7: Predict and Visualize Results

In [None]:
test_images = np.random.rand(9, 224, 224, 3)
test_labels = np.random.randint(0, len(dataset.classes), 9)
predictions = keras_model.predict(test_images)
predicted_classes = np.argmax(predictions, axis=1)

fig, axes = plt.subplots(3, 3, figsize=(10, 10))
axes = axes.flatten()
for img, ax, pred, true in zip(test_images, axes, predicted_classes, test_labels):
    ax.imshow(img)
    ax.set_title(f'Pred: {pred}\nTrue: {true}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## Step 8: Implement a Cross-Attention Layer in Keras

In [None]:
class CrossAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(CrossAttention, self).__init__()
        self.multi_head_attention = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim
        )

    def call(self, query, key, value):
        attention_output = self.multi_head_attention(
            query=query,
            key=key,
            value=value
        )
        return attention_output

# Example usage of Cross-Attention
image_features = tf.random.normal(shape=(2, 10, 512))
text_features = tf.random.normal(shape=(2, 5, 512))

cross_attention_layer = CrossAttention(embed_dim=512, num_heads=8)
output = cross_attention_layer(query=image_features, key=text_features, value=text_features)
print(output.shape)

## Step 9: Generate Captions using BLIP Pretrained Model

In [None]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

raw_image = Image.open('path_to_your_image.jpg').convert('RGB')
inputs = processor(raw_image, return_tensors="pt")
out = blip_model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)

print(f"Generated Caption: {caption}")

## Step 10: Show BLIP Model Architecture

In [None]:
print(blip_model)