# Database importations

## Importations

In [1]:
import pandas as pd
import numpy as np

from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from datasets import Dataset

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

from PIL import Image
import requests
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm





## Dataset creation

In [2]:
images_path = [os.path.join('cropped_image', f) for f in os.listdir('cropped_image') if f.endswith('.png')]

images_path = [i for i in tqdm(images_path) if len(np.array(Image.open(i)).shape) == 3] # remove grayscale images

labels = [i.split('\\')[1].split('_')[:-1] for i in images_path]
labels = ['_'.join(i) for i in labels]

100%|██████████| 4414/4414 [00:13<00:00, 334.93it/s]


In [3]:
df_total = pd.DataFrame({'path': images_path, 'label': labels})

In [4]:
df_total.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4130 entries, 0 to 4129
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   path    4130 non-null   object
 1   label   4130 non-null   object
dtypes: object(2)
memory usage: 64.7+ KB


In [5]:
df_total.nunique()

path     4130
label     484
dtype: int64

In [6]:
unique_labels = df_total['label'].unique()
num_unique_labels = [i for i in range(len(unique_labels))]

df_labels = pd.DataFrame({'label': unique_labels, 'label_id': num_unique_labels})

df_total = df_total.merge(df_labels, on='label', how='left')
df_total

Unnamed: 0,path,label,label_id
0,cropped_image\A73EGS-P_1.png,A73EGS-P,0
1,cropped_image\A73EGS-P_3.png,A73EGS-P,0
2,cropped_image\A73EGS-P_4.png,A73EGS-P,0
3,cropped_image\A73EGS-P_5.png,A73EGS-P,0
4,cropped_image\A73EGS-P_6.png,A73EGS-P,0
...,...,...,...
4125,cropped_image\zeus_faber_3.png,zeus_faber,483
4126,cropped_image\zeus_faber_4.png,zeus_faber,483
4127,cropped_image\zeus_faber_5.png,zeus_faber,483
4128,cropped_image\zeus_faber_6.png,zeus_faber,483


## Processing images

Here we use a pre-trained Vision Transformer (ViT) model, so we first pre-process the images in 4 steps

- <b>Image resizing</b>: image resized to a fixed resolution of $224 \times 224$ pixels. 

- <b>Normalization</b>: pixel values of the image are normalized across the RGB channels. This involves scaling the pixel values to have a mean of 0.5 and a standard deviation of 0.5 for each channel. 

- <b>Patch Extraction</b>: image is then divided into non-overlapping patches, each of size 16x16 pixels. This results in a sequence of patches that the model processes, treating each patch similarly to how tokens are treated in natural language processing tasks.

- <b>Linear Embedding</b>: Each 16x16 patch is flattened into a 1-dimensional vector and then linearly transformed into an embedding of a specified dimension. This step translates the raw pixel data into a format suitable for input into the transformer's architecture.

- <b>Position Embedding Addition</b>: Position embeddings are added to each patch to retain spatial information. These embeddings encode the position of each patch within the original image, allowing the model to understand the spatial relationships between patches.

- <b>Class Token Addition</b>: A special classification token ([CLS]) is prepended to the sequence of patch embeddings. The final hidden state corresponding to this token is used as the aggregate representation of the entire image for classification purposes.

All these steps are explained deeper on <a href = "https://huggingface.co/google/vit-base-patch16-224">the Hugging Face page</a> of the model and in the original publication <a href="https://arxiv.org/abs/2010.11929">An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale</a> by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. In the publication we could also find all informations on the model architecture etc. 



We then instantiate a Vision Transformer model specifically designed for image classification.

It uses a pre-trained model identified by google that has been trained on the ImageNet-21k dataset

This model is loaded with pre-trained weights and is used to adjust the classification head to match the number of classes in the specific task.

In [7]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=len(unique_labels))

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


This function loads an image from a path, pre-processes it using the defined processor (resizing, normalization, tensor conversion) and returns the pixel values in tensor form, ready for use by the model.

In [8]:
def preprocess_image(image_path, processor):
    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")
    return inputs['pixel_values']  # Extract only pixel values (tensor)


This function takes a DataFrame containing image paths and their labels. It preprocesses each image using preprocess_image to obtain the pixel values and extracts the labels as tensors. It then creates a Dataset object containing the pixels and labels, ready for training or evaluation.

In [9]:
def preprocess_dataset(df, path_col, labels_idx_col, processor):
    preprocessed_images = []
    labels = torch.tensor(df[labels_idx_col].values)

    for i in tqdm(df[path_col].values, desc='Preprocessing images'):
        pixel_values = preprocess_image(i, processor)  # Extract only pixel values (tensor)
        preprocessed_images.append(pixel_values.squeeze())  # Squeeze the tensor if necessary to remove extra dimensions
    
    dataset_dict = {'pixel_values': preprocessed_images, 'labels': labels}
    dataset = Dataset.from_dict(dataset_dict)

    return dataset

