# Train

In [1]:
import torch
import torchvision

In [2]:
torch.cuda.empty_cache()
#device = "cpu"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")

Device: cuda:0
Devices count: 1


In [3]:
import pandas
import numpy
import pickle

In [4]:
from pathlib import Path

In [5]:
from misc.data import DetectorDataset, SimpleClassifierDataset, SupplementaryDataset, concatenate_collate_fn, detection_results_to_annotations

In [6]:
from PytorchWildlife.models import detection as pw_detection
from PytorchWildlife.models import classification as pw_classification
from PytorchWildlife.data import transforms as pw_trans
from PytorchWildlife import utils as pw_utils

In [7]:
classifier_model_name = "swin_v2_s"
classifier_weights_name = "Swin_V2_S_Weights"
classifier_weights_subname = "IMAGENET1K_V1"

## Data

In [8]:
data_path = Path("./data/train_data_minprirodi/")
images_path = data_path / "images"
annotation_path = data_path / "annotation.csv"

In [9]:
model_path = data_path / f"models/{classifier_weights_name}/{classifier_weights_subname}"

In [10]:
annotation = pandas.read_csv(annotation_path)
annotation

Unnamed: 0,Name,Bbox,Class
0,1001958.jpg,"0.7075520833333333,0.5319444444444444,0.282812...",1
1,1001958.jpg,"0.09505208333333333,0.6305555555555555,0.19010...",0
2,1001958.jpg,"0.031510416666666666,0.7434027777777777,0.0630...",0
3,1002155.jpg,"0.8135416666666667,0.6976851851851852,0.371875...",0
4,1002155.jpg,"0.3221354166666667,0.7939814814814815,0.477604...",0
...,...,...,...
1980,1997546.jpg,"0.34661458333333334,0.490625,0.6015625,0.48958...",1
1981,1997602.jpg,"0.7317708333333334,0.25601851851851853,0.30625...",0
1982,1999067.jpg,"0.5630208333333333,0.5020833333333333,0.519791...",1
1983,1999067.jpg,"0.9572916666666667,0.5881944444444445,0.085416...",0


In [11]:
dataset = DetectorDataset(images_path, annotation)

In [12]:
from sklearn.model_selection import train_test_split

unique_names = annotation["Name"].unique()
train_names, test_names = train_test_split(unique_names, test_size=0.2, random_state=42)

## Model

In [13]:
detection_model = pw_detection.MegaDetectorV6(device=device, pretrained=True)

Ultralytics 8.3.28 🚀 Python-3.10.12 torch-2.5.1+cu124 CUDA:0 (Tesla V100-SXM3-32GB, 32494MiB)
YOLOv9c summary (fused): 384 layers, 25,321,561 parameters, 0 gradients, 102.3 GFLOPs


results = detection_model.batch_image_detection(images_path)

with open(data_path / "detection_results.dat", 'wb') as file:
    pickle.dump(results, file)

In [14]:
with open(data_path / "detection_results.dat", 'rb') as file:
    results = pickle.load(file)

In [15]:
annotations_pred = detection_results_to_annotations(results)

In [16]:
annotations_pred

Unnamed: 0,Name,Bbox,xyxy_normalized_coords,label,confidence
0,1001958.jpg,"0.7119067311286926, 0.5377235412597656, 0.2593...","[0.58222866, 0.3885526, 0.8415848, 0.6868945]",0,0.921108
1,1001958.jpg,"0.0919116884469986, 0.6257917284965515, 0.1838...","[0.0, 0.49359453, 0.18382338, 0.7579889]",0,0.915630
2,1001958.jpg,"0.3535197675228119, 0.8627723455429077, 0.2466...","[0.23018798, 0.7784824, 0.47685155, 0.9470624]",0,0.765915
3,1001958.jpg,"0.0270781759172678, 0.7385634183883667, 0.0541...","[0.0, 0.6670405, 0.05415635, 0.8100863]",0,0.599017
4,1002155.jpg,"0.8208271265029907, 0.7090132236480713, 0.3577...","[0.6419688, 0.42098355, 0.9996855, 0.99704283]",0,0.941775
...,...,...,...,...,...
1976,1997546.jpg,"0.354478120803833, 0.4903988838195801, 0.59009...","[0.059431933, 0.2530929, 0.64952433, 0.7277049]",0,0.952174
1977,1997602.jpg,"0.7356047630310059, 0.2534518837928772, 0.2846...","[0.59326077, 0.00030263266, 0.87794876, 0.5066...",0,0.767034
1978,1999067.jpg,"0.5610121488571167, 0.512933075428009, 0.49320...","[0.31441194, 0.25900587, 0.8076124, 0.7668603]",0,0.945998
1979,1999067.jpg,"0.961925745010376, 0.5905055999755859, 0.07592...","[0.9239653, 0.32892227, 0.99988616, 0.8520889]",0,0.849700


