# Example of the `aitlas` toolbox in the context of multi label image classification

This notebook shows a sample implementation of a multi label image classification using the `aitlas` toolbox using the MLRS Net multi label dataset.

In [1]:
from aitlas.datasets import MLRSNetMultiLabelDataset
from aitlas.models import ResNet50MultiLabel
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader

  from .autonotebook import tqdm as notebook_tqdm


## Load the dataset

In [2]:
dataset_config = {
    "data_dir": "./data/MLRSNet/images",
    "csv_file": "./data/MLRSNet/multilabels.txt"
}
dataset = MLRSNetMultiLabelDataset(dataset_config)

FileNotFoundError: [Errno 2] No such file or directory: './data/MLRSNet/multilabels.txt'

## Show images from the dataset

In [None]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(30)
fig3 = dataset.show_batch(15)

## Inspect the data

In [None]:
dataset.show_samples()

In [None]:
dataset.data_distribution_table()

In [None]:
fig = dataset.data_distribution_barchart()

## Load train and test splits

In [None]:
train_dataset_config = {
    "batch_size": 16,
    "shuffle": True,
    "num_workers": 4,
    "data_dir": "./data/MLRSNet/images",
    "csv_file": "./data/MLRSNet/train.csv"
}

train_dataset = MLRSNetMultiLabelDataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor() 

test_dataset_config = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "./data/MLRSNet/images",
    "csv_file": "./data/MLRSNet/test.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

test_dataset = MLRSNetMultiLabelDataset(test_dataset_config)
len(train_dataset), len(test_dataset)

## Setup and create the model for training

In [None]:
epochs = 10
model_directory = "/data/scratch/public/mlrsnet/best_checkpoint_1699726140_24.pth.tar"
model_config = {
    "num_classes": 60, 
    "learning_rate": 0.0001,
    "pretrained": True, 
    "threshold": 0.5, 
    "metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50MultiLabel(model_config)
model.prepare()

## Training and evaluation

In [None]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    epochs=epochs,
    model_directory=model_directory,
    val_dataset=test_dataset,
    run_id='1',
)

## Predictions

In [None]:
model_path = "./experiments/MLRSNet/checkpoint.pth.tar"
#labels = MLRSNetMultiLabelDataset.labels
labels = ["airplane", "airport", "bare soil", "baseball diamond", "basketball court", "beach", "bridge", "buildings",
          "cars", "cloud", "containers", "crosswalk", "dense residential area", "desert", "dock", "factory", "field",
          "football field", "forest", "freeway", "golf course", "grass", "greenhouse", "gully", "habor", "intersection",
          "island", "lake", "mobile home", "mountain", "overpass", "park", "parking lot", "parkway", "pavement",
          "railway", "railway station", "river", "road", "roundabout", "runway", "sand", "sea", "ships", "snow",
          "snowberg", "sparse residential area", "stadium", "swimming pool", "tanks", "tennis court", "terrace",
          "track", "trail", "transmission tower", "trees", "water", "chaparral", "wetland", "wind turbine"]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)

image = image_loader('./data/predict/image1.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image2.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image3.tif')
fig = model.predict_image(image, labels, transform)

image = image_loader('./data/predict/image4.tif')
fig = model.predict_image(image, labels, transform)