In [None]:
import torch
import numpy as np
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.optim import Adam
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from dataset import *
from utils import *
from rpn import RPNHead

Build dataset and dataloader for both training and validation

In [None]:
## Build the dataset ##
imgs_path = './data/hw3_mycocodata_img_comp_zlib.h5'
masks_path = './data/hw3_mycocodata_mask_comp_zlib.h5'
labels_path = './data/hw3_mycocodata_labels_comp_zlib.npy'
bboxes_path = './data/hw3_mycocodata_bboxes_comp_zlib.npy'
paths = [imgs_path, masks_path, labels_path, bboxes_path]
dataset = BuildDataset(paths)
# Split the whole dataset into 80% training and 20% validation
full_size = len(dataset)
train_size = int(full_size * 0.8)
val_size = full_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

## Build dataloader from training and validation dataset ##
batch_size = 4
train_build_loader = BuildDataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_loader = train_build_loader.loader()
val_build_loader = BuildDataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
val_loader = val_build_loader.loader()

Training

In [None]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=2,
    monitor="val_loss",
    mode="min",
    filename="best-val-{epoch:02d}-{val_loss:.2f}",
)
trainer = Trainer(gpus=1, max_epochs=70, callbacks=[checkpoint_callback])
rpn_net = RPNHead().cuda()
trainer.fit(rpn_net, train_dataloaders=train_loader, val_dataloaders=val_loader)
train_loss = np.array(rpn_net.total_loss)
cls_loss = np.array(rpn_net.cls_loss)
reg_loss = np.array(rpn_net.reg_loss)
val_loss = np.array(rpn_net.val_loss)
np.save("total_loss.npy", train_loss)
np.save("cls_loss.npy", cls_loss)
np.save( "reg_loss.npy", reg_loss)
np.save("val_loss.npy", val_loss)