In [None]:
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/data_loader.py
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/extractNet_connected_vgg11.py
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/helper.py
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/metrics.py
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/trainer.py
!wget https://github.com/MNRKhan/aps360-project/raw/master/modules/visualizer.py

In [None]:
import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from data_loader import *
from extractNet_connected_vgg11 import *
from helper import *
from metrics import *
from trainer import *
from visualizer import *

In [None]:
!rm -rf __MACOSX
!rm -rf *.zip
!wget https://github.com/MNRKhan/aps360-project/raw/master/datasets/train2014/data_person_vehicle.zip
!unzip data_person_vehicle.zip
!rm -rf __MACOSX
!rm -rf *.zip

In [None]:
batch_size = 64
lr = 0.001

In [None]:
# Set random seeds
torch.manual_seed(360)
np.random.seed(360)
random.seed(360)

# Form dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = ImageMaskDataset("./data", transform, size=1000)

# Dataset sizes
size = len(dataset)
train_size = int(0.6 * size)
valid_size = int(0.2 * size)
test_size = size - train_size - valid_size

batch_size = 64

# Splitting datasets
train_data, valid_data, test_data = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

# Making dataloader
train = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
valid = DataLoader(valid_data, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
print("Full data set: ", size)
print("Training size: ", train_size)
print("Validation size: ", valid_size)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Model is being trained on:", device)
torch.cuda.empty_cache()

In [None]:
model = extractNet_connected_vgg11()
model.to(device)

In [None]:
# Train
trainModel(model, train, valid, batch_size=batch_size, lr=lr, num_epochs=80, checkpoint=False, device=device)

In [None]:
model = model.eval()
model.to("cpu")
torch.save(model.state_dict(), "epoch80")
model_state = torch.load("epoch80")

In [None]:
model.to("cpu")

# After training (train set)
img, target = train_data[-112]
target = target.squeeze(0).detach().numpy()
out = torch.sigmoid(model(img.unsqueeze(0)))
out = out.squeeze(0).squeeze(0).detach().numpy()
out = thresholdProbMask(out)
denoised = denoise(out)

imshow(np.transpose(img.squeeze(0).detach().numpy(), [1,2,0]))
imshow(target)
imshow(out)
imshow(denoised)

In [None]:
# After training (validation set)
img, target = valid_data[-2]
target = target.squeeze(0).detach().numpy()
out = torch.sigmoid(model(img.unsqueeze(0)))
out = out.squeeze(0).squeeze(0).detach().numpy()
out = thresholdProbMask(out)
denoised = denoise(out)

imshow(np.transpose(img.squeeze(0).detach().numpy(), [1,2,0]))
imshow(target)
imshow(out)
imshow(denoised)
imshow(np.expand_dims(denoised,axis=3)*np.transpose(img.squeeze(0).detach().numpy(), [1,2,0]))