In [None]:
#load packages
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim import SGD
from torch import tensor
from torch import max as torchmax
from torch import save as torchsave
from torch.cuda.amp import GradScaler, autocast
from torch.cuda import is_available
import torch.nn as nn
import torchvision.transforms as transforms


import pandas as pd

import os

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from timm.models import create_model

from sklearn.metrics import f1_score

import time

from livelossplot import PlotLosses


In [None]:
#load images and labels
TRAIN_DATA_DIR = "/dir/to/traindata/"
TEST_DATA_DIR = "/dir/to/testdata/"
TRAIN_LABELS_DIR=os.path.abspath("/dir/to/trainmetadata/SnakeCLEF2022-TrainMetadata.csv")
TEST_LABELS_DIR=os.path.abspath("/dir/to/testmetadata/SnakeCLEF2022-TestMetadata.csv")
# Select a folder, where to save models
MODEL_DIR = os.path.join("/dir/to/save/models/")

# set additional parameters
BATCH_SIZE = 32
NUM_EPOCHS=30
IMAGE_SIZE = 384 
trainingDataset=pd.read_csv(TRAIN_LABELS_DIR)
learning_rate=0.1
threshold_early_stopping=8

In [None]:
trainingDataset["image_path"]=TRAIN_DATA_DIR+trainingDataset.file_path


In [None]:
NUM_TRAINING_SAMPLES=trainingDataset.shape[0]

In [None]:
class SnakeTrainDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data=data
        self.transform = transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        img_obj = self.data.iloc[index]
        img = Image.open(img_obj.image_path).convert("RGB")
        y_label = tensor(img_obj.class_id)

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)

In [None]:
model = create_model("tf_efficientnetv2_m_in21k",
        pretrained=True,
        num_classes=trainingDataset.class_id.unique().shape[0],
        drop_rate=0.3)

In [None]:
device = ("cuda" if is_available() else "cpu")
model.to(device)

In [None]:
transform = transforms.Compose(
        [
            transforms.Resize((int(IMAGE_SIZE+IMAGE_SIZE*0.1), int(IMAGE_SIZE+IMAGE_SIZE*0.1))),
            transforms.RandomCrop((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
datasetTrain = SnakeTrainDataset(trainingDataset,transform=transform)


In [None]:
train_loader = DataLoader(dataset=datasetTrain, shuffle=True, batch_size=BATCH_SIZE,num_workers=4)

EPOCH_LENGTH=datasetTrain.__len__()//BATCH_SIZE

In [None]:
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [None]:
criterion = nn.CrossEntropyLoss()


In [None]:
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [None]:
iters = len(train_loader)

In [None]:
scaler = GradScaler()
liveloss = PlotLosses()
epoch=0
scheduler = CosineAnnealingWarmRestarts(optimizer, 5,2)
while (epoch<=NUM_EPOCHS):
    epoch+=1
    #initialize logs used for live plotting
    logs = {}
    #initialize epoch starting time
    start = time.time()
    preds_epoch=[]
    label_epoch=[]
    model.train()
    loader=train_loader
    running_loss = 0.0
    #iterate over all minibatches of the dataset
    for i, data in enumerate(loader, 0):
        # get the images and labels of one mini batch and convert to GPU readable format
        inputs, labels = data
        inputs=inputs.cuda()
        labels=labels.cuda()
        # mixed precision
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step(epoch + i / iters)
        _, preds = torchmax(outputs, 1)
        preds_epoch.append(preds.cpu().detach().numpy())
        label_epoch.append(labels.data.cpu().detach().numpy())
        running_loss += loss.detach() * inputs.size(0)
    #flatten predictions and labels of epoch
    preds_epoch = [item for sublist in preds_epoch for item in sublist]
    label_epoch = [item for sublist in label_epoch for item in sublist]
    #calculate epoch loss for training dataset
    epoch_loss = running_loss / NUM_TRAINING_SAMPLES
    logs['log loss'] = epoch_loss.item()
    logs['f1'] = f1_score(preds_epoch,label_epoch,average="macro")

    #update plot
    liveloss.update(logs)
    liveloss.send()
    torchsave(model.state_dict(),MODEL_DIR+"model_"+str(epoch)+".pth")
    end = time.time()
    print('{:5.3f}s'.format(end-start))
