# Vision Transformers for Marine Computer Vision

## Overview

This notebook introduces Vision Transformers (ViTs) and their applications in marine computer vision. While Convolutional Neural Networks (CNNs) have been the dominant architecture for computer vision tasks, Transformers—originally designed for natural language processing—have recently shown remarkable performance in vision tasks. This lesson explores how ViTs can be applied to marine science challenges, offering new capabilities for analyzing underwater imagery.

### Learning Objectives

By the end of this lesson, you will:

* Understand the fundamental architecture of Vision Transformers
* Compare ViTs with traditional CNNs for marine image analysis
* Implement a basic ViT model for marine species classification
* Fine-tune pre-trained ViT models on marine datasets
* Evaluate the performance of ViTs in challenging underwater conditions

---

## Introduction to Vision Transformers

Vision Transformers (ViTs) represent a paradigm shift in computer vision. Introduced by Dosovitskiy et al. in their 2020 paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale," ViTs adapt the Transformer architecture—originally designed for natural language processing—to visual tasks.

Unlike CNNs, which process images through hierarchical convolutional layers, ViTs divide an image into fixed-size patches, linearly embed each patch, and process these embeddings with a standard Transformer encoder. This approach allows ViTs to capture global dependencies in the image without the inductive biases inherent in CNNs.

### Key Advantages of ViTs for Marine Applications:

1. **Global Context**: ViTs can capture long-range dependencies across the entire image, which is valuable for understanding complex marine scenes with varying scales and relationships.

2. **Scale Invariance**: The self-attention mechanism in ViTs helps in recognizing objects regardless of their size in the frame, addressing the challenge of varying distances in underwater imagery.

3. **Transfer Learning Capabilities**: Pre-trained ViTs can be effectively fine-tuned on smaller marine datasets, leveraging knowledge from large-scale pre-training.

4. **Robustness to Occlusion**: ViTs have shown better performance in scenarios with partial occlusion, which is common in turbid water conditions.

5. **Adaptability to Different Lighting Conditions**: The attention mechanism helps ViTs adapt to the variable lighting conditions encountered in underwater environments.

## Vision Transformer Architecture

Let's break down the architecture of a Vision Transformer:

1. **Patch Embedding**: The input image is divided into fixed-size patches (e.g., 16×16 pixels). Each patch is flattened and linearly projected to create a patch embedding.

2. **Position Embedding**: Since Transformers don't inherently understand spatial relationships, positional embeddings are added to provide information about the location of each patch.

3. **Class Token**: A special learnable embedding called the class token is prepended to the sequence of embedded patches. The final state of this token serves as the image representation for classification tasks.

4. **Transformer Encoder**: The sequence of embedded patches plus the class token is processed by a standard Transformer encoder, which consists of alternating layers of multiheaded self-attention (MSA) and MLP blocks, with layer normalization (LN) applied before each block and residual connections around each block.

5. **MLP Head**: The final representation of the class token is passed through an MLP head to produce class predictions.

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Implementing a Simple Vision Transformer

Let's implement a simplified version of a Vision Transformer to understand its core components. This implementation is for educational purposes and is not optimized for production use.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        
    def forward(self, x):
        # x: (batch_size, in_channels, img_size, img_size)
        x = self.proj(x)  # (batch_size, embed_dim, n_patches^0.5, n_patches^0.5)
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x

