In [1]:
import torch
from torchvision import transforms

from aim_perception.training import Trainer
from aim_perception.models import ResNet, ModelWrapper, MultiClassMlpHead
from aim_perception.loading import AimDatasetConstructor

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
abs_path = '/home/ubuntu/aim/vehicle_dataset'
image_path = 'data'
label_path = 'ground_truth.csv'

In [3]:

dataset_constructor = AimDatasetConstructor(
    root_dir=abs_path,
    csv_path=label_path,
    data_subdir=image_path,
    transforms=[
        transforms.ToTensor(),
        transforms.Resize(size=(96, 96)),
        transforms.Normalize(mean=[0.4886, 0.4855, 0.4838], std=[0.2456, 0.2443, 0.2490])
    ]
)

Train percent: 75.00150253026095
Val percent: 14.99885807700167
Test percent: 9.99963939273737


In [4]:
train, val, test = dataset_constructor.get_all_datasets()

In [5]:
import torch
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights

epochs = 16
batch_size = 128
momentum = 0.9
learning_rate = 1e-1
weight_decay = 1e-5

# Create loaders
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size*2, shuffle=True, num_workers=2)

# Create model
backbone = ResNet.resnet_18(in_channels=3, depthwise_separable=False)
head = MultiClassMlpHead(
    input_size=512, 
    inner_dim=128,  
    num_targets=10, 
    bias=True,
    dropout=0.05, 
    norm=nn.BatchNorm1d
)
model = ModelWrapper(backbone=backbone, head=head)

# Optimizer and loss
criterion = torch.nn.CrossEntropyLoss(weight=train.get_class_weights().to('cuda'))
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)



In [6]:
trainer = Trainer(epochs=epochs, validate_every=100, criterion=criterion, optimizer=optimizer, scheduler=scheduler, wandb_project='test')
eval = trainer(model=model, train_loader=train_loader, val_loader=val_loader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: fagundo. Use `wandb login --relogin` to force relogin


Moving model to cuda...
[1,   101] train loss: 2.212 | val loss: 2.083 | val bal-acc: 0.393
[1,   201] train loss: 2.020 | val loss: 2.043 | val bal-acc: 0.427
[1,   301] train loss: 1.943 | val loss: 1.949 | val bal-acc: 0.525
[1,   401] train loss: 1.933 | val loss: 1.908 | val bal-acc: 0.568
-------- Finished Epoch 1 --------
[2,   101] train loss: 1.916 | val loss: 1.882 | val bal-acc: 0.590
[2,   201] train loss: 1.844 | val loss: 1.877 | val bal-acc: 0.592
[2,   301] train loss: 1.858 | val loss: 1.858 | val bal-acc: 0.614
[2,   401] train loss: 1.809 | val loss: 1.851 | val bal-acc: 0.621
-------- Finished Epoch 2 --------
[3,   101] train loss: 1.808 | val loss: 1.828 | val bal-acc: 0.645
[3,   201] train loss: 1.803 | val loss: 1.816 | val bal-acc: 0.658
[3,   301] train loss: 1.797 | val loss: 1.816 | val bal-acc: 0.656
[3,   401] train loss: 1.792 | val loss: 1.801 | val bal-acc: 0.668
-------- Finished Epoch 3 --------
[4,   101] train loss: 1.753 | val loss: 1.781 | val ba

In [7]:
print(eval.classification_report)

              precision    recall  f1-score   support

           0       0.80      0.79      0.80      1267
           1       0.70      0.83      0.76       277
           2       0.88      0.92      0.90      1192
           3       0.89      0.77      0.83      4548
           4       0.68      0.85      0.75       245
           5       0.42      0.60      0.49       212
           6       0.82      0.88      0.85       748
           7       0.74      0.78      0.76      2208
           8       0.56      0.61      0.59       638
           9       0.69      0.77      0.73      1143

    accuracy                           0.79     12478
   macro avg       0.72      0.78      0.75     12478
weighted avg       0.80      0.79      0.79     12478

