In [4]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from data import MemmapDataset
from models import *
from training_utils import TrainingSession, plot_loss_comparison, plot_comparison_metrics

In [None]:
# Constants for training
DATA_SPLIT = 0.9
INIT_LR = 0.005
NUM_EPOCHS = 10
BATCH_SIZE = 32
INPUT_IMAGE_SIZE = 224
THRESHOLD = 0.5
LOSS = JaccardLoss()

# define the path to each directory
BASE_DIR = "A:\\Desktop\\Drone_Datasets\\Original"
IMAGE_DIR = BASE_DIR + "\\dataset_images.npy"
LABEL_DIR = BASE_DIR + "\\dataset_labels.npy"

# Build the dataset and split it into training and test sets
dataset = MemmapDataset(np.load(IMAGE_DIR, 'r'), np.load(LABEL_DIR, 'r'))
trainDS, testDS = dataset.split(DATA_SPLIT)


# create data loaders
trainLoader = DataLoader(trainDS, shuffle=True, batch_size=BATCH_SIZE)
testLoader = DataLoader(testDS, shuffle=False, batch_size=BATCH_SIZE)

print(f"Dataset containing {len(dataset)} images loaded.")
del trainDS, testDS, dataset


In [None]:
model = ResNet_UNet(input_image_size=INPUT_IMAGE_SIZE)
Machine = TrainingSession(model, trainLoader, testLoader, LOSS, init_lr=INIT_LR, num_epochs=NUM_EPOCHS)
Machine.learn()

In [None]:
Machine.plot_loss()

In [None]:
print("Available Metrics:", Machine.get_available_metrics())
Machine.plot_metrics("Metrics")

In [None]:
torch.save(Machine.model.state_dict(), "resnet_unet.pth")