# Two-stage solution

In [None]:
import torch
import torchvision

In [None]:
torch.cuda.empty_cache()
#device = "cpu"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")

In [None]:
import pandas
import numpy
import pickle

In [None]:
from pathlib import Path

In [None]:
from misc.data import SimpleClassifierDataset, SupplementaryDataset, concatenate_collate_fn, detection_results_to_annotations

In [None]:
from PytorchWildlife.models import detection as pw_detection
from PytorchWildlife.models import classification as pw_classification
from PytorchWildlife.data import transforms as pw_trans
from PytorchWildlife import utils as pw_utils

In [None]:
classifier_model_name = "swin_v2_s"
classifier_weights_name = "Swin_V2_S_Weights"
classifier_weights_subname = "IMAGENET1K_V1"

In [None]:
image_size = 232

## Data

In [None]:
data_path = Path("./data/train_data_minprirodi/")
images_path = data_path / "images"
annotation_path = data_path / "annotation.csv"

In [None]:
model_path = data_path / f"models/{classifier_weights_name}/{classifier_weights_subname}"

In [None]:
annotation = pandas.read_csv(annotation_path)
annotation

In [None]:
dataset = SimpleClassifierDataset(
    images_path,
    annotation,
    torchvision.transforms.Resize((image_size, image_size), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.imshow(dataset[53][0].cpu().detach().numpy().swapaxes(0,2).swapaxes(1,0))
plt.show();

In [None]:
from sklearn.model_selection import train_test_split

unique_names = annotation["Name"].unique()
train_names, test_names = train_test_split(unique_names, test_size=0.2, random_state=42)

In [None]:
train_annotation = annotation[annotation["Name"].isin(train_names)].reset_index(drop=True)
test_annotation  = annotation[annotation["Name"].isin(test_names)].reset_index(drop=True)

In [None]:
train_dataset = SimpleClassifierDataset(images_path, train_annotation, dataset.transform)
test_dataset  = SimpleClassifierDataset(images_path, test_annotation, dataset.transform)

In [None]:
train_batch_size = 24
test_batch_size  = 128

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=test_batch_size,  shuffle=False)

### Negative datasets

In [None]:
empty_path = data_path / "images_empty"
all_path   = data_path / "images"

In [None]:
empty_dataset = SupplementaryDataset(
    empty_path,
    torchvision.transforms.Resize((image_size, image_size), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
)

all_dataset = SupplementaryDataset(
    all_path,
    torchvision.transforms.Resize((image_size, image_size), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.imshow(empty_dataset[53][0].cpu().detach().numpy().swapaxes(0,2).swapaxes(1,0))
plt.show();

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.imshow(all_dataset[53][0].cpu().detach().numpy().swapaxes(0,2).swapaxes(1,0))
plt.show();

In [None]:
empty_dataloader = torch.utils.data.DataLoader(empty_dataset, batch_size=train_batch_size, shuffle=True)
all_dataloader   = torch.utils.data.DataLoader(all_dataset,   batch_size=train_batch_size, shuffle=True)

## Model

In [None]:
from misc.train import *

In [None]:
weights = getattr(getattr(torchvision.models, classifier_weights_name), classifier_weights_subname)

In [None]:
classifier = getattr(torchvision.models, classifier_model_name)(weights=weights).to(device)

In [None]:
model = Classifier(
    #detector,
    classifier.train(),
    ClassifierHead(1000),
    weights.transforms()
).to(device)

In [None]:
history = train_detector_classifier(
    model,
    train_dataloader,
    test_dataloader,
    torch.optim.Adam(model.parameters(), lr=5.0e-5),
    device,
    negative_dataloaders=[
        #empty_dataloader,
        all_dataloader,
    ],
    model_path=model_path,
    n_epochs=20
)

In [None]:
history