<img src='https://github.com/Deci-AI/super-gradients/blob/master/documentation/assets/SG_img/SG%20-%20Horizontal%20Glow.png?raw=true'>

## If you want to learn more about the ResNet family of model architectures, be sure to check out [my FREE course on Udemy](https://www.udemy.com/course/supergradients-resnet/).

In this notebook, you'll use the SuperGradients training library to classify the color and type for apparel images.

Before you dive into the code, it's worth talking about something: The difference between <span style="color:red">multi-class</span>, <span style="color:green">multi-task</span>, and <span style="color:orange">multi-label learning</span>,.

<span style="color:red">

In multi-class learning, you're predicting a single label for each input, but each label is a single element from a set of possible labels. 

For example, imagine you're working with a dataset of clothing images, and you want to predict whether each item is a shirt, a pair of pants, or a pair of shoes. To add another dimension, you should also predict the colour of each item, such as whether it's red, blue, or green. With multi-class learning, you can predict both the clothing item and its colour at the same time. For example, "red shirt" vs "blue shirt" vs "brown show" vs "black shoe," etc. etc. 

You typically use the Cross-entropy loss function to train a neural network on this problem.
</span>

<span style="color:green">

In multi-task learning, you have multiple problems that need to be solved simultaneously. 

For instance, you could predict the clothing item and its colour as separate tasks. The idea here is that solving one task could help solve the other task - for example, certain colours might be more common for certain types of clothing. In this case, you can assign each output (clothing item and colour) loss function to train a neural network. 

You can then combine the loss functions by summing them up (or averaging) and using weights to balance the importance of each task.
</span>

<span style="color:orange">

Multi-label learning is a special case of multi-task learning. 

In this scenario, you should label an image with multiple clothing items and their colours. You can break down the task into multiple binary classification problems to solve this. If the possible labels are "shirt 👔", "pants 👖", and "shoes 👟", and the possible colours are <span style="color:red"> "red",</span> <span style="color:lightblue"> "blue",</span> and <span style="color:green">"green",</span> you would need to train the network to answer questions like "is there a shirt in the image?" and "is the shirt in the image red?" You want to use multi-label learning in the scenario where your labels are not mutually exclusive. 

You would use the `BCEWithLogisLoss` for multi-label learning since it combines a sigmoid activation function and binary cross-entropy loss into a single function, making it efficient and numerically stable.
</span>

## What type of learning are we going to perfom here?

The answer will be revealed, but I want you to think about as you read through the notebook. We need to predict the *color* and *type* of an image of clothing, what would make sense here?

#### What is SuperGradients?

