In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from src.config import Config
from src.dataset import prepare_datasets
from src.models import MushroomCNN
from src.trainer import Trainer


In [None]:
# prepare augmented datasets: classes are getting somewhat balanced through augmenting less represented classes. validation set is not augmented.
# the validation set contains the same number of samples per class for each class.
train_dataset, val_dataset, test_dataset = prepare_datasets(Config.DATA_DIR)

In [4]:
model = MushroomCNN(Config.NUM_CLASSES).cuda()
# model = MushroomVGG19(Config.NUM_CLASSES).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=True,
    num_workers=4
)

val_loader = torch.utils.data.DataLoader(
    val_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False,
    num_workers=4
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=5
)

In [None]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=Config.DEVICE,
    scheduler=scheduler
)

val_acc = trainer.train(num_epochs=Config.NUM_EPOCHS, model_name='cnn')
