In [4]:
import sys
from pathlib import Path
sys.path.insert(0, '../src')

from plant_care_ai.training.class_selection import get_most_popular_classes

project_root = Path.cwd().parent 
data_path = project_root / "data" / "plantnet_300K"

top_100, all_counts = get_most_popular_classes(str(data_path), top_k=100)
print(f"Top class: {top_100[0]} with {all_counts[top_100[0]]} samples")

Loaded 243916 samples, 1081 classes
Top class: 1363227 with 7208 samples


In [5]:
class_to_idx = {species_id: i for i, species_id in enumerate(top_100)}
idx_to_class = {i: species_id for i, species_id in enumerate(top_100)}
num_classes = len(top_100)

[v for v in class_to_idx.items()][:3]
[v for v in idx_to_class.items()][:3]

[(0, '1363227'), (1, '1392475'), (2, '1356022')]

In [6]:
from plant_care_ai.training.train import PlantTrainer

config = {
    "data_dir": "../data/plantnet_300K",
    "checkpoint_dir": "../checkpoints",
    "experiment_name": "efficientnet_v2_top100",
    
    "subset_classes": top_100,
    "train_samples_per_class": 150,
    "val_samples_per_class": 30,
    
    "model": "efficientnetv2",
    "variant": "b0",
    "img_size": 224,
    
    "batch_size": 32,
    "epochs": 20,
    "lr": 1e-3,
    "weight_decay": 0.05,
    "label_smoothing": 0.1,
    "augm_strength": 0.4,
    "num_workers": 4
}

In [7]:
trainer = PlantTrainer(config)
trainer.prepare_data()

sample_img, sample_label = trainer.train_loader.dataset[0]
print(f"Sample label: {sample_label}")
print(f"Label type: {type(sample_label)}")
print(f"Model output size: {trainer.num_classes}")

  self.scaler = torch.amp.GradScaler("cuda")


Loaded 243916 samples, 1081 classes
Loaded 31118 samples, 1081 classes
Total train samples in dataset: 243916
Selected classes: 100
Train samples: 15000
Val samples: 3000
Sample label: 26
Label type: <class 'int'>
Model output size: 100


In [4]:
trainer = PlantTrainer(config)
trainer.prepare_data()
trainer.build_model()
trainer.setup_training()
history = trainer.train()

Loaded 243916 samples, 1081 classes
Loaded 31118 samples, 1081 classes
Total train samples in dataset: 243916
Selected classes: 100

Checkpoint dir: ../checkpoints/efficientnet_v2_top100
Epochs: 20
Train samples: 15000
Val samples: 3000
Epoch 1/20


                                                                               


Results:
	Train: Loss=4.3811, Acc=3.25%
	Val:   Loss=4.2507, Top-1=4.63%, Top-5=18.97%
	LR: 0.000994
Best model saved (acc: 4.63%)
Epoch 2/20


                                                                               


Results:
	Train: Loss=4.0772, Acc=6.86%
	Val:   Loss=4.3713, Top-1=6.83%, Top-5=25.63%
	LR: 0.000976
Best model saved (acc: 6.83%)
Epoch 3/20


                                                                                     


Results:
	Train: Loss=3.9137, Acc=9.75%
	Val:   Loss=3.8051, Top-1=11.83%, Top-5=38.90%
	LR: 0.000946
Best model saved (acc: 11.83%)
Epoch 4/20


                                                                                

KeyboardInterrupt: 

^^ terrible accuracy and loss value :(

In [16]:
import torch
from pathlib import Path
\
checkpoint_path = "../checkpoints/efficientnet_v2_top100/best.pth"
checkpoint = torch.load(checkpoint_path)

from plant_care_ai.training.class_selection import get_most_popular_classes
data_path = "../data/plantnet_300K"
top_100, _ = get_most_popular_classes(data_path, top_k=100)

class_to_idx = {species_id: i for i, species_id in enumerate(top_100)}
idx_to_class = {i: species_id for i, species_id in enumerate(top_100)}

checkpoint["num_classes"] = len(top_100)
checkpoint["class_to_idx"] = class_to_idx
checkpoint["idx_to_class"] = idx_to_class

torch.save(checkpoint, checkpoint_path)

Loaded 243916 samples, 1081 classes


In [10]:
import json


mapping_file = data_path / "plantnet300K_species_id_2_name.json"
print(f"\nLoading plant names from: {mapping_file}")

with open(mapping_file, 'r', encoding='utf-8') as f:
    id_to_name = json.load(f)

print(f"Loaded {len(id_to_name)} plant names")
print(f"Example: {list(id_to_name.items())[0]}")


Loading plant names from: /home/szczuru/projects/Plant-Care-Assistant-App/ai/data/plantnet_300K/plantnet300K_species_id_2_name.json
Loaded 1081 plant names
Example: ('1355868', 'Lactuca virosa L.')


In [11]:
import json
import sys
from pathlib import Path

sys.path.insert(0, '../src')
project_root = Path.cwd().parent 
data_path = project_root / "data" / "plantnet_300K"
checkpoint_path = project_root / "checkpoints" / "efficientnet_v2_top100" / "best.pth"

from plant_care_ai.inference.classifier import PlantClassifier
classifier = PlantClassifier.from_checkpoint(checkpoint_path)
classifier.set_name_mapping(id_to_name)

print("\nFinding test image...")
from plant_care_ai.data.dataset import PlantNetDataset
from plant_care_ai.data.preprocessing import get_inference_pipeline

dataset = PlantNetDataset(
    str(data_path),
    split="train",
    transform=get_inference_pipeline(224)
)

test_img_path = None
true_species_id = None

for img_path, species_id in dataset.paths:
    if species_id in classifier.idx_to_class.values():
        test_img_path = img_path
        true_species_id = species_id
        break

print(f"Test image: {Path(test_img_path).name}")
print(f"True class: {id_to_name.get(true_species_id, 'Unknown')} (ID: {true_species_id})")

result = classifier.predict(test_img_path, top_k=5)

print(f"Processing time: {result['processing_time_ms']:.1f}ms\n")

for rank, pred in enumerate(result['predictions'], 1):
    is_correct = "CORRECT" if pred['class_id'] == true_species_id else ""
    print(f"{rank}. {pred.get('class_name', 'Unknown')} ({pred['confidence']:.2%}) {is_correct}")
    print(f"   ID: {pred['class_id']}")

print("Top prediction:", result['predictions'][0])

from PIL import Image
im = Image.open(test_img_path)
im.show()

Loaded checkpoint from /home/szczuru/projects/Plant-Care-Assistant-App/ai/checkpoints/efficientnet_v2_top100/best.pth
Model: efficientnetv2
Classes: 100
Best validation accuracy: 11.83%

Finding test image...
Loaded 243916 samples, 1081 classes
Test image: 15030c83645760de2d3b5d3665e2fbe511c7085b.jpg
True class: Anthurium andraeanum Linden ex Andr√© (ID: 1409238)
Processing time: 105.2ms

1. Tagetes erecta L. (9.98%) 
   ID: 1374048
2. Tagetes patula L. (8.33%) 
   ID: 1364159
3. Papaver orientale L. (7.33%) 
   ID: 1394404
4. Kniphofia uvaria (L.) Hook. (5.33%) 
   ID: 1393393
5. Calendula officinalis L. (4.98%) 
   ID: 1357330
Top prediction: {'class_id': '1374048', 'confidence': 0.0997985303401947, 'class_name': 'Tagetes erecta L.'}
