In [1]:
import sys
sys.path.append("../../")
from utils.class_names import class_names
from utils.dataloaders import get_dataloaders

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.amp import autocast, GradScaler
from sklearn.utils.class_weight import compute_class_weight

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Feature extraction and head creation

#### Load model

In [4]:
MODEL_PATH = "../../speciesnet/models/speciesnet-pytorch-v4.0.1a-v1/always_crop_99710272_22x8_v12_epoch_00148.pt"
NUM_CLASSES = len(class_names)  # Number of target Polish species

In [5]:
model = torch.load(MODEL_PATH, map_location="cuda", weights_only=False)

for p in model.parameters():
    p.requires_grad = False

model.eval()

GraphModule(
  (initializers): Module()
  (SpeciesNet/efficientnetv2-m/rescaling/mul): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/rescaling/add): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/stem_conv/Conv2D__6): OnnxTranspose()
  (SpeciesNet/efficientnetv2-m/stem_conv/Conv2D): Sequential(
    (0): OnnxPadStatic()
    (1): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2))
  )
  (SpeciesNet/efficientnetv2-m/stem_activation/Sigmoid): Sigmoid()
  (SpeciesNet/efficientnetv2-m/stem_activation/mul_1): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1a_project_conv/Conv2D): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (SpeciesNet/efficientnetv2-m/block1a_project_activation/Sigmoid): Sigmoid()
  (SpeciesNet/efficientnetv2-m/block1a_project_activation/mul_1): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1a_add/add): OnnxBinaryMathOperation()
  (SpeciesNet/efficientnetv2-m/block1b_project_conv/Conv2D): Conv2d(24, 24, 

#### Get number of speciesnet outputs

In [6]:
dummy = torch.randn(1, 480, 480, 3).to('cuda')
out = model(dummy)
n_original_classes = out.shape[1]
print(out.shape)

torch.Size([1, 2498])


In [7]:
scale = model.initializers.onnx_initializer_0.cpu().numpy()
offset = model.initializers.onnx_initializer_1.cpu().numpy()

print("scale:", scale)
print("offset:", offset)

scale: 1.9921875
offset: -1.0


#### Extract features

In [8]:
class SpeciesNetFeatures(nn.Module):
    def __init__(self, graph_module, feature_node_name):
        super().__init__()
        self.graph = graph_module
        self.feature_node_name = feature_node_name
        self._features = None

        # register hook on the ONNX-traced layer
        layer = getattr(self.graph, feature_node_name)
        layer.register_forward_hook(self._hook)

    def _hook(self, module, inp, out):
        self._features = out

    def forward(self, x):
        _ = self.graph(x)              # run through full graph
        return self._features          # return features BEFORE final classifier

In [9]:
FEATURE_NODE = "SpeciesNet/efficientnetv2-m/avg_pool/Mean_Squeeze__3825"
feat_extractor = SpeciesNetFeatures(model, FEATURE_NODE).to("cuda")

#### Check number of features output by model, before classifier

In [10]:
dummy = torch.randn(1, 480, 480, 3).to("cuda")
feat = feat_extractor(dummy)
feature_dim = feat.shape[1]
print("Feature dim:", feature_dim)

Feature dim: 1280


#### Change classifier in new model

In [11]:
from speciesnet_polish_model import SpeciesnetPolish

## Training

##### Dataloaders

In [12]:
IMG_SIZE = 480
BATCH_SIZE = 120
VALID_EXT = (".jpg", ".jpeg", ".png", ".bmp", ".webp")

In [13]:
dataloaders = get_dataloaders(BATCH_SIZE)
train_dataloader = dataloaders.get('train_dataloader')
test_dataloader = dataloaders.get('test_dataloader')
train_dataset = dataloaders.get('train_dataset')

Number of train classes:  26
Number of train images:  39027
Number of test images:  3978


##### The loop

imbalanced dataset:

In [14]:
y_train = [label for _, label in train_dataset.samples]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)

weights = torch.tensor(class_weights, dtype=torch.float32).cuda()

In [15]:
number_of_epochs = 20
polish_model = SpeciesnetPolish(feat_extractor, NUM_CLASSES).to("cuda")

optimizer = torch.optim.AdamW(polish_model.classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=weights)
scaler = GradScaler()

best_val = float("inf")
best_state = None

