## Console for running classifier training

In [1]:
# Create datasets and dataloaders

from maviratrain.data.classification_dataset import make_training_dataloaders

# The path to the dataset used for training
# train_path = "../data/test/classifier2/train/"
train_path = "../data/classifier/classifier354-r2-s2-n2/train/"

# The path to the dataset used for validation
# val_path = "../data/test/classifier2/val/"
val_path = "../data/classifier/classifier354-r2-s2-n2/val/"

# Specify any non-default dataloader parameters
# Defaults found in maviratrain.utils.constants
dataloader_params = {"batch_size": 64, "num_workers": 4}

# Specify any additional transforms to apply to the data
additional_transforms = None

# Create PyTorch datasets/dataloaders
train_dataset, train_dataloader, val_dataset, val_dataloader = (
    make_training_dataloaders(
        train_data_path=train_path,
        additional_transforms=additional_transforms,
        train_dataloader_params=dataloader_params,
        val_data_path=val_path,
        val_dataloader_params=dataloader_params,
    )
)

input_dims = list(train_dataset[0][0].shape)

In [2]:
# # Create ViT model

# from maviratrain.models.classifier_model import create_vit

# vit_kwargs = {
#     "image_size": train_dataset[0][0].shape[-1],
#     "patch_size": 16,
#     "num_layers": 6,
#     "num_heads": 6,
#     "hidden_dim": 360,
#     "mlp_dim": 1024,
#     "num_classes": len(train_dataset.classes),
# }
# model = create_vit(vit_kwargs=vit_kwargs)

In [3]:
# # Create EfficientNet model

# from torchvision.models import EfficientNet_B3_Weights

# from maviratrain.models.classifier_model import create_efficientnet_b3

# model = create_efficientnet_b3(
#     weights=EfficientNet_B3_Weights.IMAGENET1K_V1,
#     num_classes=len(train_dataset.classes),
# )

In [4]:
# # Create SimpleModel model

# from maviratrain.models.classifier_model import SimpleModel

# model = SimpleModel(
#     input_dims=input_dims, num_classes=len(train_dataset.classes)
# )

In [5]:
# Create TestModel model

from maviratrain.models.classifier_model import TestModel

model = TestModel(
    input_dims=input_dims, num_classes=len(train_dataset.classes)
)

In [None]:
# Create optimizers and learning rate schedulers

from torch.optim import AdamW, lr_scheduler

# Set up optimizer
optimizer = AdamW(
    params=model.parameters(),
    lr=0.001,
    betas=(0.9, 0.999),
    eps=1e-10,
    weight_decay=0.01,
    fused=True,
)

# Set up learning rate schedulers
scheduler1 = lr_scheduler.LinearLR(
    optimizer, start_factor=0.1, end_factor=1, total_iters=10, last_epoch=-1
)
scheduler2 = lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=100, last_epoch=10
)

optimization = [optimizer, scheduler1, scheduler2]

In [18]:
# Create loss function

from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss(
    weight=None,
    reduction="sum",
    label_smoothing=0,
)

In [19]:
# Set up Trainer

from maviratrain.train.train_classifier import Trainer

# Set up trainer
trainer = Trainer(
    loaders=[train_dataloader, val_dataloader],  # type: ignore
    optimization=optimization,  # type: ignore
    loss_fn=loss_fn,
)

In [21]:
# Train model

n_epochs = 100

model, n_steps_trained, n_epochs_trained = trainer.train(
    model=model, n_epochs=n_epochs
)

  Val Epoch: 0        Loss: 14.8320        Accuracy: 0.26        Top-5 Accuracy: 1.44
Train Epoch: 1        Loss: 236.5943        Accuracy: 0.88        Top-5 Accuracy: 3.47        Time: 247.84        
  Val Epoch: 1        Loss: 256.0670        Accuracy: 1.13        Top-5 Accuracy: 4.04
Train Epoch: 2        Loss: 420.0707        Accuracy: 1.74        Top-5 Accuracy: 5.92        Time: 207.13        
  Val Epoch: 2        Loss: 450.9185        Accuracy: 1.19        Top-5 Accuracy: 4.01
Train Epoch: 3        Loss: 592.8663        Accuracy: 2.70        Top-5 Accuracy: 8.42        Time: 209.86        
  Val Epoch: 3        Loss: 726.8832        Accuracy: 0.93        Top-5 Accuracy: 3.50
Train Epoch: 4        Loss: 753.7261        Accuracy: 3.66        Top-5 Accuracy: 10.60        Time: 80.65        
  Val Epoch: 4        Loss: 979.4936        Accuracy: 0.91        Top-5 Accuracy: 3.47
Train Epoch: 5        Loss: 908.3697        Accuracy: 4.63        Top-5 Accuracy: 12.66        Time: 86.55