[SuperGradients](https://github.com/Deci-AI/super-gradients) is an open-source PyTorch based training library that has a number of pre-trained models for you to use, training recipies that will get you amazing accuracy, and many [training tricks](https://www.deeplearningdaily.community/t/tips-for-training-your-neural-networks/307) that you can use with just the "flip of a switch". For this example you'll use an EfficientNetB0 to perfom the classification. You can check out our [model zoo](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/Computer_Vision_Models_Pretrained_Checkpoints.md) and use any of the pretrained models we have available.

Feel free to reach out to me on my community forum, [Deep Learning Daily (free and open to all)](https://www.deeplearningdaily.community/), should you have any questions.


In [None]:
%%capture 
#  NOTE: You MUST restart the notebook after installation is complete, else you will get an import error
!pip install super-gradients==3.0.7 
!pip install imutils

In [None]:
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, random_split, DataLoader
from PIL import Image
# import torchvision.models as models
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchvision import transforms
import imutils
import torch.nn.functional as F
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from imutils import paths
import os
import cv2 as cv
import re
import requests

import pathlib
from pathlib import Path

import torch.nn.functional as F


%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split

from torchmetrics import Metric

import super_gradients
from super_gradients.common.object_names import Models
from super_gradients.training import Trainer
from super_gradients.training import training_hyperparams
from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
from super_gradients.training.utils.early_stopping import EarlyStop
from super_gradients.training import models
from super_gradients.training.utils.callbacks import Phase
from super_gradients.common.registry import register_metric, register_model

# Config class

Here we'll have variable information saved.


In [None]:
class config:
    # specify the paths to datasets
    ROOT_DIR = Path("../input/apparel-images-dataset")

    # set the input height and width
    INPUT_HEIGHT = 224
    INPUT_WIDTH = 224

    # set the input heig/ht and width
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]
    
    IMAGE_TYPE = '.jpg'
    BATCH_SIZE = 128
    MODEL_NAME = 'resnet50'
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    TRAINING_PARAMS = 'training_hyperparams/imagenet_resnet50_train_params'
    
    COLOUR_LABELS = ['black', 'blue', 'brown', 'green', 'red', 'white']
    N_COLOUR_LABELS = len(COLOUR_LABELS)
    ARTICLE_LABELS = ['dress', 'pants', 'shirt', 'shoes', 'shorts']
    N_ARTICLE_LABELS = len(ARTICLE_LABELS)
    NUM_CLASSES = N_COLOUR_LABELS + N_ARTICLE_LABELS
    
    CHECKPOINT_DIR = 'checkpoints'

# Data into train, validation, and testing dataset

In [None]:
all_images = list(paths.list_images(config.ROOT_DIR))
train_images, dummy_list = train_test_split(all_images, test_size=.20, shuffle=True, random_state=42)
val_images, test_images = train_test_split(dummy_list, test_size=.50, shuffle=True, random_state=42)

Function to encode the labels

In [None]:
def encode_label(label, class_list):
    """Encode a list of labels using one-hot encoding.

    Args:
        label: Label to encode.
        class_list: A list of all possible labels. Defaults to DEFAULT_LABELS.

    Returns:
        A tensor representing the one-hot encoding of the input labels.
    """
    # Create a tensor of zeros with the same length as the class list
    target = torch.zeros(len(class_list))
    for _ in class_list:
        # Find the index of the current label in the class list
        idx = class_list.index(label)
        # Set the corresponding index in the target tensor to 1
        target[idx] = 1
    return target

# Defining a custom dataset


This line of code: `labels = re.findall(r'\w+\_\w+', img_path)[0].split('_')`

This line extracts labels from an image path by searching for a pattern of one or more word characters followed by an underscore and one or more word characters, and then splitting the extracted match into a list of strings using underscores as the delimiter.

You'll unpack the result of this and define `colour_label` and `article_label`, which are one-hot encoded tensors that are then concatenated. 

This will be your target tensor.

Take a look at the naming convention of the image folders and it will make sense

In [None]:
!ls ../input/apparel-images-dataset

In [None]:
class ApparelDataset(Dataset):
    def __init__(self, image_list, transform=None):
        self.transform=transform
        self.image_list=image_list
        
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        img_path = self.image_list[idx]
        img = Image.open(img_path)
        
        if self.transform:
            img = self.transform(img)
        # The following line of code finds the pattern aaa_bbb in the filepath string
        labels = re.findall(r'\w+\_\w+', img_path)[0].split('_')
        colour_label = encode_label(labels[0], config.COLOUR_LABELS)
        article_label = encode_label(labels[1], config.ARTICLE_LABELS)
        return img, np.concatenate([colour_label, article_label])

# Data augmentations

I've instantiated a number of augmentations that you're free to play around with.

For illustrative purposes I've only included RandomAugment in the pipeline. I encourage you to try some other ones.

In [None]:
# initialize our data augmentation functions
resize = transforms.Resize(size=(config.INPUT_HEIGHT,config.INPUT_WIDTH))
make_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=config.IMAGENET_MEAN, std=config.IMAGENET_STD)
center_cropper = transforms.CenterCrop((config.INPUT_HEIGHT,config.INPUT_WIDTH))
random_horizontal_flip = transforms.RandomHorizontalFlip(p=0.75)
random_vertical_flip = transforms.RandomVerticalFlip(p=0.75)
random_rotation = transforms.RandomRotation(degrees=90)
random_crop = transforms.RandomCrop(size=(200,200))
augmix = transforms.AugMix(severity = 3, mixture_width=3, alpha=0.2)
auto_augment = transforms.AutoAugment()
random_augment = transforms.RandAugment()

