In [1]:
from train import *
import matplotlib.pyplot as plt

image_size = 150
device = 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data = LabelledImageDataset('dataset/train', transform = A.Compose([
        A.Resize(image_size, image_size),   
        A.Normalize()
    ]))
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

val_data = LabelledImageDataset('dataset/val', transform = A.Compose([
        A.Resize(image_size, image_size),   
        A.Normalize()
    ]))
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

In [None]:
# visualize
def display_img(img,label):
    print(f"Label: {'primary' if label else 'footway'}")
    plt.imshow(img.permute(1,2,0))

train_features, train_labels = next(iter(val_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")


In [None]:
for i, l in zip(train_features, train_labels):
    #plt.imshow(i.numpy().transpose(1,2,0))
    plt.title(l)
    img = i.squeeze()
    plt.imshow(img.permute(1, 2, 0))
    plt.show()
    print(f"Label: {'primary' if l else 'footway'}")

In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def show_batch(dl):
    """Plot images grid of single batch"""
    for images, labels in dl:
        fig,ax = plt.subplots(figsize = (16,12))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(images,nrow=16).permute(1,2,0))
        break
        
show_batch(val_loader)

In [5]:
model = HighwayClassifier().to(device)

num_epochs = 30
lr = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr)
loss_fn = F.binary_cross_entropy



In [None]:
model

In [6]:

#fitting the model on training data and record the result after each epoch
for epoch in range(num_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    num_batches = 0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        num_batches += 1
        
    # print every epoch
    print(f'[{epoch + 1}] loss: {running_loss / num_batches:.3f}')


print('Finished Training')

[1] loss: 0.679
[2] loss: 0.498
[3] loss: 0.449
[4] loss: 0.376
[5] loss: 0.512
[6] loss: 0.371
[7] loss: 0.365
[8] loss: 0.236
[9] loss: 0.218
[10] loss: 0.178
[11] loss: 0.140
[12] loss: 0.113
[13] loss: 0.110
[14] loss: 0.044
[15] loss: 0.011
[16] loss: 0.001
[17] loss: 0.008
[18] loss: 0.210
[19] loss: 0.059
[20] loss: 0.092
[21] loss: 0.004
[22] loss: 0.000
[23] loss: 0.000
[24] loss: 0.000
[25] loss: 0.000
[26] loss: 0.000
[27] loss: 0.000
[28] loss: 0.000
[29] loss: 0.000
[30] loss: 0.000
Finished Training


In [7]:
torch.save(model.state_dict(), 'weights/model_weights.pth')

In [12]:
model.eval()

val_data = LabelledImageDataset('dataset/val', transform = A.Compose([
        A.Resize(image_size, image_size),   
        A.Normalize()
    ]))
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)


# visualize
def display_eval(img,pred,actual):
    print(f"Predicted: {'primary' if pred else 'footway'}")
    print(f"Actual: {'primary' if actual else 'footway'}")
    plt.imshow(img.permute(1,2,0))

num_correct = 0
num_total = 0
for val_features, val_labels in val_loader:
    # print(f"Feature batch shape: {val_features.size()}")
    # print(f"Labels batch shape: {val_labels.size()}")

    for i, actual in zip(val_features, val_labels):
        img = i.squeeze()
        pred = model(img.unsqueeze(0).to(device)).item() > 0.5
        num_correct += (pred == actual)
        num_total += 1
        # plt.title(l)
        # display_eval(img, pred, actual)
        # plt.show()

print(f"Accuracy: {num_correct / num_total}")


Accuracy: tensor([0.9200])