for epoch in range(1, number_of_epochs+1):
    correct_train = 0
    total_train = 0
    correct_test = 0
    total_test = 0
    train_loss = 0
    test_loss = 0
    batch = 0

    # TRAIN
    polish_model.train()
    for images, labels in train_dataloader:
        batch += 1
        print(f'Batch {batch} / {len(train_dataloader)}', end='\r')
        images, labels = images.cuda(), labels.cuda()
        
        with autocast('cuda'):
            logits = polish_model(images)
            loss = criterion(logits, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        pred = logits.argmax(1)
        correct_train += (pred == labels).sum().item()
        total_train += labels.size(0)

    # TEST
    batch = 0
    polish_model.eval()
    with torch.no_grad():
        for images, labels in test_dataloader:
            batch += 1
            print(f'Batch {batch} / {len(test_dataloader)}', end='\r')
            images, labels = images.cuda(), labels.cuda()
            with autocast('cuda'):
                logits = polish_model(images)
                loss = criterion(logits, labels)
            test_loss += loss.item()
            pred = logits.argmax(1)
            correct_test += (pred == labels).sum().item()
            total_test += labels.size(0)

    print(f"Epoch {epoch} train_acc: {correct_train / total_train:.3f} train_loss: {train_loss:.3f}")
    print(f"Epoch {epoch} test_acc: {correct_test / total_test:.3f} test_loss: {test_loss:.3f}")

    if test_loss < best_val:
        best_val = test_loss
        best_state = polish_model.state_dict()

Epoch 1 train_acc: 0.817 train_loss: 402.503
Epoch 1 test_acc: 0.789 test_loss: 36.411
Epoch 2 train_acc: 0.907 train_loss: 130.453
Epoch 2 test_acc: 0.790 test_loss: 32.347
Epoch 3 train_acc: 0.919 train_loss: 104.667
Epoch 3 test_acc: 0.783 test_loss: 33.508
Epoch 4 train_acc: 0.924 train_loss: 94.977
Epoch 4 test_acc: 0.786 test_loss: 32.730
Epoch 5 train_acc: 0.928 train_loss: 86.457
Epoch 5 test_acc: 0.794 test_loss: 32.416
Epoch 6 train_acc: 0.931 train_loss: 81.881
Epoch 6 test_acc: 0.786 test_loss: 33.048
Epoch 7 train_acc: 0.934 train_loss: 77.486
Epoch 7 test_acc: 0.785 test_loss: 33.412
Epoch 8 train_acc: 0.937 train_loss: 74.064
Epoch 8 test_acc: 0.781 test_loss: 33.794
Epoch 9 train_acc: 0.940 train_loss: 70.093
Epoch 9 test_acc: 0.792 test_loss: 32.330
Epoch 10 train_acc: 0.942 train_loss: 67.604
Epoch 10 test_acc: 0.792 test_loss: 32.203
Epoch 11 train_acc: 0.942 train_loss: 66.734
Epoch 11 test_acc: 0.793 test_loss: 33.330
Epoch 12 train_acc: 0.944 train_loss: 63.816
Ep

KeyboardInterrupt: 

In [16]:
name = "speciesnet_polish_3_lr4"
checkpoint = {
    'state_dict': best_state,
    'class_names': class_names,
    'feature_node': FEATURE_NODE,
    'num_classes': NUM_CLASSES
}
torch.save(checkpoint, f"{name}_checkpoint.pt")

In [None]:
labels

tensor([23,  8,  9, 24, 15, 15, 16, 24, 15, 14, 18,  9, 21,  3, 17, 11, 21,  2,
         9, 15,  7,  8, 17, 17, 13,  0, 14, 13,  6, 17,  1,  4,  4,  1,  2, 17,
         9, 23,  7, 19, 15, 17,  2, 11, 13, 22, 19, 15, 11, 14,  0,  8, 11,  1,
         5, 23, 15, 11, 22, 23, 25, 18,  1,  8, 25, 20, 17, 11,  9, 23, 24, 25,
         2, 16,  3, 18, 16, 18, 15,  9, 17, 20, 17, 25,  1, 18, 11,  9, 21,  1,
         1, 16, 20, 12, 25, 14, 22,  6, 19,  7], device='cuda:0')

In [None]:
torch.argmax(logits, dim=1)

tensor([24,  8,  8, 22, 23,  1,  7,  8,  4, 18, 11, 18, 20, 16, 12, 23,  9,  6,
         4,  3, 23,  8, 16,  4, 22, 22, 20,  2,  1,  5, 24,  8, 22,  8, 19, 17,
        22,  8, 16, 11,  1,  9, 20,  0, 19, 17,  5, 18,  5, 18,  3,  3,  8, 22,
         0,  8, 13,  3, 18, 17, 23,  4,  7,  3, 18, 22, 19, 24, 20,  9,  5, 18,
         4, 12, 16, 20, 17, 22,  5,  1, 24, 12, 14, 12, 22, 22,  7, 17, 23,  0,
        22, 22, 18, 22,  0, 19, 20, 11,  6,  0], device='cuda:0')