# initialize our training and validation set data augmentation pipeline
train_transforms = transforms.Compose([
  resize, 
  random_augment,
  make_tensor,
  normalize
])

val_transforms = transforms.Compose([resize, make_tensor, normalize])

## Instantiate datasets

In [None]:
train_dataset = ApparelDataset(train_images, transform = train_transforms)
val_dataset = ApparelDataset(val_images, transform = val_transforms)
test_dataset = ApparelDataset(test_images, transform = val_transforms)

## Instantiate dataloaders

In [None]:
train_loader = DataLoader(train_dataset, config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, config.BATCH_SIZE)
test_loader = DataLoader(test_dataset, config.BATCH_SIZE)

# Model definition

This defines a PyTorch module named MultilabelClassifier that inherits from nn.Module. The @register_model decorator is used to register the module in the SuperGradients training library.

In the constructor (`__init__()`), the ResNet model is loaded using `models.get()`, which is a utility function that retrieves a pre-trained model by name (`config.MODEL_NAME`) from the [SuperGradients model zoo](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/Computer_Vision_Models_Pretrained_Checkpoints.md). 

The last two layers of the ResNet model are removed to obtain a feature extractor, which is stored in `self.model_wo_fc`. 

`self.avgpool` is an instance of `nn.AdaptiveAvgPool2d`, which computes the spatial average of each feature map. `self.colour` and `self.article` are instances of `nn.Sequential` that define the classification heads for colour and article labels, respectively. 

Each head consists of a dropout layer, a fully connected layer, and a softmax activation function.

The `forward()` method is where the actual computation takes place. 

The input tensor x is passed through the ResNet feature extractor (`self.model_wo_fc`) and the average pooling layer (`self.avgpool`). The resulting feature tensor is then fed into the color and article classification heads, and the predicted probabilities are concatenated along the second dimension (`dim=1`) to form a single tensor pred_tensor, which is returned as the output of the module.

Note that the `nn.Softmax() `function is applied to the output of each classification head, which normalizes the predicted probabilities across all possible labels for each head. This ensures that the predicted probabilities for each head sum to 1 and can be interpreted as probabilities. 

However, this is not strictly necessary.

In [None]:
@register_model("multilabelmodel")
class MultilabelClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.get(config.MODEL_NAME, pretrained_weights='imagenet')
        self.model_wo_fc = nn.Sequential(*(list(self.resnet.children())[:-2]))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.colour = nn.Sequential(nn.Dropout(0.5),
                                     nn.Linear(2048, config.N_COLOUR_LABELS),
                                     nn.Softmax(dim=1))
        self.article = nn.Sequential(nn.Dropout(0.5),
                                      nn.Linear(2048, config.N_ARTICLE_LABELS),
                                      nn.Softmax(dim=1))
    
    def forward(self, x):
        x = self.model_wo_fc(x)
        x = self.avgpool(x)
        predicted_colour = self.colour(x.view(x.size(0), -1))
        predicted_article = self.article(x.view(x.size(0), -1))
        pred_tensor = torch.cat([predicted_colour, predicted_article], dim=1)
        return pred_tensor

# Loss function definition

Here you define a custom loss function called `CustomLoss`, which extends the PyTorch `nn.Module` class. This custom loss function calculates the average of the cross-entropy loss between the predicted outputs and the ground truth targets.

In the `__init__` method, the `super()` function is called to initialize the parent class `nn.Module`.

The cross-entropy loss is calculated using the PyTorch `nn.CrossEntropyLoss()` function in the forward method. The predicted outputs and target are passed as arguments to this function. The `preds` tensor contains the predicted outputs from the model and has shape (batch_size, num_classes), where `num_classes` is the total number of classes.

Next, the colour loss and article loss are calculated separately using the cross-entropy loss function. The colour loss is calculated by taking the first `N_COLOUR_LABELS` columns of the preds tensor, while the article loss is calculated by taking the remaining columns.

