In [1]:
from dataset import MosquitoDataset
from model import build_SSD
from train import train

import torch
from torch.utils.data import DataLoader


In [2]:
datafolder = "simple_dataset"
batch_size = 8
images_dir = f"../datasets/{datafolder}/images"
labels_dir = f"../datasets/{datafolder}/labels"

train_dataset = MosquitoDataset(f"{images_dir}/train", f"{labels_dir}/train", img_size=300)
val_dataset = MosquitoDataset(f"{images_dir}/val", f"{labels_dir}/val", img_size=300)
test_dataset = MosquitoDataset(f"{images_dir}/test", f"{labels_dir}/test", img_size=300)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [3]:
model = build_SSD(num_classes=7) # 6 mosquito classes + 1 background
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [24]:
N_epochs = 10
save_dir = f"models/{datafolder}/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_losses, val_losses, lr_history =  train(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    trainloader=trainloader,
    valloader=valloader,
    device=device,
    N_epochs=N_epochs,
    save_dir=save_dir,
    save_name="best_model.pth"
)

  2%|▏         | 14/797 [00:08<08:13,  1.59it/s]


KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load(f"{save_dir}/best_model.pth"))