In [34]:
import tiny_utils # custom module
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import segmentation_models_pytorch as smp
from train import UNET
from tiny_utils import ShipDatabaseSegmation
from sklearn.model_selection import train_test_split
from skimage.morphology import binary_opening, disk, label

In [37]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
base_model = model.load_state_dict(torch.load("base_model_checkpoint.pth.tar")['state_dict']) 

In [46]:
# loading data
data = pd.read_csv("data/train_ship_segmentations_v2.csv")
data = data.dropna() 
data = data.sample(frac=0.125, replace=False, random_state=42)

train, valid = train_test_split(data, test_size = 0.05, random_state=42)
print(train.shape[0], 'training set')
print(valid.shape[0], 'validation set')


transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor() 
])
ship_dataset_train = ShipDatabaseSegmation(train, "data/train_v2", transforms=transforms)
ship_dataset_valid = ShipDatabaseSegmation(valid, "data/train_v2", transforms=transforms)

train_loader = torch.utils.data.DataLoader(ship_dataset_train, batch_size=5, shuffle=True, num_workers=8)
valid_loader = torch.utils.data.DataLoader(ship_dataset_valid, batch_size=1, shuffle=False, num_workers=4)

9704 training set
511 validation set


In [None]:
# loading U-Net Layer to Base UNET
loss = nn.BCEWithLogitsLoss()
loss.__name__ = "bceWithLogitLoss"
device = "cuda"
metrics = [smp.utils.metrics.IoU(threshold=0.5),]

test_epoch_UNET = smp.utils.train.ValidEpoch(model, 
                                            loss=loss, 
                                            metrics=metrics, 
                                            device=device,
                                            verbose=True,)

valid_logs = test_epoch_UNET.run(valid_loader) #  on data validation
train_logs = test_epoch_UNET.run(train_loader) #  on data train

valid: 100%|███████████████████████████████████████████████████████████| 508/508 [01:40<00:00,  5.08it/s, bceWithLogitLoss - 0.03445, iou_score - 7.982e-10]
valid:  21%|████████████                                              | 352/1690 [06:18<23:54,  1.07s/it, bceWithLogitLoss - 0.03504, iou_score - 2.412e-11]

In [None]:
# terrible IoU scores. Further parameter tweaking is needed. 