Finally, the average loss is calculated as the average of the colour and article loss. This is returned as the output of the forward method.

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, preds, target):
        criterion = nn.CrossEntropyLoss()
        colour_loss = criterion(preds[:, :config.N_COLOUR_LABELS], target[:,:config.N_COLOUR_LABELS])
        article_loss = criterion(preds[:,config.N_COLOUR_LABELS:], target[:,config.N_COLOUR_LABELS:])
        average_loss = (colour_loss + article_loss)/2
        return average_loss

# Accuracy metric definition

This code defines a custom metric called `MyAccuracy` that inherits from the PyTorch Metric class. This metric measures the accuracy of a multi-task classification model that predicts an image's colour and article labels.

In the constructor of MyAccuracy, four state variables are created using the `add_state` method. These state variables keep track of the number of correct predictions and the total number of predictions made for each task (colour and article). The `dist_reduce_fx` argument specifies how to aggregate the state variables across different devices in a distributed setting. 

Here, the sum is used to add the values of the state variables across devices.

The update method of `MyAccuracy` is called after each batch of data is processed. It takes two arguments: `preds`, a tensor of shape (batch_size, num_labels) representing the model's predictions, and `target`, a tensor of the same shape representing the ground truth labels.

The `preds` tensor is split into colour and article predictions using slicing in the update method. The `argmax` function is used to get the indices of the predicted labels for each task. The same is done for the target tensor to get the true indices of the labels.

The number of correct predictions and the total number of predictions are then updated for each task using the `self.color_correct`, `self.color_total`, `self.article_correct`, and `self.article_total` state variables.

The compute method of MyAccuracy is called after all the batches have been processed. It calculates the total accuracy of the model as the average of the accuracies for each task. The accuracies are computed by dividing the number of correct predictions by the total number of predictions for each task and then taking the average of the two accuracies. 

The total accuracy is then returned.

