In [43]:
from datetime import datetime
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

from PIL import Image
import os
import torch
import random
from torchvision import transforms as T
from src.core import models, metrics, training, training_dino, data, loss_functions
from src.dev import experiments as experiments

DATA_DIR = '../SnakeCLEF2023-medium_size-train/'
VAL_DIR='../SnakeCLEF2023-medium_size-val'

# set seed for reproducibility

seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cpu


In [44]:
# load metadata
train_df = pd.read_csv('../snake_csv_files/snakeCLEF2023_bbox_cleaned_train_metadata.csv')
valid_df = pd.read_csv('../snake_csv_files/SnakeCLEF2023-cleaned-metadata-val.csv')


train_df.head()
classes = np.unique(train_df['binomial'])
no_classes = len(classes)


print(f'No. of classes: {no_classes}')
print(f'Train set length: {len(train_df):,d}')
print(f'Validation set length: {len(valid_df):,d}')

No. of classes: 1784
Train set length: 120,550
Validation set length: 10,985


In [46]:
class CFG:

    N_CLASS=no_classes
    model_name= 'clef2023_dinov2_focal_'+datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")
    history_file= 'clef2023_dinov2_focal_'+datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")+'.csv'
    data='clef2023'
    model = 'dino_v2'
    batch_size=2
    no_epochs=5
    total_batch_size=2
    loss='efocal'
    optimizer = 'sgd'
    learning_rate=0.01
    scheduler = 'reduce_lr_on_plateau'
    shuffle=True


In [47]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, classes):
        super(LinearClassifier, self).__init__()
        self.in_channels=in_channels
        self.classes=classes
        self.fc=torch.nn.Sequential(
               torch.nn.Linear(self.in_channels, 2048),
               torch.nn.ReLU(inplace=True),
               torch.nn.Linear(2048, classes))
        
    def forward(self,x):
        return self.fc(x)

In [48]:
model=LinearClassifier(1536,CFG.N_CLASS).to(device)

In [49]:
patch_h = 40
patch_w = 40
feat_dim = 1536 # vitg14

dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14').to(device)

Using cache found in /Users/aartibalana/.cache/torch/hub/facebookresearch_dinov2_main


In [55]:
'''
Dataset class 

'''
class SnakeTrainDataset(Dataset):

    def __init__(self, root, data, bbox=False, transform = None):
        self.root=root
        self.data=data
        self.transform=transform
        self.bbox=bbox

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):

        image = self.data.iloc[index]
        img = Image.open(self.root+image.image_path).convert("RGB")
        img=TF.adjust_sharpness(img, 20.0)
        
        if self.bbox is not False:
            x_min = self.data['xmin'].values[index]
            y_min=self.data['ymin'].values[index]
            x_max=self.data['xmax'].values[index]
            y_max=self.data['ymax'].values[index]
            img=img.crop((x_min, y_min, x_max, y_max))
        else:
            img=img

        label = torch.tensor(image.class_id)

        if self.transform is not None:
            img = self.transform(img)

        return (img, label)

In [56]:
'''
DINO Transform
'''

transform_train = T.Compose([
    T.GaussianBlur(9, sigma=(0.1, 2.0)),
    T.Resize((patch_h * 14, patch_w * 14)),
    T.CenterCrop((patch_h * 14, patch_w * 14)),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

train_dataset = SnakeTrainDataset(DATA_DIR, train_df[:50], bbox=True, transform=transform_train) # data augmentation. set augmentations = None to disable augmentations
valid_dataset = SnakeTrainDataset(VAL_DIR, valid_df[:50], bbox=True, transform=transform_train) # data augmentation. set augmentations = None to disable augmentations

train_dataloader = DataLoader(train_dataset, batch_size = CFG.batch_size, shuffle = CFG.shuffle)
valid_dataloader = DataLoader(valid_dataset, batch_size = CFG.batch_size, shuffle = CFG.shuffle)



In [57]:
loss_fn = loss_functions.LOSSES[CFG.loss]
opt_fn = training.OPTIMIZERS[CFG.optimizer]
sched_fn = training.SCHEDULERS[CFG.scheduler]

In [58]:
# create trainer

criterion = loss_fn()

trainer = training_dino.Trainer(
    dinov2_vitg14,
    model,
    train_dataloader,
    criterion,
    opt_fn,
    sched_fn,
    validloader=valid_dataloader,
    accumulation_steps=CFG.total_batch_size // CFG.batch_size,
    path='.',
    model_filename=CFG.model_name,
    history_filename=CFG.history_file,
    device=device)


In [None]:
trainer.train(no_epochs=CFG.no_epochs, lr=CFG.learning_rate)