# Quickstart Guide
This notebook will guide you to train the Disaster Detection model with 3 classes Earthquake,Flood and Wildfire using Vision Transformer

**Note**: The easiest way to train the model is in Google Colab or Kaggle, which allows you to dive in with no setup. We recommend you enable the GPU runtime to train the model efficiently

**Note**: You need to have at least Python 3.6 to run the scripts.

## Install HugsVision

First we install HugsVision if needed. 

try:
    import hugsvision
except:
    !pip install -q hugsvision
    import hugsvision
    
print(hugsvision.__version__)

## Downloading Data

First, we need to download the dataset called `Disaster Images Dataset` [here](https://www.kaggle.com/datasets/mikolajbabula/disaster-images-dataset-cnn-model/code) which weight around ~3 GB.

## Loading Data

Once it has been converted, we can start loading the data.

In [None]:
from hugsvision.dataio.VisionDataset import VisionDataset

train,test,id2label, label2id = VisionDataset.fromImageFolder(
  "./dataset/",
  test_ratio   = 0.15,
  balanced     = True,
  augmentation = True,
)

## Choose a image classifier model on HuggingFace

Now we can choose our base model on which we will perform a fine-tuning to make it fit our needs.

Our choices aren't very large since we haven't a lot of model available yet on HuggingFace for this task.

So, to be sure that the model will be compatible with `HugsVision` we need to have a model exported in `PyTorch` and compatible with the `image-classification` task obviously.

Models available with this criterion: [here](https://huggingface.co/models?filter=pytorch&pipeline_tag=image-classification&sort=downloads)

At the time I'am writing this, I recommend to use the following models:

* `google/vit-base-patch16-224-in21k`
* `google/vit-base-patch16-224`
* `facebook/deit-base-distilled-patch16-224`
* `microsoft/beit-base-patch16-224`

**Note:** Please specify `ignore_mismatched_sizes=True` for both `model` and `feature_extractor` if you aren't using the following model.

In [None]:
huggingface_model = 'google/vit-base-patch16-224'

## Train the model

So, once the model choosen, we can start building the `Trainer` and start the fine-tuning.

**Note**: Import the `FeatureExtractor` and `ForImageClassification` according to your previous choice.

In [None]:
from hugsvision.nnet.VisionClassifierTrainer import VisionClassifierTrainer
from transformers import ViTFeatureExtractor, ViTForImageClassification

trainer = VisionClassifierTrainer(
	model_name   = "MyDisasterModel",
	train        = train,
	test         = test,
	output_dir   = "./out/",
	max_epochs   = 15,
	batch_size   = 32, 
	lr	     = 2e-5,
	fp16	     = True,
	model = ViTForImageClassification.from_pretrained(
	    huggingface_model,
	    num_labels = len(label2id),
	    label2id   = label2id,
	    id2label   = id2label,
        ignore_mismatched_sizes=True
	),
	feature_extractor = ViTFeatureExtractor.from_pretrained(
		huggingface_model, ignore_mismatched_sizes=True
	),
)

## Evaluate F1-Score

Using the F1-Score metrics will allow us to get a better representation of predictions for all the labels and find out if their are any anomalies wit ha specific label.

In [None]:
hyp, ref = trainer.evaluate_f1_score()

## Make a prediction

Rename the `./out/MODEL_PATH/config.json` file present in the model output to `./out/MODEL_PATH/preprocessor_config.json`

In [None]:
import os.path
from transformers import ViTFeatureExtractor, ViTForImageClassification
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference

path1 = "./out/MyDisasterModel/10_2024-03-22-00-08-07/feature_extractor/"
path2 = "./out/MyDisasterModel/10_2024-03-22-00-08-07/model/"
img  = "C:/Users/rohan/OneDrive/Desktop/DisasterManage/Disaster-Management/Disasterpics/china-earthquake-21.jpg"

classifier = VisionClassifierInference(
    feature_extractor = ViTFeatureExtractor.from_pretrained(path1),
    model = ViTForImageClassification.from_pretrained(path2),
)

label = classifier.predict(img_path=img)
print("Predicted class:", label)