In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from torchvision import transforms, datasets
from sklearn.utils.class_weight import compute_class_weight

from utils.dataloaders import IMG_SIZE, CroppedImageDataset, is_valid_image
from utils.class_names import class_names, class_to_idx

In [None]:
BATCH_SIZE = 100

### Dataloaders

In [None]:
# Filter: accept only files inside pictures_cropped
training_data_transforms = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.05,
        ),
        transforms.ToTensor()
    ])

dataset = datasets.ImageFolder(
    root="../../data",
    transform=training_data_transforms,
    is_valid_file=is_valid_image,
)
print('Number of classes: ', len(dataset.classes))
print('Number of images: ', len(dataset))

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

if dataset.class_to_idx != class_to_idx:
    raise ValueError('Mapping doesn\'t match')

df = pd.read_csv('../../y_clean_thin.csv', index_col=0).sample(frac=0.2).reset_index(drop=True)
df_mega = pd.read_csv('../../megadetector_results.csv', index_col=0)
df = df.merge(df_mega, on='image_path')
df.image_path = '../../' + df.image_path
df = df.dropna()

val_ds = CroppedImageDataset(df)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

### Loop

In [None]:
y_train = [label for _, label in dataset.samples]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)

weights = torch.tensor(class_weights, dtype=torch.float32).cuda()

In [None]:
number_of_epochs = 20
polish_model = model(feat_extractor, len(class_names)).to("cuda")

optimizer = torch.optim.Adam(polish_model.classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(weight=weights)
scaler = GradScaler()

best_val = float("inf")
best_state = None

for epoch in range(1, number_of_epochs+1):
    # TRAIN
    train_loss = 0
    polish_model.train()
    for images, labels in dataloader:
        images, labels = images.cuda(), labels.cuda()
        
        with autocast('cuda'):
            logits = polish_model(images)
            loss = criterion(logits, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    # TEST
    polish_model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            with autocast('cuda'):
                logits = polish_model(images)
                loss = criterion(logits, labels)
            val_loss += loss.item()

    print(f"epoch {epoch}: train_loss={train_loss:.2f}  val_loss={val_loss:.2f}")

    if val_loss < best_val:
        best_val = val_loss
        best_state = polish_model.state_dict()

checkpoint = {
    'state_dict': best_state,
    'class_names': class_names,
    'feature_node': FEATURE_NODE,
    'num_classes': len(class_names)
}
torch.save(checkpoint, "speciesnet_polish_2_checkpoint.pt")