In [1]:
import torch
from torchvision import transforms

from aim_perception.training import Trainer
from aim_perception.models import ResNet, ModelWrapper, MultiClassMlpHead
from aim_perception.loading import AimDatasetConstructor

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
abs_path = '/home/ubuntu/aim/vehicle_dataset'
image_path = 'data'
label_path = 'ground_truth.csv'

In [3]:

dataset_constructor = AimDatasetConstructor(
    root_dir=abs_path,
    csv_path=label_path,
    data_subdir=image_path,
    transforms=[
        transforms.ToTensor(),
        transforms.Resize(size=(96, 96)),
        transforms.Normalize(mean=[0.4886, 0.4855, 0.4838], std=[0.2456, 0.2443, 0.2490])
    ]
)

Train percent: 75.00150253026095
Val percent: 14.99885807700167
Test percent: 9.99963939273737


In [4]:
train, val, test = dataset_constructor.get_all_datasets()

In [5]:
import torch
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights

epochs = 16
batch_size = 128
momentum = 0.9
learning_rate = 1e-1
weight_decay = 1e-5

# Create loaders
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size*2, shuffle=True, num_workers=2)

# Create model
backbone = ResNet.resnet_18(in_channels=3, depthwise_separable=False)
head = MultiClassMlpHead(
    input_size=512, 
    inner_dim=128,  
    num_targets=10, 
    bias=True,
    dropout=0.05, 
    norm=nn.BatchNorm1d
)
model = ModelWrapper(backbone=backbone, head=head)

# Optimizer and loss
criterion = torch.nn.CrossEntropyLoss(weight=train.get_class_weights().to('cuda'))
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)



In [6]:
trainer = Trainer(epochs=epochs, validate_every=100, criterion=criterion, optimizer=optimizer, scheduler=scheduler, wandb_project='test')
eval = trainer(model=model, train_loader=train_loader, val_loader=val_loader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: fagundo. Use `wandb login --relogin` to force relogin


Moving model to cuda...
[1,   101] train loss: 2.240 | val loss: 2.138 | val bal-acc: 0.330
[1,   201] train loss: 2.086 | val loss: 2.080 | val bal-acc: 0.388
[1,   301] train loss: 2.029 | val loss: 2.042 | val bal-acc: 0.435
[1,   401] train loss: 2.026 | val loss: 1.974 | val bal-acc: 0.499
-------- Finished Epoch 1 --------
[2,   101] train loss: 1.966 | val loss: 1.952 | val bal-acc: 0.522
[2,   201] train loss: 1.963 | val loss: 1.941 | val bal-acc: 0.530
[2,   301] train loss: 1.917 | val loss: 1.913 | val bal-acc: 0.556
[2,   401] train loss: 1.874 | val loss: 1.919 | val bal-acc: 0.551
-------- Finished Epoch 2 --------
[3,   101] train loss: 1.860 | val loss: 1.896 | val bal-acc: 0.573
[3,   201] train loss: 1.858 | val loss: 1.879 | val bal-acc: 0.590
[3,   301] train loss: 1.831 | val loss: 1.844 | val bal-acc: 0.627
[3,   401] train loss: 1.892 | val loss: 1.915 | val bal-acc: 0.552
-------- Finished Epoch 3 --------
[4,   101] train loss: 1.841 | val loss: 1.847 | val ba

In [8]:
print(eval.classification_report)

              precision    recall  f1-score   support

           0       0.79      0.79      0.79      1267
           1       0.68      0.80      0.74       277
           2       0.89      0.92      0.90      1192
           3       0.88      0.74      0.80      4548
           4       0.57      0.81      0.67       245
           5       0.32      0.58      0.42       212
           6       0.76      0.84      0.80       748
           7       0.75      0.76      0.75      2208
           8       0.55      0.59      0.57       638
           9       0.66      0.77      0.71      1143

    accuracy                           0.77     12478
   macro avg       0.68      0.76      0.71     12478
weighted avg       0.78      0.77      0.77     12478