class Attention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.n_heads = n_heads
        self.scale = (dim // n_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(self, x):
        # x: (batch_size, n_patches+1, dim)
        batch_size, n_tokens, dim = x.shape
        
        qkv = self.qkv(x).reshape(batch_size, n_tokens, 3, self.n_heads, dim // self.n_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, n_heads, n_tokens, dim_per_head)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (batch_size, n_heads, n_tokens, n_tokens)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(batch_size, n_tokens, dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dim, drop)
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self, img_size=224, patch_size=16, in_channels=3, n_classes=1000, embed_dim=768,
        depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)
        
        self.blocks = nn.Sequential(*[
            Block(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        # x: (batch_size, in_channels, img_size, img_size)
        batch_size = x.shape[0]
        x = self.patch_embed(x)  # (batch_size, n_patches, embed_dim)
        
        cls_token = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (batch_size, 1+n_patches, embed_dim)
        x = x + self.pos_embed  # Add positional embedding
        x = self.pos_drop(x)
        
        x = self.blocks(x)
        x = self.norm(x)
        
        # Use the class token for classification
        x = x[:, 0]  # (batch_size, embed_dim)
        x = self.head(x)  # (batch_size, n_classes)
        
        return x

## Using Pre-trained Vision Transformers

While implementing a ViT from scratch is educational, for practical applications, it's more efficient to use pre-trained models. Let's see how to use a pre-trained ViT from the Hugging Face Transformers library and fine-tune it on a marine dataset.

In [None]:
# Install required packages if not already installed
!pip install transformers datasets

from transformers import ViTForImageClassification, ViTImageProcessor
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.notebook import tqdm

# Load a pre-trained ViT model
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Move model to the appropriate device
model = model.to(device)

## Preparing a Marine Dataset

For this example, let's assume we have a dataset of marine species images. We'll need to preprocess these images to match the input requirements of our ViT model.

In [None]:
# Define a function to preprocess images for the ViT model
def preprocess_images(examples):
    images = examples["image"]
    processed_images = processor(images, return_tensors="pt")["pixel_values"]
    return {"pixel_values": processed_images, "labels": examples["label"]}

# Load a sample dataset (replace with your marine dataset)
# For demonstration, we'll use a subset of the Oxford-IIIT Pet dataset
dataset = load_dataset("oxford-iiit-pet", split="train[:100]")

# Preprocess the dataset
processed_dataset = dataset.map(preprocess_images, batched=True)
processed_dataset.set_format(type="torch", columns=["pixel_values", "labels"])

# Create a DataLoader
dataloader = DataLoader(processed_dataset, batch_size=8, shuffle=True)

# Display a sample image
plt.figure(figsize=(10, 10))
plt.imshow(dataset[0]["image"])
plt.title(f"Label: {dataset[0]['label']}")
plt.axis("off")
plt.show()

## Fine-tuning the ViT Model

Now, let's fine-tune our pre-trained ViT model on the marine dataset. We'll need to modify the classification head to match the number of classes in our dataset.

In [None]:
# Modify the classification head for our number of classes
num_classes = len(set(dataset["label"]))
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)
model.num_labels = num_classes

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

# Fine-tuning loop
num_epochs = 3
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

## Evaluating the Model

After fine-tuning, let's evaluate our model on a test set to see how well it performs on marine species classification.

In [None]:
# Load test dataset
test_dataset = load_dataset("oxford-iiit-pet", split="test[:50]")
processed_test_dataset = test_dataset.map(preprocess_images, batched=True)
processed_test_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
test_dataloader = DataLoader(processed_test_dataset, batch_size=8)

# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Evaluating"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(pixel_values=pixel_values)
        _, predicted = torch.max(outputs.logits, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

## Visualizing Attention Maps

One of the advantages of Vision Transformers is the interpretability of attention maps. Let's visualize the attention patterns to understand what the model is focusing on when making predictions.

In [None]:
# Function to extract attention maps from the model
def get_attention_maps(model, image):
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with output_attentions=True
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # Extract attention maps
    attention_maps = outputs.attentions  # This is a tuple of attention tensors
    
    return attention_maps

# Get a sample image
sample_image = test_dataset[0]["image"]

# Get attention maps
attention_maps = get_attention_maps(model, sample_image)

# Visualize the attention map from the last layer
last_layer_attention = attention_maps[-1].detach().cpu().numpy()

# Average over attention heads
avg_attention = np.mean(last_layer_attention, axis=1)[0]

# Extract attention from the CLS token to all patches
cls_attention = avg_attention[0, 1:]

# Reshape to match the image patches
patch_size = 16
num_patches = int(np.sqrt(cls_attention.shape[0]))
attention_map = cls_attention.reshape(num_patches, num_patches)

# Visualize
plt.figure(figsize=(16, 8))

plt.subplot(1, 2, 1)
plt.imshow(sample_image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(attention_map, cmap="hot")
plt.title("Attention Map (CLS token to patches)")
plt.axis("off")

plt.tight_layout()
plt.show()

## Applications of Vision Transformers in Marine Science

Vision Transformers have shown promising results in various marine science applications. Here are some key areas where ViTs can be particularly beneficial:

### 1. Marine Species Classification and Counting

ViTs can effectively classify marine species in underwater imagery, even in challenging conditions with variable lighting, turbidity, and occlusion. Their ability to capture global context helps in distinguishing similar-looking species and counting individuals in crowded scenes.

### 2. Coral Reef Monitoring

The self-attention mechanism in ViTs can help identify subtle changes in coral reef health over time, detecting bleaching events, disease outbreaks, and recovery patterns with higher accuracy than traditional CNN-based approaches.

### 3. Marine Debris Detection

ViTs can be trained to detect and classify marine debris of various sizes and types, from microplastics to abandoned fishing gear, helping in pollution monitoring and cleanup efforts.

### 4. Underwater Infrastructure Inspection

For inspecting underwater structures like pipelines, cables, and offshore platforms, ViTs can identify structural anomalies, corrosion, and biofouling with high precision, reducing the need for human divers.

### 5. Plankton Analysis

ViTs can analyze plankton imagery from flow cytometers and microscopes, classifying different species and measuring their abundance, which is crucial for understanding marine food webs and ecosystem health.

## Case Study: Coral Reef Monitoring with ViT

Let's explore a hypothetical case study of using Vision Transformers for coral reef monitoring.

### Problem Statement

A marine conservation organization needs to monitor the health of coral reefs across multiple locations. Traditional methods involve divers manually surveying the reefs, which is time-consuming, expensive, and limited in coverage. The organization has collected thousands of underwater images from autonomous underwater vehicles (AUVs) but needs an efficient way to analyze them.

### Solution Approach

1. **Data Collection**: Gather a diverse dataset of coral reef images, including healthy corals, bleached corals, diseased corals, and various coral species.

2. **Data Annotation**: Annotate the images with labels for coral health status, species identification, and coverage percentage.

3. **Model Selection**: Use a pre-trained ViT model as the backbone and fine-tune it on the coral reef dataset.

4. **Training Strategy**: Implement a multi-task learning approach where the model simultaneously predicts coral health, species, and coverage.

5. **Deployment**: Deploy the model on edge devices installed on AUVs for real-time analysis during surveys.

### Results

The ViT-based model achieves 92% accuracy in coral health classification, outperforming the previous CNN-based model by 7%. The attention maps reveal that the model focuses on specific coral features like polyp structure and coloration patterns when making predictions, providing insights into the model's decision-making process.

The real-time analysis capability allows the AUVs to adaptively survey areas of interest, increasing the efficiency of monitoring operations by 60% and enabling the conservation organization to cover three times more reef area with the same resources.

## Challenges and Limitations

While Vision Transformers offer significant advantages for marine computer vision, they also come with challenges:

1. **Computational Requirements**: ViTs are computationally intensive, requiring substantial GPU resources for training and inference, which can be a limitation for deployment on resource-constrained underwater vehicles.

2. **Data Hunger**: ViTs typically require larger datasets for training from scratch compared to CNNs, though this can be mitigated through transfer learning from pre-trained models.

3. **Resolution Limitations**: Standard ViTs process images at a fixed resolution (e.g., 224×224 pixels), which may not be optimal for all marine applications, especially those requiring fine-grained details.

4. **Domain Shift**: Pre-trained ViTs are typically trained on terrestrial images, leading to potential domain shift issues when applied to underwater imagery with different optical properties.

5. **Interpretability Challenges**: While attention maps provide some interpretability, understanding the complex interactions between attention heads and layers remains challenging.

## Future Directions

The field of Vision Transformers for marine applications is rapidly evolving. Here are some promising future directions:

1. **Efficient ViT Architectures**: Development of more computationally efficient ViT variants like MobileViT and EfficientFormer that can run on edge devices deployed in marine environments.

2. **Hybrid CNN-Transformer Models**: Combining the local processing capabilities of CNNs with the global context modeling of Transformers to create hybrid architectures optimized for underwater imagery.

3. **Self-Supervised Learning**: Leveraging unlabeled marine imagery through self-supervised pre-training approaches like masked image modeling to reduce the dependency on large labeled datasets.

4. **Multi-Modal Transformers**: Integrating multiple data modalities (visual, acoustic, environmental) through transformer-based architectures for more comprehensive marine environment understanding.

5. **Adaptive Resolution Processing**: Developing ViT variants that can process images at variable resolutions or focus computational resources on regions of interest within high-resolution images.

## Conclusion

Vision Transformers represent a powerful new tool in the marine computer vision toolkit. Their ability to capture global context, adapt to varying conditions, and transfer knowledge from pre-training makes them particularly well-suited for the challenges of underwater imagery analysis.

As the technology continues to evolve and computational efficiency improves, we can expect ViTs to play an increasingly important role in marine science applications, from species identification and ecosystem monitoring to underwater infrastructure inspection and marine debris detection.

By understanding the fundamentals of ViT architecture and implementation, marine scientists can leverage these advanced models to gain new insights from visual data and address pressing challenges in ocean conservation and research.

## References

1. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.

2. Touvron, H., Cord, M., Douze, M., Massa, F., Sablayrolles, A., & Jégou, H. (2021). Training data-efficient image transformers & distillation through attention. In International Conference on Machine Learning (pp. 10347-10357). PMLR.

3. Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., ... & Guo, B. (2021). Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 10012-10022).

4. Caron, M., Touvron, H., Misra, I., Jégou, H., Mairal, J., Bojanowski, P., & Joulin, A. (2021). Emerging properties in self-supervised vision transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 9650-9660).

5. Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., ... & Girshick, R. (2023). Segment anything. arXiv preprint arXiv:2304.02643.

6. Kiefer, B., Žust, L., Muhovič, J., Kristan, M., Perš, J., Teršek, M., ... & Lin, T. Y. (2025). 3rd Workshop on Maritime Computer Vision (MaCVi) 2025: Challenge Results. arXiv preprint arXiv:2501.10343v1.