# 🧑🏽‍🍳 How to Train a Computer Vision Model with Focoos

🐍 Setup Focoos

# 🎨 Fine-tune a model in 3 steps

This section covers the steps to create a model and train it using the focoos library. The following example demonstrates how to interact with the Focoos API to manage models, datasets, and training jobs.

In this guide, we will perform the following steps:

0. 🐍 Connect with Focoos
1. 📦 Load or select a dataset
2. 🏃‍♂️ Train the model
3. 🧪 Test your model


## 🐍 Connect with Focoos

In [None]:
from focoos.hub import FocoosHUB

FOCOOS_API_KEY = None  # write here your API key
hub = FocoosHUB(api_key=FOCOOS_API_KEY)

## 📦 Let's create a dataset

If you want to download a dataset from the hub, you can use the hub to directly store it in your local environment.
Check the reference of your dataset on the platform and use it in the following cell.
In the next cell, we will download the example dataset [Football Player Detection](https://app.focoos.ai/datasets/3a7cec8afb6b4780) with reference `3a7cec8afb6b4780`

In [None]:
dataset = hub.get_remote_dataset("3a7cec8afb6b4780")
print(dataset)

dataset_path = dataset.download_data()

Now that we downloaded the dataset, we can magically 🪄 instanciate the dataset using the `AutoDataset` as will be used in the training. You can optionally specify aumgentations for the training using the `DatasetAugmentation` dataclass.

In [None]:
from focoos.data.auto_dataset import AutoDataset
from focoos.data.default_aug import DatasetAugmentations
from focoos.ports import DatasetSplitType

task = dataset.task  # see ports.Task for more information
layout = dataset.layout  # see ports.DatasetLayout for more information
auto_dataset = AutoDataset(dataset_name=dataset_path, task=task, layout=layout)

train_augs = DatasetAugmentations(
    resolution=512,
    color_augmentation=1.0,
    horizontal_flip=0.5,
    vertical_flip=0.0,
    rotation=0.0,
    aspect_ratio=0.0,
    scale_ratio=0.0,
    crop=True,
)
valid_augs = DatasetAugmentations(resolution=512)
# Optionally, you can also get the default augmentations for the task
# train_augs, valid_augs = get_default_by_task(task, 512)

train_dataset = auto_dataset.get_split(augs=train_augs.get_augmentations(), split=DatasetSplitType.TRAIN)
valid_dataset = auto_dataset.get_split(augs=valid_augs.get_augmentations(), split=DatasetSplitType.VAL)

Let's also visualize a few augmented inputs!

In [None]:
display(train_dataset.show_sample_image())

## 🏃‍♂️ Train the model

The first step to personalize your model is to instance a model. You can get a model using the ModelManager as follow.
Check the list of available models on the focoos platform! You can also get one of your trained models on the hub.

In [None]:
from focoos.model_manager import ModelManager

model = ModelManager.get("hub://fai-detr-m-coco")

The next step is to train the model. You can train the model by calling the train method. You need to give it the hyperparameters, encapsulated in the `TrainerArgs`, the datasets and see the magic happens.

In [None]:
from focoos.ports import TrainerArgs

args = TrainerArgs(
    run_name="football-tutorial",  # the name of the experiment
    output_dir="./experiments",  # the folder where the model is saved
    batch_size=16,  # how many images in each iteration
    max_iters=500,  # how many iterations lasts the training
    eval_period=100,  # period after we eval the model on the validation (in iterations)
    learning_rate=0.0001,  # learning rate
    weight_decay=0.0001,  # regularization strenght (set it properly to avoid under/over fitting)
    sync_to_hub=True,  # Use this to see the model under training on the platform
)

# Let's go!
model.train(args, train_dataset, valid_dataset, hub=hub)

## 🧪 Test your model
Let's visualize some prediction!

In [None]:
model = ModelManager.get(name="exp1", models_dir="./experiments")

In [None]:
import random

from PIL import Image

from focoos.utils.vision import annotate_image

index = random.randint(0, len(valid_dataset))

print("Ground truth:")
display(valid_dataset.show_sample_image(index, use_augmentations=False))

image = Image.open(valid_dataset[index]["file_name"])
outputs = model(image)

print("Prediction:")
annotate_image(image, outputs, task=task, classes=model.model_info.classes)