## Console for running classifier training

In [1]:
# Create datasets and dataloaders

from maviratrain.data.classification_dataset import make_training_dataloaders

# The path to the dataset used for training
train_path = "../data/classifier/classifier354-r2-s2-n2/train/"
# train_path = "/mnt/disks/localssd/data/classifier362-r2-s3-n4/train/"

# The path to the dataset used for validation
val_path = "../data/classifier/classifier354-r2-s2-n2/val/"
# val_path = "/mnt/disks/localssd/data/classifier362-r2-s3-n4/val/"

# Specify any non-default dataloader parameters
# Defaults found in maviratrain.utils.constants
dataloader_params = {"batch_size": 64, "num_workers": 3}

# Specify any additional transforms to apply to the data
additional_transforms = None

# Create PyTorch datasets/dataloaders
train_dataset, train_dataloader, val_dataset, val_dataloader = (
    make_training_dataloaders(
        train_data_path=train_path,
        additional_transforms=additional_transforms,
        train_dataloader_params=dataloader_params,
        val_data_path=val_path,
        val_dataloader_params=dataloader_params,
    )
)

input_dims = list(train_dataset[0][0].shape)

In [2]:
# # Create ViT model

# from maviratrain.models.classifier_model import create_vit

# vit_kwargs = {
#     "image_size": train_dataset[0][0].shape[-1],
#     "patch_size": 16,
#     "num_layers": 6,
#     "num_heads": 6,
#     "hidden_dim": 360,
#     "mlp_dim": 1024,
#     "num_classes": len(train_dataset.classes),
# }
# model = create_vit(vit_kwargs=vit_kwargs)

In [3]:
# Create EfficientNet model

# from torchvision.models import EfficientNet_B3_Weights

from maviratrain.models.classifier_model import create_efficientnet_b3

model = create_efficientnet_b3(
    weights=None,
    # weights=EfficientNet_B3_Weights.IMAGENET1K_V1,
    num_classes=len(train_dataset.classes),
)

In [4]:
# # Create SimpleModel model

# from maviratrain.models.classifier_model import SimpleModel

# model = SimpleModel(
#     input_dims=input_dims, num_classes=len(train_dataset.classes)
# )

In [5]:
# # Create TestModel model

# from maviratrain.models.classifier_model import TestModel

# model = TestModel(
#     input_dims=input_dims, num_classes=len(train_dataset.classes)
# )

In [6]:
# Create optimizers and learning rate schedulers

from torch.optim import AdamW, lr_scheduler

# Number of epochs for warmup and total training
warmup_epochs = 5
total_epochs = 50

# Set up optimizer
optimizer = AdamW(
    params=model.parameters(),
    lr=0.001,
    betas=(0.9, 0.999),
    eps=1e-10,
    weight_decay=0.01,
    fused=True,
)

# Set up learning rate schedulers
scheduler1 = lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1,
    total_iters=warmup_epochs,
    last_epoch=-1,
)
scheduler2 = lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=total_epochs, last_epoch=warmup_epochs
)

optimization = [optimizer, scheduler1, scheduler2]

In [7]:
# Create loss function

from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss(
    weight=None,
    reduction="sum",
    label_smoothing=0,
)

In [8]:
# Set up Trainer

from maviratrain.train.train_classifier import Trainer

# Set up trainer
trainer = Trainer(
    loaders=[train_dataloader, val_dataloader],  # type: ignore
    optimization=optimization,  # type: ignore
    loss_fn=loss_fn,
)

In [None]:
# Train model

model, n_steps_trained, n_epochs_trained = trainer.train(
    model=model, n_epochs=total_epochs
)

17:36:49 - Val Epoch: 0        Loss: 5.8796        Accuracy: 0.32        Top-5 Accuracy: 1.56
