In [1]:
import sys
sys.path.append("../../")
from utils.class_names import class_names
from utils.dataloaders import get_dataloaders

In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.amp import autocast, GradScaler
from sklearn.utils.class_weight import compute_class_weight

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Feature extraction and head creation

#### Load model

In [4]:
MODEL_PATH = "../../speciesnet/models/speciesnet-pytorch-v4.0.1a-v1/always_crop_99710272_22x8_v12_epoch_00148.pt"
NUM_CLASSES = len(class_names)  # Number of target Polish species

In [5]:
model = torch.load(MODEL_PATH, map_location="cuda", weights_only=False)

for p in model.parameters():
    p.requires_grad = False

model.eval()

GraphModule(
  (initializers): Module()
  (SpeciesNet/efficientnetv2-m/rescaling/mul): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/rescaling/add): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/stem_conv/Conv2D__6): OnnxTranspose()
  (SpeciesNet/efficientnetv2-m/stem_conv/Conv2D): Sequential(
    (0): OnnxPadStatic()
    (1): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2))
  )
  (SpeciesNet/efficientnetv2-m/stem_activation/Sigmoid): Sigmoid()
  (SpeciesNet/efficientnetv2-m/stem_activation/mul_1): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1a_project_conv/Conv2D): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (SpeciesNet/efficientnetv2-m/block1a_project_activation/Sigmoid): Sigmoid()
  (SpeciesNet/efficientnetv2-m/block1a_project_activation/mul_1): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1a_add/add): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1b_project_conv/Conv2D): Conv2d(24, 24, 

#### Get number of speciesnet outputs

In [6]:
dummy = torch.randn(1, 480, 480, 3).to('cuda')
out = model(dummy)
n_original_classes = out.shape[1]
print(out.shape)

torch.Size([1, 2498])


In [7]:
scale = model.initializers.onnx_initializer_0.cpu().numpy()
offset = model.initializers.onnx_initializer_1.cpu().numpy()

print("scale:", scale)
print("offset:", offset)

scale: 1.9921875
offset: -1.0


#### Extract features

In [8]:
class SpeciesNetFeatures(nn.Module):
    def __init__(self, graph_module, feature_node_name):
        super().__init__()
        self.graph = graph_module
        self.feature_node_name = feature_node_name
        self._features = None

        # register hook on the ONNX-traced layer
        layer = getattr(self.graph, feature_node_name)
        layer.register_forward_hook(self._hook)

    def _hook(self, module, inp, out):
        self._features = out

    def forward(self, x):
        _ = self.graph(x)              # run through full graph
        return self._features          # return features BEFORE final classifier

In [9]:
FEATURE_NODE = "SpeciesNet/efficientnetv2-m/avg_pool/Mean_Squeeze__3825"
feat_extractor = SpeciesNetFeatures(model, FEATURE_NODE).to("cuda")

#### Check number of features output by model, before classifier

In [10]:
dummy = torch.randn(1, 480, 480, 3).to("cuda")
feat = feat_extractor(dummy)
feature_dim = feat.shape[1]
print("Feature dim:", feature_dim)

Feature dim: 1280


#### Change classifier in new model

In [11]:
from speciesnet_polish_model import SpeciesnetPolish

## Training

##### Dataloaders

In [12]:
BATCH_SIZE = 120

In [13]:
dataloaders = get_dataloaders(BATCH_SIZE)
train_dataloader = dataloaders.get('train_dataloader')
test_dataloader = dataloaders.get('test_dataloader')
train_dataset = dataloaders.get('train_dataset')

Number of train classes:  26
Number of train images:  39027
Number of test images:  4000


##### The loop

imbalanced dataset:

In [14]:
y_train = [label for _, label in train_dataset.samples]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)

weights = torch.tensor(class_weights, dtype=torch.float32).cuda()

In [15]:
number_of_epochs = 20
polish_model = SpeciesnetPolish(feat_extractor, NUM_CLASSES).to("cuda")

optimizer = torch.optim.AdamW(polish_model.classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=weights)
scaler = GradScaler()

best_val = float("inf")
best_state = None

for epoch in range(1, number_of_epochs+1):
    correct_train = 0
    total_train = 0
    correct_test = 0
    total_test = 0
    train_loss = 0
    test_loss = 0
    batch = 0

    # TRAIN
    polish_model.train()
    for images, labels in train_dataloader:
        batch += 1
        print(f'Batch {batch} / {len(train_dataloader)}', end='\r')
        images, labels = images.cuda(), labels.cuda()
        
        with autocast('cuda'):
            logits = polish_model(images)
            loss = criterion(logits, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        pred = logits.argmax(1)
        correct_train += (pred == labels).sum().item()
        total_train += labels.size(0)
    print(f"Epoch {epoch} train_acc: {correct_train / total_train:.3f} train_loss: {train_loss:.3f}")

    # TEST
    batch = 0
    polish_model.eval()
    with torch.no_grad():
        for images, labels in test_dataloader:
            batch += 1
            print(f'Batch {batch} / {len(test_dataloader)}', end='\r')
            images, labels = images.cuda(), labels.cuda()
            with autocast('cuda'):
                logits = polish_model(images)
                loss = criterion(logits, labels)
            test_loss += loss.item()
            pred = logits.argmax(1)
            correct_test += (pred == labels).sum().item()
            total_test += labels.size(0)

    print(f"Epoch {epoch} test_acc: {correct_test / total_test:.3f} test_loss: {test_loss:.3f}")

    if test_loss < best_val:
        best_val = test_loss
        best_state = polish_model.state_dict()

Epoch 1 train_acc: 0.827 train_loss: 398.372
Epoch 1 test_acc: 0.881 test_loss: 26.012
Epoch 2 train_acc: 0.911 train_loss: 126.779
Epoch 2 test_acc: 0.881 test_loss: 22.185
Epoch 3 train_acc: 0.923 train_loss: 102.098
Epoch 3 test_acc: 0.878 test_loss: 21.251
Epoch 4 train_acc: 0.927 train_loss: 91.362
Epoch 4 test_acc: 0.866 test_loss: 22.587
Epoch 5 train_acc: 0.932 train_loss: 83.234
Epoch 5 test_acc: 0.883 test_loss: 20.592
Batch 16 / 3426

KeyboardInterrupt: 

In [16]:
name = "speciesnet_polish_lr4"
checkpoint = {
    'state_dict': best_state,
    'class_names': class_names,
    'feature_node': FEATURE_NODE,
    'num_classes': NUM_CLASSES
}
torch.save(checkpoint, f"{name}_checkpoint.pt")
torch.save(polish_model, f'{name}_model.pt')