In [17]:
annotations_pred.to_csv(data_path / "annotations_pred.csv", index=False)

In [18]:
#annotations_pred["Class"] = numpy.zeros(len(annotations_pred))
annotations_pred["Class"] = annotation["Class"]

In [19]:
dataset = SimpleClassifierDataset(
    images_path,
    annotations_pred,
    torchvision.transforms.Resize((232, 232), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
)

In [20]:
train_annotation = annotations_pred[annotations_pred["Name"].isin(train_names)].reset_index(drop=True)
test_annotation  = annotations_pred[annotations_pred["Name"].isin(test_names)].reset_index(drop=True)

In [21]:
train_dataset = SimpleClassifierDataset(images_path, train_annotation, dataset.transform)
test_dataset  = SimpleClassifierDataset(images_path, test_annotation, dataset.transform)

In [22]:
train_batch_size = 24
test_batch_size  = 128

In [23]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=test_batch_size,  shuffle=False)

In [24]:
model = torch.load(model_path / "e17_0.93.pt")

In [25]:
_, y_pred = model.predict(test_dataloader, device)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:29<00:00,  7.32s/it]


In [26]:
test_annotations_pred = test_dataset.annotation.copy()
test_annotations_pred["Class"] = numpy.exp(y_pred[:,1]) > 0.7

In [27]:
from misc.metrics import *

In [28]:
metric_value, metric_data = calculate_metric(annotation[annotation["Name"].isin(test_names)].reset_index(drop=True).copy(), test_annotations_pred)

In [29]:
metric_value

np.float64(0.8469809760132341)

In [30]:
metric_data["Bbox_pred"]

0      0.11784205585718155, 0.49985140562057495, 0.23...
1      0.48835498094558716, 0.5186216831207275, 0.403...
2      0.2436753213405609, 0.5376400947570801, 0.3343...
3      0.39207592606544495, 0.6405500769615173, 0.443...
4      0.3681947886943817, 0.7512457370758057, 0.3220...
                             ...                        
398    0.902854323387146, 0.6882379651069641, 0.18693...
399    0.7073821425437927, 0.6112831234931946, 0.2363...
400    0.28912046551704407, 0.5781852602958679, 0.244...
401    0.5738311409950256, 0.4737164080142975, 0.2051...
402    0.3902233839035034, 0.8498777151107788, 0.3204...
Name: Bbox_pred, Length: 403, dtype: object

In [31]:
metric_data["Bbox_true"]

0      0.121875,0.4996527777777778,0.24375,0.39930555...
1      0.48854166666666665,0.51484375,0.4114583333333...
2      0.24296875,0.5392361111111111,0.3578125,0.4729...
3      0.390625,0.6425925925925926,0.446875,0.5444444...
4      0.36770833333333336,0.7393518518518518,0.36458...
                             ...                        
398    0.8979166666666667,0.6885416666666667,0.197916...
399    0.7057291666666666,0.6159722222222223,0.255208...
400    0.19739583333333333,0.5857638888888889,0.0875,...
401    0.5729166666666666,0.471875,0.209375,0.5493055...
402    0.3859375,0.8388888888888889,0.35,0.3185185185...
Name: Bbox_true, Length: 403, dtype: object

In [32]:
(~metric_data["correct_Bbox"]).sum()

np.int64(15)

In [33]:
(~metric_data["correct_Bbox"] & metric_data["Class_true"]).sum()

np.int64(2)