| IEEE AI Workshop: Self-Supervised Learning and JEPA in Autonomous Driving


## TCP Overview:


![teaser](assets/teaser_.png)

> Trajectory-guided Control Prediction for End-to-end Autonomous Driving: A Simple yet Strong Baseline  
> [Penghao Wu*](https://scholar.google.com/citations?user=9mssd5EAAAAJ&hl=en), [Xiaosong Jia*](https://jiaxiaosong1002.github.io/), [Li Chen*](https://scholar.google.com/citations?user=ulZxvY0AAAAJ&hl=en), [Junchi Yan](https://thinklab.sjtu.edu.cn/), [Hongyang Li](https://lihongyang.info/), [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/)    
>  - [arXiv Paper](https://arxiv.org/abs/2206.08129), NeurIPS 2022
>  - [Blog in Chinese](https://zhuanlan.zhihu.com/p/532665469)



## Environment Setup
Before diving into the practical exercises, ensure all participants have their environments set up correctly.


For Training:
```bash
conda create -f conda_env/tcp_trainer.yml --name TCPTrainer
```

For Evaluation:
```bash
conda create -f conda_env/tcp_runner.yml --name TCPEval
```


In [None]:
# Create the environments
!conda create -f conda_env/tcp_trainer.yml --name TCPTrainer
!conda create -f conda_env/tcp_runner.yml --name TCPEval

# Install the necessary package for the Jepa encoder
!pip install vjepa-encoder

# Clone the necessary repository for running examples
!git clone https://huggingface.co/jonathanzkoch/vjepa-self-driving vjepa/

# V-JEPA: A Vision Transformer pretrained on a bunch of video data that achieves SOTA when finetuned on downstream tasks

<img src="https://github.com/facebookresearch/jepa/assets/7530871/72df7ef0-2ef5-48bb-be46-27963db91f3d" width=40%>
&emsp;&emsp;&emsp;&emsp;&emsp;

<img src="https://github.com/facebookresearch/jepa/assets/7530871/f26b2e96-0227-44e2-b058-37e7bf1e10db" width=40%>


### Using V-Jepa:

 1. Install the pip package (which we already did)
 2. take an image and pass it through the network! (thats it)

In [1]:
from vjepa_encoder.vision_encoder import JepaEncoder
import numpy
import torch

# Note Required: Update the yaml to update /path/to/vjepa/insert/MODEL_NAME.tar
encoder = JepaEncoder.load_model(
    "vjepa/params-encoder.yaml"
)


# Initialize a random image:
img = numpy.random.random(size=(360, 480, 3))


print("Input Img:", img.shape)
embedding = encoder.embed_image(img)

print(embedding)
print(embedding.shape)

# Initialize another random image:
x = torch.rand((32, 3, 256, 900))

embedding = encoder.embed_image(x)
print(embedding)
print(embedding.shape)

  from .autonotebook import tqdm as notebook_tqdm


FileNotFoundError: [Errno 2] No such file or directory: 'vjepa/params-encoder.yaml'

## Training TCP with JEPA

1. Configure the data. This can be downloaded from here: https://drive.usercontent.google.com/download?id=1HZxlSZ_wUVWkNTWMXXcSQxtYdT7GogSm&export=download&authuser=0&confirm=t&uuid=3bbfcf3e-2be2-4aff-94c0-8d6e41ac91de&at=APZUnTXwtDK4X3Vdq_R30bHy58Vi%3A1712945448485 

2. (Already done for you) Modify the existing computer vision architecture to use the JEPA architecture

3. Run the training script (its that easy!)

In [2]:
from TCP.train import main
import os
from argparse import Namespace

args = Namespace(
    id='TCP',
    epochs=10,
    lr=0.0001,
    val_every=2,
    batch_size=64,
    logdir='log',
    gpus=1
)

# You may want to dynamically adjust paths or other settings here
args.logdir = os.path.join(args.logdir, args.id)

main(args)

TypeError: main() takes 0 positional arguments but 1 was given

## Running the Eval:
To run the eval, you will need to configure Carla. I am running Ubuntu, and you can download this from the GitHub Fork I created:

1. Set the python path to include Carla:
```bash
export PYTHONPATH=$PYTHONPATH:$(pwd)
```

2. Make sure you update the `TEAM_CONFIG` variable in the eval script to point to the checkpoint file.

3. Run the eval script (its that easy)
```bash
bash ./leaderboard/scripts/run_evaluation.sh 
```


In [None]:
# Export
!export PYTHONPATH=$PYTHONPATH:$(pwd)
!bash ./leaderboard/scripts/run_evaluation.sh 

# TCP_JEPA: Trajectory and Control Prediction using JEPA Encoder

TCP_JEPA is a neural network module designed for trajectory and control prediction in autonomous driving systems. It utilizes the JEPA (Joint Encoding for Prediction and Alignment) encoder for perception and combines it with various components for trajectory and control prediction.

## Overview

The TCP_JEPA module takes an image, state information, and a target point as input and generates predictions for trajectory and control. It consists of the following main components:

1. **Perception**: The perception component uses the JEPA encoder to extract features from the input image. The JEPA encoder is loaded from a configuration file and wrapped in a `PerceptWrapper` class.

2. **Attentive Pooler**: The attentive pooler is used to pool the features extracted by the JEPA encoder. It uses a single query and applies attention to obtain a compact representation of the image features.

3. **Measurements**: The measurements component processes the state information, which includes speed, position, and other relevant measurements. It applies a series of linear layers and activation functions to transform the state information into a feature representation.

4. **Trajectory Prediction**: The trajectory prediction component combines the image features and measurement features to predict the future trajectory of the vehicle. It uses a GRU (Gated Recurrent Unit) cell to generate the trajectory autoregressively.

5. **Control Prediction**: The control prediction component combines the image features and measurement features to predict the future control actions of the vehicle. It uses a GRU cell to generate the control actions autoregressively.

6. **Output Branches**: The module includes several output branches for predicting speed, value (for both trajectory and control), and the parameters of the control distribution (mean and standard deviation).

## Data Shapes

The TCP_JEPA module expects the following input shapes:

- `img`: The input image tensor with shape `(batch_size, channels, height, width)`.
- `state`: The state information tensor with shape `(batch_size, state_dim)`, where `state_dim` is the dimensionality of the state information (1+2+6 in the provided code).
- `target_point`: The target point tensor with shape `(batch_size, 2)`, representing the target position.

The module generates various output tensors with the following shapes:

- `pred_speed`: The predicted speed tensor with shape `(batch_size, 1)`.
- `pred_value_traj`: The predicted value for trajectory with shape `(batch_size, 1)`.
- `pred_features_traj`: The predicted features for trajectory with shape `(batch_size, feature_dim)`.
- `pred_wp`: The predicted waypoints tensor with shape `(batch_size, pred_len, 2)`, where `pred_len` is the length of the predicted trajectory.
- `pred_value_ctrl`: The predicted value for control with shape `(batch_size, 1)`.
- `pred_features_ctrl`: The predicted features for control with shape `(batch_size, feature_dim)`.
- `mu_branches`: The predicted mean of the control distribution with shape `(batch_size, dim_out)`.
- `sigma_branches`: The predicted standard deviation of the control distribution with shape `(batch_size, dim_out)`.
- `future_feature`: The predicted future features with shape `(pred_len, batch_size, feature_dim)`.
- `future_mu`: The predicted future mean of the control distribution with shape `(pred_len, batch_size, dim_out)`.
- `future_sigma`: The predicted future standard deviation of the control distribution with shape `(pred_len, batch_size, dim_out)`.

## Code Snippets

Here are a few code snippets that highlight important aspects of the TCP_JEPA module:

1. Loading the JEPA encoder:

```python
self.perception = PerceptWrapper(
    JepaEncoder.load_model(
        config_file_path=jepa_config,
        device=self.config.__dict__.get("jepa_device") 
    ),
    "embed_image",
)
```

2. Attentive pooling of image features:

```python
self.attn_pool = AttentivePooler(
    num_queries=1,
    embed_dim=jepa_embed_dim,
    num_heads=8,
    mlp_ratio=4.0,
    depth=1,
)
```

3. Trajectory prediction using GRU:

```python
for _ in range(self.config.pred_len):
    x_in = torch.cat([x, target_point], dim=1)
    z = self.decoder_traj(x_in, z)
    traj_hidden_state.append(z)
    dx = self.output_traj(z)
    x = dx + x
    output_wp.append(x)
```

4. Control prediction using GRU:

```python
for _ in range(self.config.pred_len):
    x_in = torch.cat([x, mu, sigma], dim=1)
    h = self.decoder_ctrl(x_in, h)
    wp_att = self.wp_att(torch.cat([h, traj_hidden_state[:, _]], 1))
    new_feature_emb = torch.bmm(img_embedding, wp_att.unsqueeze(1).transpose(1, 2)).squeeze(dim=-1)
    merged_feature = torch.cat([h, new_feature_emb], 1)
    dx = self.output_ctrl(merged_feature)
    x = dx + x
    policy = self.policy_head(x)
    mu = self.dist_mu(policy)
    sigma = self.dist_sigma(policy)
    future_feature.append(x)
    future_mu.append(mu)
    future_sigma.append(sigma)
```

## Possible Code Changes

Here are a few possible code changes or improvements that could be considered:

1. Freezing the JEPA encoder: If the JEPA encoder is pre-trained and not intended to be fine-tuned, you can freeze its weights by uncommenting the line `# self.perception.backbone.freeze_encoder()`.

2. Adjusting the dimensions of the fully connected layers: The dimensions of the fully connected layers in the various components (e.g., `measurements`, `join_traj`, `join_ctrl`) can be adjusted based on the specific requirements of the task or the available computational resources.

3. Modifying the autoregressive generation loop: The number of steps in the autoregressive generation loop for trajectory and control prediction is determined by `self.config.pred_len`. This value can be adjusted based on the desired length of the predicted trajectory and control actions.

4. Experimenting with different attention mechanisms: The attentive pooler currently uses a single query and applies attention to pool the image features. Different attention mechanisms or a varying number of queries could be explored to potentially improve the feature representation.

5. Incorporating additional input modalities: The TCP_JEPA module currently uses an image, state information, and a target point as input. Depending on the available data and the specific requirements of the task, additional input modalities such as lidar, radar, or high-level semantic information could be incorporated to enhance the prediction accuracy.

These are just a few examples of possible code changes or improvements. The actual modifications would depend on the specific goals, constraints, and performance requirements of the autonomous driving system.

## Conclusion:

- TCP is a simple way to start learning about Self Driving Cars
- Getting set up with TCP requires minimal effort