In [2]:
import torch
from torchvision import transforms

from aim_perception.training import Trainer
from aim_perception.models import ResNet, ModelWrapper, MultiClassMlpHead
from aim_perception.loading import AimDatasetConstructor


In [3]:
abs_path = '/Users/mjfagundo/Documents/vehicle_dataset'
image_path = 'data'
label_path = 'ground_truth.csv'

In [4]:

dataset_constructor = AimDatasetConstructor(
    root_dir=abs_path,
    csv_path=label_path,
    data_subdir=image_path,
    transforms=[
        transforms.ToTensor(),
        transforms.Resize(size=(64, 64)),
        transforms.Normalize(mean=[0.4886, 0.4855, 0.4838], std=[0.2456, 0.2443, 0.2490])
    ]
)

Train percent: 75.00150253026095
Val percent: 14.99885807700167
Test percent: 9.99963939273737


In [5]:
train, val, test = dataset_constructor.get_all_datasets()

In [6]:
import torch
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights

epochs = 6
batch_size = 128
momentum = 0.9
learning_rate = 1e-2
weight_decay = 1e-5

# Create loaders
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size*2, shuffle=True, num_workers=2)

# Create model
backbone = ResNet.resnet_18(in_channels=3, depthwise_separable=True)
head = MultiClassMlpHead(
    input_size=512, 
    inner_dim=128,  
    num_targets=10, 
    bias=True,
    dropout=0.05, 
    norm=nn.BatchNorm1d
)
model = ModelWrapper(backbone=backbone, head=head)

# Optimizer and loss
criterion = torch.nn.CrossEntropyLoss(weight=train.get_class_weights())
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)



In [7]:
trainer = Trainer(epochs=5, validate_every=100, criterion=criterion, optimizer=optimizer, scheduler=scheduler, wandb_project='test')
trainer(model=model, train_loader=train_loader, val_loader=val_loader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfagundo[0m. Use [1m`wandb login --relogin`[0m to force relogin
