# Multi-Task Image Encoder Tutorial

This notebook provides a complete walkthrough of how to install, train, and use the multi-task image encoder.

### Step 1: Clone the Repository

First, we'll clone the project repository from GitHub. **Remember to replace the URL with your actual repository URL.**

In [None]:
!git clone https://github.com/TalonBvV/ImageEncoder.git image-encoder
%cd image-encoder

### Step 2: Install Dependencies

Now, we'll install the project and all its dependencies using the `setup.py` and `requirements.txt` files.

In [None]:
!pip install .

### Step 3: Download a Sample Dataset

We need some images to train on. We'll download a small set of flower photos as a demonstration.

In [None]:
!wget -q -O flower_photos.tgz http://download.tensorflow.org/example_images/flower_photos.tgz
!tar -xzf flower_photos.tgz
IMAGE_DIR = 'flower_photos'

### Step 4: Configure and Run Training

We need to tell our training script where to find the images. We'll modify the `train.py` file to point to our downloaded dataset and reduce the number of epochs for a quick demonstration.

In [None]:
# Modify the train.py script to use our new image directory and run for just 2 epochs
!sed -i "s|IMAGE_DIR = .*|IMAGE_DIR = 'flower_photos'|g" train.py
!sed -i "s|MAX_EPOCHS = .*|MAX_EPOCHS = 2|g" train.py

print('--- Modified train.py ---')
!cat train.py

In [None]:
# Run the training!
!python train.py

### Step 5: Use the Trained Encoder for Inference

After training, a checkpoint file is saved. We can now load this checkpoint, extract the trained encoder, and use it to get a latent vector for a new image.

In [None]:
import torch
from PIL import Image
from torchvision import transforms
from lightning_module import MultiTaskImageEncoder
import glob

# Find the checkpoint file
checkpoint_path = glob.glob('tb_logs/image_encoder_v1/version_0/checkpoints/*.ckpt')[0]
print(f'Found checkpoint: {checkpoint_path}')

# Load the model from the checkpoint
model = MultiTaskImageEncoder.load_from_checkpoint(checkpoint_path)
encoder = model.encoder
encoder.eval() # Set to evaluation mode

# Load a sample image
sample_image_path = glob.glob('flower_photos/*/*.jpg')[0]
img = Image.open(sample_image_path).convert('RGB')

# Preprocess the image
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])
img_tensor = transform(img).unsqueeze(0) # Add batch dimension

# Get the latent vector
with torch.no_grad():
    latent_vector = encoder(img_tensor)

print(f'Successfully encoded image {sample_image_path}')
print(f'Latent vector shape: {latent_vector.shape}')
print(f'Latent vector (first 10 values): {latent_vector[0, :10]}')

### Step 6: Export the Encoder to ONNX

Finally, we'll use the `export.py` script to save the trained encoder to the standard ONNX format for deployment.

In [None]:
# Modify export.py to use the correct checkpoint path
!sed -i "s|# CHECKPOINT_PATH = .*|CHECKPOINT_PATH = f'{checkpoint_path}'|g" export.py
!sed -i "s|# export_encoder_to_onnx.*|export_encoder_to_onnx(CHECKPOINT_PATH)|g" export.py

print('--- Modified export.py ---')
!cat export.py

In [None]:
# Run the export!
!python export.py

In [None]:
!ls -l encoder.onnx