In [1]:
import tensorflow as tf
import pandas as pd
import plotly.express as px
from pathlib import Path
import numpy as np
from sklearn.utils import compute_class_weight
from sklearn.dummy import DummyClassifier

from datasets import deep_weeds
from sklearn import metrics

2024-07-06 10:45:07.795110: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-06 10:45:07.807489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-06 10:45:07.807504: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-06 10:45:07.815647: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_dir = Path('../data/images/')
train_path = data_dir / 'train'
val_split = 0.1

In [3]:
train_ds, val_ds = deep_weeds.get_train_val_dataloader(train_path)

Using 12607 files for training.
Using 1400 files for validation.


2024-07-06 10:45:10.000412: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0c:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-06 10:45:10.016921: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0c:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-06 10:45:10.017189: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0c:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-06 10:45:10.018969: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0c:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-06 10:45:10.019317: I external/local_xla/xla/stream_executor

## 1 - Get class balance

In [4]:
ys = np.array([y.numpy() for _,y in train_ds])

2024-07-06 10:45:13.220710: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [5]:
class_weight = compute_class_weight(class_weight='balanced', classes=np.unique(ys), y=ys)
class_weight

array([1.90323068, 0.21467859, 1.87520452, 1.69585687, 1.83108206,
       1.96738452, 1.8701973 , 1.79586895, 1.8263074 ])

In [6]:
focal_loss_alpha = class_weight / class_weight.sum()
focal_loss_alpha

array([0.12705305, 0.01433119, 0.12518212, 0.1132095 , 0.12223666,
       0.13133574, 0.12484786, 0.11988596, 0.12191792])

In [7]:
for idx, n in enumerate(class_weight):
    print(f'{idx}: {n}')

0: 1.9032306763285025
1: 0.2146785866326096
2: 1.8752045217908673
3: 1.6958568738229756
4: 1.8310820624546116
5: 1.9673845193508115
6: 1.870197300103842
7: 1.7958689458689459
8: 1.826307402578589


## 2 - Get baseline results

In [8]:
X = np.zeros_like(ys)

In [9]:
def get_metrics(y, y_pred):
    return {
        'accuracy': metrics.accuracy_score(y, y_pred),
        'recall': metrics.recall_score(y, y_pred, average='macro'),
        'precision': metrics.precision_score(y, y_pred, average='macro'),
    }

In [10]:
strategies = ['stratified', 'prior', 'uniform', 'most_frequent']
for strategy in strategies:
    model = DummyClassifier(strategy='stratified')
    model.fit(X, ys)
    y_pred = model.predict(X)
    dummy_metrics = get_metrics(ys, y_pred)

    print(f'Strategy: {strategy},\n{dummy_metrics}\n')

Strategy: stratified,
{'accuracy': 0.29110811453954155, 'recall': 0.11036791609539978, 'precision': 0.11029397614556909}

Strategy: prior,
{'accuracy': 0.29087015150313317, 'recall': 0.11041313864309091, 'precision': 0.1104412037533023}

Strategy: uniform,
{'accuracy': 0.29872293170460856, 'recall': 0.11380884613749968, 'precision': 0.11390514671741592}

Strategy: most_frequent,
{'accuracy': 0.29420163401285, 'recall': 0.10693324987816817, 'precision': 0.10685589429711115}