This functions perform a prediction on a given image. It loads the image, preprocesses it with preprocess_image, sends the data to the same device as the model, and uses the model to obtain logits. The index of the highest score among the logits is returned as the predicted class.

In [28]:

def predict(image_path, processor, model):
    image = Image.open(image_path)
    inputs = preprocess_image(image_path, processor).to(model.device)
    outputs = model(inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return predicted_class_idx

def predict_from_image(image, processor, model):
    inputs = processor(images=image, return_tensors="pt").to(model.device)
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return predicted_class_idx

In [11]:
processed_dataset = preprocess_dataset(df_total, 'path', 'label_id', processor)

Preprocessing images: 100%|██████████| 4130/4130 [00:28<00:00, 144.10it/s]


In [12]:
processed_dataset_split = processed_dataset.train_test_split(test_size=0.2)

## Training of the model

This code configures training parameters, including batch sizes, number of epochs, learning rate and save and evaluate strategies. A Trainer object is then created to automate model training and evaluation on the appropriate device (CPU or GPU), using the datasets and processor provided.

CHOICE OF MODEL :

The Vision Transformer (ViT) model is chosen for this project because of its ability to capture fine details and global relationships in images, which is essential for differentiating fish species. Pre-trained on large datasets such as ImageNet, it can be adjusted efficiently on specific datasets, even of small size. Its robustness to angle, lighting and background variations makes it a powerful choice for species classification in complex visual environments.

In [13]:
# Training arguments
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

training_args = TrainingArguments(
    output_dir='./results',          # Output directory
    evaluation_strategy="epoch",    # Evaluate after each epoch
    save_strategy="epoch",          # Save after each epoch
    per_device_train_batch_size=16, # Adjust based on GPU/CPU memory
    per_device_eval_batch_size=16,  # Adjust based on GPU/CPU memory
    num_train_epochs=3,             # Number of epochs
    warmup_steps=500,               # Warmup steps for learning rate scheduler
    learning_rate=5e-5,             # Learning rate
    logging_dir='./logs',           # Logging directory
    logging_steps=10,
    load_best_model_at_end=True
)

# Define Trainer
trainer = Trainer(
    model=model.to(device),
    args=training_args,
    train_dataset=processed_dataset_split["train"],
    eval_dataset=processed_dataset_split["test"],
    tokenizer=processor
)




In [14]:
trainer.train()

  context_layer = torch.nn.functional.scaled_dot_product_attention(
  2%|▏         | 10/621 [00:14<13:02,  1.28s/it]

{'loss': 6.1977, 'grad_norm': 1.7483935356140137, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05}


  3%|▎         | 20/621 [00:26<12:20,  1.23s/it]

{'loss': 6.2022, 'grad_norm': 1.6503273248672485, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.1}


  5%|▍         | 30/621 [00:39<12:07,  1.23s/it]

{'loss': 6.1953, 'grad_norm': 1.9379818439483643, 'learning_rate': 3e-06, 'epoch': 0.14}


  6%|▋         | 40/621 [00:51<11:55,  1.23s/it]

{'loss': 6.1844, 'grad_norm': 1.7241199016571045, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.19}


  8%|▊         | 50/621 [01:04<11:43,  1.23s/it]

{'loss': 6.1934, 'grad_norm': 1.6593527793884277, 'learning_rate': 5e-06, 'epoch': 0.24}


 10%|▉         | 60/621 [01:17<12:57,  1.39s/it]

{'loss': 6.1747, 'grad_norm': 1.7218977212905884, 'learning_rate': 6e-06, 'epoch': 0.29}


 11%|█▏        | 70/621 [01:31<12:21,  1.35s/it]

{'loss': 6.1806, 'grad_norm': 1.6053146123886108, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.34}


 13%|█▎        | 80/621 [01:44<12:07,  1.35s/it]

{'loss': 6.1832, 'grad_norm': 1.6769832372665405, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.39}


 14%|█▍        | 90/621 [01:58<12:00,  1.36s/it]

{'loss': 6.1746, 'grad_norm': 1.719382405281067, 'learning_rate': 9e-06, 'epoch': 0.43}


 16%|█▌        | 100/621 [02:12<12:10,  1.40s/it]

{'loss': 6.1743, 'grad_norm': 1.7115355730056763, 'learning_rate': 1e-05, 'epoch': 0.48}


 18%|█▊        | 110/621 [02:25<11:42,  1.38s/it]

{'loss': 6.18, 'grad_norm': 1.7326316833496094, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.53}


 19%|█▉        | 120/621 [02:39<11:33,  1.38s/it]

{'loss': 6.1602, 'grad_norm': 2.5346083641052246, 'learning_rate': 1.2e-05, 'epoch': 0.58}


 21%|██        | 130/621 [02:53<11:02,  1.35s/it]

{'loss': 6.1458, 'grad_norm': 1.8798775672912598, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.63}


 23%|██▎       | 140/621 [03:07<10:58,  1.37s/it]

{'loss': 6.1525, 'grad_norm': 1.7022583484649658, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.68}


 24%|██▍       | 150/621 [03:20<10:44,  1.37s/it]

{'loss': 6.148, 'grad_norm': 1.8175896406173706, 'learning_rate': 1.5e-05, 'epoch': 0.72}


 26%|██▌       | 160/621 [03:34<10:35,  1.38s/it]

{'loss': 6.1213, 'grad_norm': 1.8142632246017456, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.77}


 27%|██▋       | 170/621 [03:47<09:25,  1.25s/it]

{'loss': 6.1382, 'grad_norm': 1.962619662284851, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.82}


 29%|██▉       | 180/621 [03:59<09:05,  1.24s/it]

{'loss': 6.1176, 'grad_norm': 1.820120096206665, 'learning_rate': 1.8e-05, 'epoch': 0.87}


 31%|███       | 190/621 [04:12<08:54,  1.24s/it]

{'loss': 6.1198, 'grad_norm': 2.07236909866333, 'learning_rate': 1.9e-05, 'epoch': 0.92}


 32%|███▏      | 200/621 [04:24<08:48,  1.25s/it]

{'loss': 6.0903, 'grad_norm': 1.9042634963989258, 'learning_rate': 2e-05, 'epoch': 0.97}


                                                 
 33%|███▎      | 207/621 [05:28<05:41,  1.21it/s]

{'eval_loss': 6.0888447761535645, 'eval_runtime': 56.5691, 'eval_samples_per_second': 14.602, 'eval_steps_per_second': 0.919, 'epoch': 1.0}


 34%|███▍      | 210/621 [05:35<1:07:42,  9.88s/it]

{'loss': 6.0786, 'grad_norm': 1.9302380084991455, 'learning_rate': 2.1e-05, 'epoch': 1.01}


 35%|███▌      | 220/621 [05:47<09:36,  1.44s/it]  

{'loss': 6.0332, 'grad_norm': 2.0529959201812744, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.06}


 37%|███▋      | 230/621 [05:59<07:59,  1.23s/it]

{'loss': 6.0281, 'grad_norm': 2.090045690536499, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.11}


 39%|███▊      | 240/621 [06:12<08:14,  1.30s/it]

{'loss': 6.0056, 'grad_norm': 2.1024014949798584, 'learning_rate': 2.4e-05, 'epoch': 1.16}


 40%|████      | 250/621 [06:25<08:27,  1.37s/it]

{'loss': 6.0018, 'grad_norm': 1.8880449533462524, 'learning_rate': 2.5e-05, 'epoch': 1.21}


 42%|████▏     | 260/621 [06:38<07:17,  1.21s/it]

{'loss': 5.998, 'grad_norm': 2.0043134689331055, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.26}


 43%|████▎     | 270/621 [06:50<07:03,  1.21s/it]

{'loss': 5.9829, 'grad_norm': 1.992096185684204, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.3}


 45%|████▌     | 280/621 [07:02<06:52,  1.21s/it]

{'loss': 5.9811, 'grad_norm': 2.0707788467407227, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.35}


 47%|████▋     | 290/621 [07:15<06:59,  1.27s/it]

{'loss': 5.9985, 'grad_norm': 2.255643367767334, 'learning_rate': 2.9e-05, 'epoch': 1.4}


 48%|████▊     | 300/621 [07:28<07:11,  1.34s/it]

{'loss': 5.9585, 'grad_norm': 2.2047250270843506, 'learning_rate': 3e-05, 'epoch': 1.45}


 50%|████▉     | 310/621 [07:41<06:14,  1.20s/it]

{'loss': 5.9629, 'grad_norm': 2.7350809574127197, 'learning_rate': 3.1e-05, 'epoch': 1.5}


 52%|█████▏    | 320/621 [07:53<06:39,  1.33s/it]

{'loss': 5.9664, 'grad_norm': 2.1389377117156982, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.55}


 53%|█████▎    | 330/621 [08:09<07:25,  1.53s/it]

{'loss': 5.9168, 'grad_norm': 2.0459482669830322, 'learning_rate': 3.3e-05, 'epoch': 1.59}


 55%|█████▍    | 340/621 [08:24<07:07,  1.52s/it]

{'loss': 5.952, 'grad_norm': 2.1704373359680176, 'learning_rate': 3.4000000000000007e-05, 'epoch': 1.64}


 56%|█████▋    | 350/621 [08:39<06:57,  1.54s/it]

{'loss': 5.9026, 'grad_norm': 2.275416612625122, 'learning_rate': 3.5e-05, 'epoch': 1.69}


 58%|█████▊    | 360/621 [08:54<06:26,  1.48s/it]

{'loss': 5.871, 'grad_norm': 2.0861480236053467, 'learning_rate': 3.6e-05, 'epoch': 1.74}


 60%|█████▉    | 370/621 [09:09<06:31,  1.56s/it]

{'loss': 5.9518, 'grad_norm': 2.106750249862671, 'learning_rate': 3.7e-05, 'epoch': 1.79}


 61%|██████    | 380/621 [09:24<05:52,  1.46s/it]

{'loss': 5.8844, 'grad_norm': 2.122260332107544, 'learning_rate': 3.8e-05, 'epoch': 1.84}


 63%|██████▎   | 390/621 [09:38<05:20,  1.39s/it]

{'loss': 5.8941, 'grad_norm': 1.987959861755371, 'learning_rate': 3.9000000000000006e-05, 'epoch': 1.88}


 64%|██████▍   | 400/621 [09:53<05:35,  1.52s/it]

{'loss': 5.8597, 'grad_norm': 2.0320940017700195, 'learning_rate': 4e-05, 'epoch': 1.93}


 66%|██████▌   | 410/621 [10:08<05:15,  1.50s/it]

{'loss': 5.8599, 'grad_norm': 2.066953182220459, 'learning_rate': 4.1e-05, 'epoch': 1.98}


                                                 
 67%|██████▋   | 414/621 [11:27<03:23,  1.02it/s]

{'eval_loss': 5.8559794425964355, 'eval_runtime': 74.3851, 'eval_samples_per_second': 11.104, 'eval_steps_per_second': 0.699, 'epoch': 2.0}


 68%|██████▊   | 420/621 [11:39<18:02,  5.39s/it]  

{'loss': 5.7417, 'grad_norm': 2.1574792861938477, 'learning_rate': 4.2e-05, 'epoch': 2.03}


 69%|██████▉   | 430/621 [11:53<04:20,  1.36s/it]

{'loss': 5.672, 'grad_norm': 2.179776430130005, 'learning_rate': 4.3e-05, 'epoch': 2.08}


 71%|███████   | 440/621 [12:06<04:13,  1.40s/it]

{'loss': 5.6693, 'grad_norm': 2.058309555053711, 'learning_rate': 4.4000000000000006e-05, 'epoch': 2.13}


 72%|███████▏  | 450/621 [12:20<03:53,  1.37s/it]

{'loss': 5.6403, 'grad_norm': 2.137808084487915, 'learning_rate': 4.5e-05, 'epoch': 2.17}


 74%|███████▍  | 460/621 [12:34<03:42,  1.38s/it]

{'loss': 5.656, 'grad_norm': 2.1934173107147217, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.22}


 76%|███████▌  | 470/621 [12:46<03:04,  1.22s/it]

{'loss': 5.5869, 'grad_norm': 2.2831220626831055, 'learning_rate': 4.7e-05, 'epoch': 2.27}


 77%|███████▋  | 480/621 [12:58<02:50,  1.21s/it]

{'loss': 5.645, 'grad_norm': 2.2567434310913086, 'learning_rate': 4.8e-05, 'epoch': 2.32}


 79%|███████▉  | 490/621 [13:11<02:38,  1.21s/it]

{'loss': 5.5948, 'grad_norm': 2.18607759475708, 'learning_rate': 4.9e-05, 'epoch': 2.37}


 81%|████████  | 500/621 [13:23<02:26,  1.21s/it]

{'loss': 5.6041, 'grad_norm': 2.1543853282928467, 'learning_rate': 5e-05, 'epoch': 2.42}


 82%|████████▏ | 510/621 [13:35<02:15,  1.22s/it]

{'loss': 5.6027, 'grad_norm': 2.295408010482788, 'learning_rate': 4.586776859504133e-05, 'epoch': 2.46}


 84%|████████▎ | 520/621 [13:47<02:01,  1.20s/it]

{'loss': 5.559, 'grad_norm': 2.264561176300049, 'learning_rate': 4.1735537190082645e-05, 'epoch': 2.51}


 85%|████████▌ | 530/621 [13:59<01:50,  1.21s/it]

{'loss': 5.5517, 'grad_norm': 2.201387405395508, 'learning_rate': 3.760330578512397e-05, 'epoch': 2.56}


 87%|████████▋ | 540/621 [14:12<01:42,  1.27s/it]

{'loss': 5.5406, 'grad_norm': 2.156916618347168, 'learning_rate': 3.347107438016529e-05, 'epoch': 2.61}


 89%|████████▊ | 550/621 [14:24<01:28,  1.25s/it]

{'loss': 5.5478, 'grad_norm': 2.256890296936035, 'learning_rate': 2.9338842975206616e-05, 'epoch': 2.66}


 90%|█████████ | 560/621 [14:37<01:18,  1.29s/it]

{'loss': 5.5286, 'grad_norm': 2.048412799835205, 'learning_rate': 2.5206611570247934e-05, 'epoch': 2.71}


 92%|█████████▏| 570/621 [14:51<01:11,  1.39s/it]

{'loss': 5.5643, 'grad_norm': 2.1092517375946045, 'learning_rate': 2.1074380165289255e-05, 'epoch': 2.75}


 93%|█████████▎| 580/621 [15:04<00:52,  1.28s/it]

{'loss': 5.4578, 'grad_norm': 2.4177799224853516, 'learning_rate': 1.694214876033058e-05, 'epoch': 2.8}


 95%|█████████▌| 590/621 [15:17<00:40,  1.30s/it]

{'loss': 5.4192, 'grad_norm': 2.231384754180908, 'learning_rate': 1.2809917355371901e-05, 'epoch': 2.85}


 97%|█████████▋| 600/621 [15:30<00:27,  1.30s/it]

{'loss': 5.4881, 'grad_norm': 2.283243417739868, 'learning_rate': 8.677685950413224e-06, 'epoch': 2.9}


 98%|█████████▊| 610/621 [15:43<00:15,  1.38s/it]

{'loss': 5.4326, 'grad_norm': 2.2273459434509277, 'learning_rate': 4.5454545454545455e-06, 'epoch': 2.95}


100%|██████████| 621/621 [15:57<00:00,  1.14it/s]

{'loss': 5.4405, 'grad_norm': 2.721954107284546, 'learning_rate': 4.132231404958678e-07, 'epoch': 3.0}


                                                 
100%|██████████| 621/621 [17:00<00:00,  1.14it/s]

{'eval_loss': 5.593548774719238, 'eval_runtime': 61.5946, 'eval_samples_per_second': 13.41, 'eval_steps_per_second': 0.844, 'epoch': 3.0}


100%|██████████| 621/621 [17:01<00:00,  1.65s/it]

{'train_runtime': 1021.7472, 'train_samples_per_second': 9.701, 'train_steps_per_second': 0.608, 'train_loss': 5.890784626037794, 'epoch': 3.0}





TrainOutput(global_step=621, training_loss=5.890784626037794, metrics={'train_runtime': 1021.7472, 'train_samples_per_second': 9.701, 'train_steps_per_second': 0.608, 'total_flos': 7.71418806058156e+17, 'train_loss': 5.890784626037794, 'epoch': 3.0})

In [None]:
predictions = trainer.predict(processed_dataset_split["test"])
preds = np.argmax(predictions.predictions, axis=1)

accuracy = accuracy_score(processed_dataset_split["test"]['labels'], preds)
print(f'Accuracy: {accuracy}')

recall = recall_score(processed_dataset_split["test"]['labels'], preds, average='weighted')
print(f'Recall: {recall}')

precision = precision_score(processed_dataset_split["test"]['labels'], preds, average='weighted')
print(f'Precision: {precision}')

f1 = f1_score(processed_dataset_split["test"]['labels'], preds, average='weighted')
print(f'F1: {f1}')


100%|██████████| 52/52 [00:59<00:00,  1.15s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.35108958837772397
Recall: 0.35108958837772397
Precision: 0.30419012949101254
F1: 0.28487774174108027


['model\\preprocessor_config.json']

The results show that the model performs poorly, with an accuracy of 32.8%, a precision of 29.8%, a recall of 32.8% and an F1-score of 26.9%. These values indicate that the model struggles to differentiate classes correctly, with unreliable predictions and a limited balance between precision and recall. This suggests that the model requires improvement, either at the training level (hyperparameters, number of epochs), or at the data level (augmentation, class balancing, or enrichment), to better generalize and capture the essential features of the data.


LIMITATIONS OF THE MODEL IN RELATION TO THE DATASET

The results obtained are not satisfactory, mainly due to the large number of fish species to be classified and the low number of images available per species, with an average of only 5 images per class. This imbalance in the data makes it difficult for the model to learn representative features for each species, limiting its ability to generalize and make accurate predictions. A richer, more balanced dataset would be required to improve the model's performance.

It is indeed far from its usual performances that are usually around 90% of accuracy according to these sources: <a href = "https://dataloop.ai/library/model/google_vit-base-patch16-224/">DataLoop</a>, <a href = "https://paperswithcode.com/sota/image-classification-on-imagenet?p=deepvit-towards-deeper-vision-transformer">Papers With Code</a> and in <a href = "https://arxiv.org/pdf/2010.11929">the original publication</a> on Table 2 and in the part C. ADDITIONAL RESULTS


# Web scrapping and Description generation

In this part we will scrap informations from the web to constitute a database of descriptions of the fishes we have in our dataset. 

We will use the information of <a href = "https://www.fishbase.se/">fishbase.se</a> which is a global biodiversity information system on all species currently known in the world. At present, FishBase covers >35,800 fish species compiled from >63,000 references.

To scrap data we will use the <a href = "https://www.crummy.com/software/BeautifulSoup/">BeautifulSoup</a> library which allows us to get informations. 

In [13]:
import pandas as pd 
import numpy as np

from urllib.request import Request, urlopen
from bs4 import BeautifulSoup as bs
from urllib.request import URLopener
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import math
import time

## Exploration

Let's first try to understand how to scrap the website.

An example of what the pages look like is <a href = "https://www.fishbase.se/summary/zeus-faber.html">here</a>.

In [14]:
def gen_soup(url):
    req = Request(url, headers={'User-Agent': 'Safari'})
    webpage = urlopen(req, timeout=15).read()
    to_ret = bs(webpage,"html.parser")
    return to_ret 

In [15]:
df_total['label']

0         A73EGS-P
1         A73EGS-P
2         A73EGS-P
3         A73EGS-P
4         A73EGS-P
           ...    
4125    zeus_faber
4126    zeus_faber
4127    zeus_faber
4128    zeus_faber
4129    zeus_faber
Name: label, Length: 4130, dtype: object

In [87]:
genus = df_total['label'].iloc[4012].split('_')[0] # let's use an example in our dataset
species = df_total['label'].iloc[4012].split('_')[1]

ret = gen_soup(f"https://www.fishbase.se/summary/SpeciesSummary.php?ID=14550&genusname={genus}&speciesname={species}")

Let's scrap the "Short description" part of the page

In [88]:
# description

to_take = str(ret).split('description')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('>')
detailed = []
for i in range(len(to_take)):
    if 'href' in to_take[i]:
        detailed.append(to_take[i].split('<a href=')[0])
    else:
        detailed.append(to_take[i])

detailed = detailed[1:-1]
result = [i.split('<')[0] if '<' in i else i for i in detailed]
result = [i.split('(Ref')[0] if 'Ref' in i else i for i in result if i != '']
result = [i for i in result if not i.isdigit()]

result.remove(' ')

description = ' '.join(result).replace('  ', ' ')
description

'Dorsal spines (total): 7; Dorsal  soft rays (total): 20 - 24; Anal  spines : 3; Anal  soft rays'

Let's do the same for the Biology part

In [94]:
# biology

bio_tab = str(ret).split('					Biology					')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('Ref')

bio_tab = [bio_tab[i] for i in range(len(bio_tab)) if i % 2 == 0]
bio_tab = [i.split('>')[-1] if '>' in i else i for i in bio_tab]
bio_tab = [i.replace(').', '') if ').' in i else i for i in bio_tab]
bio_tab = [i.replace(')', '') if ')' in i else i for i in bio_tab]
bio_tab = [i.replace('(', '') if '(' in i else i for i in bio_tab]
bio_tab = [i.replace('\t', '') if '\t' in i else i for i in bio_tab]
bio_tab = [i.replace('\n', '') if '\n' in i else i for i in bio_tab]
bio_tab = [i.replace('\r', '') if '\r' in i else i for i in bio_tab]

biology = ' '.join(bio_tab).replace('  ', ' ')
if '<' in biology:
    biology = biology.split('<')[0]
biology


'Adults occur near surface waters of lagoon and seaward reefs, in surge zones along sandy beaches , '

Finally, we do the same for the Distribution part and we can concatenate the 3 parts to have one "clean" sentence for the specie. 
The sentence is a good start but it is not truly user-friendly to read. To do that we will use an LLM to generate better description.

In [95]:
distribution = str(ret).split('					Distribution					')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('\t')[-1].split('<')[0]

description = "Short description: " + description
distribution = "Distribution: " + distribution
biology = "Biology: " + biology

full_desc = description + '\n' + distribution + '\n' + biology
full_desc

'Short description: Dorsal spines (total): 7; Dorsal  soft rays (total): 20 - 24; Anal  spines : 3; Anal  soft rays\nDistribution: Indo-Pacific:  Red Sea to the Line and Mangaréva islands, north to southern Japan, south to Lord Howe and Rapa.\nBiology: Adults occur near surface waters of lagoon and seaward reefs, in surge zones along sandy beaches , '

## Structuration 

Let's create a function to scrap automatically the data of the fishes of our dataset to create descriptions.

In [115]:
def scrap_fishbase(genus, species):
    ret = gen_soup(f"https://www.fishbase.se/summary/SpeciesSummary.php?ID=14550&genusname={genus}&speciesname={species}")

    try:
        to_take = str(ret).split('					Short description					')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('>')
        detailed = []
        for i in range(len(to_take)):
            if 'href' in to_take[i]:
                detailed.append(to_take[i].split('<a href=')[0])
            else:
                detailed.append(to_take[i])

        detailed = detailed[1:-1]
        result = [i.split('<')[0] if '<' in i else i for i in detailed]
        result = [i.split('(Ref')[0] if 'Ref' in i else i for i in result if i != '']
        result = [i for i in result if not i.isdigit()]

        result.remove(' ')

        description = ' '.join(result).replace('  ', ' ')
    except:
        description = ''

    # biology
    try:
        bio_tab = str(ret).split('					Biology					')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('Ref')

        bio_tab = [bio_tab[i] for i in range(len(bio_tab)) if i % 2 == 0]
        bio_tab = [i.split('>')[-1] if '>' in i else i for i in bio_tab]
        bio_tab = [i.replace(').', '') if ').' in i else i for i in bio_tab]
        bio_tab = [i.replace(')', '') if ')' in i else i for i in bio_tab]
        bio_tab = [i.replace('(', '') if '(' in i else i for i in bio_tab]
        bio_tab = [i.replace('\t', '') if '\t' in i else i for i in bio_tab]
        bio_tab = [i.replace('\n', '') if '\n' in i else i for i in bio_tab]
        bio_tab = [i.replace('\r', '') if '\r' in i else i for i in bio_tab]

        biology = ' '.join(bio_tab).replace('  ', ' ')
        if '<' in biology:
            biology = biology.split('<')[0]
    except:
        biology = ''

    try:
        distribution = str(ret).split('					Distribution					')[1].split('smallSpace')[1].split('<span>')[1].split('</span>')[0].split('\t')[-1].split('<')[0]
    except:
        distribution = ''
    
    if description != '':
        description = "Short description: " + description
    if distribution != '':
        distribution = "Distribution: " + distribution
    if biology != '':
        biology = "Biology: " + biology

    full_desc = description + '\n' + distribution + '\n' + biology

    if full_desc == '\n\n':
        return 'No information found'

    return full_desc

def scrap_to_apply(x):
    if '_' in x:
        genus = x.split('_')[0].replace(' ', '')
        species = x.split('_')[1].replace(' ', '')

        return scrap_fishbase(genus, species)
    else:   
        return 'No information found'
    

## Scraping

Let's start scrapping the website

In [117]:
tqdm.pandas()
descs = pd.DataFrame(df_total['label'].unique())
descs.columns = ['label']
descs['description'] = descs['label'].progress_apply(scrap_to_apply)
descs

Unnamed: 0,label,description
0,A73EGS-P,No information found
1,acanthaluteres_brownii,\nDistribution: Eastern Indian Ocean: southern...
2,acanthaluteres_spilomelanurus,Short description: Dorsal spines (total): 2; D...
3,acanthaluteres_vittiger,Short description: Dorsal spines (total): 2; D...
4,acanthistius_cinctus,\nDistribution: Southwest Pacific.\nBiology: O...
...,...,...
479,wetmorella_albofasciata,Short description: Dorsal spines (total): 9; D...
480,wetmorella_nigropinnata,Short description: Dorsal spines (total): 9; D...
481,xiphocheilus_typus,Short description: Dorsal spines (total): 12; ...
482,zenarchopterus_dispar,Short description: Dorsal spines (total): 0; D...


In [118]:
# let' save the informations found

descs.to_csv('fishbase_info.csv', index=False)

In [119]:
descs

Unnamed: 0,label,description
0,A73EGS-P,No information found
1,acanthaluteres_brownii,\nDistribution: Eastern Indian Ocean: southern...
2,acanthaluteres_spilomelanurus,Short description: Dorsal spines (total): 2; D...
3,acanthaluteres_vittiger,Short description: Dorsal spines (total): 2; D...
4,acanthistius_cinctus,\nDistribution: Southwest Pacific.\nBiology: O...
...,...,...
479,wetmorella_albofasciata,Short description: Dorsal spines (total): 9; D...
480,wetmorella_nigropinnata,Short description: Dorsal spines (total): 9; D...
481,xiphocheilus_typus,Short description: Dorsal spines (total): 12; ...
482,zenarchopterus_dispar,Short description: Dorsal spines (total): 0; D...


## Generation of descriptions

Let's now generate some well-written descriptions using LLM with the information we've found. 

We can write descriptions of 1000 tokens (token = part of a word or a word) maximum for readability of the users. 

In [15]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [16]:
pipe_rewrite = pipeline("text-generation", model="meta-llama/Llama-3.2-3B-Instruct", device = device, max_new_tokens=1000 , pad_token_id = 128001)

Loading checkpoint shards: 100%|██████████| 2/2 [00:57<00:00, 28.77s/it]


<h2> Model description </h2>

Llama-3.2-3B-Instruct is a 3-billion-parameter language model developed by Meta, optimized for instruction-following tasks. It employs an auto-regressive transformer architecture and has undergone supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align its outputs with human preferences for helpfulness and safety.

As we use a little* model of only 3B from developped by Meta: Llama 3.2 3B, we need to well construct our prompt.

Sources: <a href = "https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct">HuggingFace</a>, <a href = "https://www.llama.com/">llama.com</a>


<h2>Prompting</h2>

Prompting a small language model effectively is crucial because it ensures that the AI generates accurate, relevant, and context-specific responses. Well-designed prompts guide the model by clearly defining its role, task, and scope, minimizing ambiguity and maximizing the utility of the output. For the given prompt, the method employed includes: 
1. Role specification ("You are an AI writing assistant") to establish the AI's function 

2. Task definition ("write a non-personal and complete fish description from the informations sent by the user") to clarify expectations

3. Explicit constraints ("Do not return any text other than the rewritten description") to prevent irrelevant or extraneous content. This structured approach ensures the prompt is precise, aligned with the desired outcome, and easy for the AI to interpret.

So, we decide to prompt it the following instructions: 

- <u>System instruction:</u> You are an AI writing assistant. Your task is to write a non personal and complete fish description from the informations sent by the user. Do not return any text other than the rewritten description.

- <u>Message form:</u> Write a one line non personal and complete fish description from the informations sent below. Do not add any new information or return any text other than the rewritten message\nThe informations: ``given information``

Sources: <a href = "https://platform.openai.com/docs/guides/prompt-engineering">OpenAI</a>

<h2>Implementation</h2>

We used a hugging face pipeline which is an easy-to-implement and effective way to use this type of model. 
<br><br><br><br><br><br><br>
*we said the model is little because now there are models of more that 100B parameters like Mistral Large2, GPT-4o etc...



In [17]:
def generate_description(df_key_words:pd.DataFrame, pipe_rewrite, output_csv:str):
    key_words = df_key_words["description"].values

    system_prompt_rewrite = "You are an AI writing assistant. Your task is to write user friendly and complete fish description from the informations sent by the user. Do not return any text other than the rewritten description."
    user_prompt_rewrite = "Write user friendly and complete fish description from the informations sent below. Do not add any new information or return any text other than the rewritten message\nThe informations:"

    messages = [[{"role":"system", "content":system_prompt_rewrite},
                  {"role":"user", "content":f"{user_prompt_rewrite} {desc}"}] for desc in key_words]
    descriptions = []
    for i in tqdm(range(len(messages))):
        desc = pipe_rewrite(messages[i])
        desc = desc[0]['generated_text'][2]['content']
        descriptions.append(desc)

        df_to_append = pd.DataFrame(data={"label":[df_key_words["label"].values[i]],"generated_description": [desc]})

        # Append the row to the CSV file
        df_to_append.to_csv(output_csv, mode='a', header=not pd.io.common.file_exists(output_csv), index=False)


    #descriptions = [desc[0]['generated_text'][2]['content'] for desc in descriptions]
    df_to_ret = pd.DataFrame(data={"label": df_key_words["label"].values, "generated_description": descriptions})

    return df_to_ret

In [41]:
descs = pd.read_csv('fishbase_info.csv')
df_descriptions = generate_description(descs[:31], pipe_rewrite, "fish_descriptions_generated.csv")
df_descriptions

100%|██████████| 484/484 [00:00<?, ?it/s]


# Combination of the components

In [38]:
def full_pipeline(img, processor, vision_model, full_df, df_labels):
    pred = predict_from_image(img, processor, vision_model)
    label = full_df[df_labels['label_id'] == pred]['label'].values[0]
    desc = full_df[df_labels['label'] == label]['generated_description'].values[0]
    return label, desc

In [40]:
# example of utilisation
df_descriptions = pd.read_csv('fish_descriptions_generated.csv')

img = Image.open(df_total['path'].iloc[19])

label, desc = full_pipeline(img, processor, model, df_descriptions, df_labels)

print(f'Label: {label}')
print(f'Description: {desc}')
print("\n")

Label: acanthaluteres_vittiger
Description: This species of fish is identified by its distinctive dorsal spines, featuring a total of two. It boasts a notable characteristic of having 30-35 dorsal soft rays. Additionally, it possesses anal soft rays. Found in the Eastern Indian Ocean, this fish inhabits a vast area, including southern Australia, stretching from southern Western Australia to New South Wales and Tasmania.




  label = full_df[df_labels['label_id'] == pred]['label'].values[0]
  desc = full_df[df_labels['label'] == label]['generated_description'].values[0]


With this code, we can now predict the fish species and generate a description of the fish species from an image of the fish. 

However, currently the metrics shows the classification image model does not offer good results in comparison of its results on the benchmarks. We could improve its scores by using more data and training it for more epochs. The problematics of finding data could be solved by the users of the app: they take photos, some specialists can classify the photo of the fish taken and this data would be used to train the model. 

Moreover, we could add more species to our data to offer a better experience to the users and bring more help to the scientist.