In [None]:
@register_metric("my_accuracy")
class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("color_correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("color_total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("article_correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("article_total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        # Split the output into color and article predictions
        color_output, article_output = preds[:, :config.N_COLOUR_LABELS], preds[:, config.N_COLOUR_LABELS:]

        # Get the predicted indices for each task
        color_pred = torch.argmax(color_output, dim=1)
        article_pred = torch.argmax(article_output, dim=1)

        # Get the true indices for each task
        color_true, article_true = target[:, :config.N_COLOUR_LABELS], target[:, config.N_COLOUR_LABELS:]
        color_true = torch.argmax(color_true, dim=1)
        article_true = torch.argmax(article_true, dim=1)

        # Calculate the accuracy for each task
        self.color_correct += torch.sum(color_pred == color_true)
        self.color_total += color_true.numel()
        self.article_correct += torch.sum(article_pred == article_true)
        self.article_total += article_true.numel()

    def compute(self):
        # Calculate the total accuracy as the average of the two task accuracies
        color_acc = self.color_correct.float() / self.color_total
        article_acc = self.article_correct.float() / self.article_total
        total_acc = (color_acc + article_acc) / 2
        return total_acc

# Training hyperparameters



In [None]:
training_params =  training_hyperparams.get(config.TRAINING_PARAMS)

training_params['loss'] = CustomLoss()

training_params["train_metrics_list"] = ["my_accuracy"]
training_params["valid_metrics_list"] = ["my_accuracy"]
training_params["metric_to_watch"] = "my_accuracy"

# Set the silent mode to True to reduce clutter in the notebook, you can turn it on to see the full output
training_params["silent_mode"] = True
training_params["optimizer"] = 'AdamW'
training_params["ema"] = True
training_params["criterion_params"] = {'smooth_eps': 0.10}
training_params["average_best_models"] = True
training_params["max_epochs"] = 5
training_params["initial_lr"] = 0.001

# Train the model!

In [None]:
multi_label_model = MultilabelClassifier()
model_trainer = Trainer(experiment_name='0_Baseline_Experiment', ckpt_root_dir=config.CHECKPOINT_DIR)
model_trainer.train(model= multi_label_model, 
              training_params=training_params, 
              train_loader=train_loader,
              valid_loader=val_loader)

We are going to leverage weight averaging, which is  post-training method that takes the best model weights across the training and averages them into a single model. 

By employing this technique you help your model overcome the optimization tendency to alternate between adjacent local minimas in the later stages of the training.

It turns out that more often than not, you can average these model weights and obtain a model that perfoms better than the individual models.

You can compare the averaged model with model trained during the last epoch and see which one performs better - if it turns out the last epoch does better, you can keep that one!

Weight averaging doesn’t cost you anything other than keeping a few additional weights on the disk, and can yield a substantial boost in performance and stability.

If you want to see the "best" model checkpoint as opposed to the averaged weights, swap `average_model.pth` with `ckpt_best.pth` below


In [None]:
best_full_model = models.get('multilabelmodel',
                        num_classes=config.NUM_CLASSES,
                        checkpoint_path=os.path.join(model_trainer.checkpoints_dir_path, "average_model.pth"))

## Evaluate the model on unseen data



In [None]:
model_trainer.test(model=best_full_model,
            test_loader=test_loader,
            test_metrics_list=['my_accuracy'])

In [None]:
from typing import Tuple
import requests
import torchvision
import random
import textwrap

def pred_and_plot_image(image_path: str, 
                        subplot: Tuple[int, int, int],  # subplot tuple for `subplot()` function
                        model: torch.nn.Module = best_full_model,
                        image_size: Tuple[int, int] = (config.INPUT_HEIGHT, config.INPUT_WIDTH),
                        transform: torchvision.transforms = None,
                        device: torch.device=config.DEVICE):

    if isinstance(image_path, pathlib.PosixPath):
        img = Image.open(image_path)
    else: 
        img = Image.open(requests.get(image_path, stream=True).raw)

    # create transformation for image (if one doesn't exist)
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=config.IMAGENET_MEAN,
                                 std=config.IMAGENET_STD),
        ])
    transformed_image = transform(img)

    # make sure the model is on the target device
    model.to(device)

    # turn on model evaluation mode and inference mode
    model.eval()
    with torch.inference_mode():
        # add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
        transformed_image = transformed_image.unsqueeze(dim=0)

        # make a prediction on image with an extra dimension and send it to the target device
        target_image_pred = model(transformed_image.to(device))

    # convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
    color = torch.argmax(target_image_pred[:, :config.N_COLOUR_LABELS], dim=1).item()
    article = torch.argmax(target_image_pred[:, config.N_COLOUR_LABELS:], dim=1).item()
    
    predicted_color = config.COLOUR_LABELS[color]
    predicted_article = config.ARTICLE_LABELS[article]
    
    target_image_pred_label = predicted_color + '_' +  predicted_article

    # plot image with predicted label and probability 
    plt.subplot(*subplot)
    plt.imshow(img)
    if isinstance(image_path, pathlib.PosixPath):
        # actual label
        ground_truth = re.findall(r'\w+\_\w+', str(image_path))[0]
        title = f"Ground Truth: {ground_truth} | Pred: {target_image_pred_label}"
    else:
        title = f"Pred: {target_image_pred_label}"
    plt.title("\n".join(textwrap.wrap(title, width=20)))  # wrap text using textwrap.wrap() function
    plt.axis(False)
    

def plot_random_test_images(model, test_images):
    num_images_to_plot = 30
    test_image_path_list = [pathlib.PosixPath(p) for p in test_images] 
    test_image_path_sample = random.sample(population=test_image_path_list,  # randomly select 'k' image paths to pred and plot
                                           k=num_images_to_plot)

    # set up subplots
    num_rows = int(np.ceil(num_images_to_plot / 5))
    fig, ax = plt.subplots(num_rows, 5, figsize=(15, num_rows * 3))
    ax = ax.flatten()

    # Make predictions on and plot the images
    for i, image_path in enumerate(test_image_path_sample):
        pred_and_plot_image(model=model, 
                            image_path=image_path,
                            subplot=(num_rows, 5, i+1),  # subplot tuple for `subplot()` function
                            image_size=(config.INPUT_HEIGHT, config.INPUT_WIDTH))

    # adjust spacing between subplots
    plt.subplots_adjust(wspace=1)
    plt.show()

In [None]:
plot_random_test_images(best_full_model, test_images)

In [None]:
plot_random_test_images(best_full_model, test_images)

In [None]:
pred_and_plot_image(image_path='https://sneakernews.com/wp-content/uploads/2020/11/Air-Jordan-13-GS-DC9443-007-01.jpg', subplot=(1, 1, 1))

In [None]:
pred_and_plot_image(image_path='https://images.solecollector.com/complex/image/upload/c_fill,dpr_auto,f_auto,fl_lossy,g_face,q_auto,w_1280/cjapz33ntap9njssipiw_beyuom.jpg', subplot=(1, 1, 1))

In [None]:
pred_and_plot_image(image_path='https://assets.adidas.com/images/w_600,f_auto,q_auto/0e8981a897f446468189af0000b7e69e_9366/Adicolor_SST_Track_Suit_Blue_IB8636_01_laydown.jpg', subplot=(1, 1, 1))


# Your homework

Copy/fork this notebook and try some different architectures.

If you have a question you can leave a comment on this notebook, or visit the community and post it in the [Q&A section](https://www.deeplearningdaily.community/c/qanda/8).

## Use a different pretrained model

You can change the model you use. Take a look at the [SG model zoo](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/Computer_Vision_Models_Pretrained_Checkpoints.md)

For example, if you wanted to use RegNet you would do the following:

```
resnet_imagenet_model = models.get(model_name='regnetY800', num_classes=NUM_CLASSES, pretrained_weights='imagenet)
resnet_params =  training_hyperparams.get('training_hyperparams/imagenet_regnetY_train_params')
```

Note you can also pass 'model_name=regnetY200', 'model_name=regnetY400', 'model_name=regnetY600' to try a variety of the architecture

For ResNet50, you would do:

```
resnet_imagenet_model = models.get(model_name='resnet50', num_classes=NUM_CLASSES, pretrained_weights='imagenet)
resnet_params =  training_hyperparams.get('training_hyperparams/imagenet_resnet50_train_params')
```

Note you can also pass 'model_name=resnet18' or 'model_name=resnet34' to try a variety of the architecture

For MobileNetV2, you would do:

```
mobilenet_imagenet_model = models.get(model_name='mobilenet_v2', num_classes=NUM_CLASSES, pretrained_weights='imagenet)
resnet_params =  training_hyperparams.get('training_hyperparams/imagenet_mobilenetv2_train_params')
```

For MobileNetV3, you would do:

```
mobilenet_imagenet_model = models.get(model_name='mobilenet_v3_large', num_classes=NUM_CLASSES, pretrained_weights='imagenet)
resnet_params =  training_hyperparams.get('training_hyperparams/imagenet_mobilenetv3_train_params')
```

Note you can also pass 'model_name=mobilenet_v3_small' to try a variety of the architecture


For ViT, you would do:


```
vit_imagenet_model = models.get(model_name='vit_base', num_classes=NUM_CLASSES, pretrained_weights='imagenet')
vit_params =  training_hyperparams.get("training_hyperparams/imagenet_vit_train_params")
```

Note you can also pass 'model_name=vit_large' to try a variety of the architecture


I encourage you play around with different optimizers, all you have to do is change the value of `training_params["optimizer"]`. You can use one of ['Adam','SGD','RMSProp'] out of the box. You can play around with the optimizer params as well.

In general, play and tweak around the training recipies...

## Training recipes

SuperGradients has a number of [training recipes](https://github.com/Deci-AI/super-gradients/tree/master/src/super_gradients/recipes) you can use. [See here](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/training_hyperparams/default_train_params.yaml) for more information about the training params.

If you're using Weights and Biases to track your experiments, you would do the following

```
sg_logger: wandb_sg_logger
sg_logger_params:
project_name: <YOUR PROJECT NAME>
entity: algo
api_server: https://wandb.research.deci.ai
save_checkpoints_remote: True
save_tensorboard_remote: True
save_logs_remote: True
```