In [4]:
import torch.nn as nn
import torch
import os
import re
from matplotlib.pyplot import imread
import numpy as np
import cv2
from torch.utils import data
from torchvision.transforms import ToTensor
from torch.utils import data
from tqdm.notebook import tqdm

folder = "../data/processed_data"
positive_dataset = []
negative_dataset = []

positive_reshaped_dataset = []
negative_reshaped_dataset = []

In [5]:
def load_ground_training_data(folder):
    for f in os.listdir(folder):
        if len(f) > 14 and f[-14:] == "top-copper.png":
            orig_img = imread(os.path.join(folder,f))
            orig_img = np.array(orig_img)
            positive_dataset.append(orig_img)
            resized_img = cv2.resize(orig_img, dsize=(240, 240), interpolation=cv2.INTER_CUBIC)
            positive_reshaped_dataset.append(resized_img)
        if len(f) > 17 and f[-17:] == "bottom-copper.png":
            orig_img = imread(os.path.join(folder,f))
            orig_img = np.array(orig_img)
            negative_dataset.append(orig_img)
            resized_img = cv2.resize(orig_img, dsize=(240, 240), interpolation=cv2.INTER_CUBIC)
            negative_reshaped_dataset.append(resized_img)
load_ground_training_data(folder)



In [25]:
class GroundPlaneNet(nn.Module):
    def __init__(self):
        super(GroundPlaneNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 5, stride=2, dilation=2)
        self.fc1 = nn.Linear(116*116, 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 116*116)
        x = self.fc1(x)
        #x = torch.sigmoid(x)
        return x
    
class GroundPlaneDataset(data.Dataset):
    def __len__(self):
        return len(positive_reshaped_dataset) + len(negative_reshaped_dataset)
    
    def __getitem__(self, index):
        
        if index >= len(positive_reshaped_dataset):
            img = ToTensor()(negative_reshaped_dataset[index-len(positive_reshaped_dataset)])
            gt = torch.Tensor(np.asarray([0])).unsqueeze(0)
        else:
            img = ToTensor()(positive_reshaped_dataset[index])
            gt = torch.Tensor(np.asarray([1])).unsqueeze(0)
        return img,gt
        
        
def train():
    train_dataset = GroundPlaneDataset()
    train_dataloader = data.DataLoader(train_dataset, batch_size=4, 
                                       shuffle=True, num_workers=4, 
                                       drop_last=True)
    EPOCHS = 20

    import torch.optim as optim
    criterion = nn.BCEWithLogitsLoss()

    # Tune the learning rate.
    # See whether the momentum is useful or not
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


    for epoch in tqdm(range(EPOCHS)):  # loop over the dataset multiple times
        print(epoch)
        model.train()
        running_loss = 0.0
        for i, batch in enumerate(train_dataloader, 0):
            # get the inputs
            inputs, gt = batch

            inputs = inputs
            gt = gt
            

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            loss = criterion(outputs, gt.squeeze(0).squeeze(1))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
        
        # Normalizing the loss by the total number of train batches
        running_loss/=len(train_dataloader)
        print('[%d] loss: %.3f' %
            (epoch + 1, running_loss))

        
model = GroundPlaneNet()
train()
torch.save(model.state_dict(), "GroundDetectionModel")

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

0
[1] loss: 0.528
1
[2] loss: 0.457
2
[3] loss: 0.382
3
[4] loss: 0.410
4
[5] loss: 0.425
5
[6] loss: 0.612
6
[7] loss: 0.370
7
[8] loss: 0.284
8
[9] loss: 0.218
9
[10] loss: 0.177
10
[11] loss: 0.146
11
[12] loss: 0.034
12
[13] loss: 0.353
13
[14] loss: 0.087
14
[15] loss: 0.065
15
[16] loss: 0.024
16
[17] loss: 0.026
17
[18] loss: 0.013
18
[19] loss: 0.013
19
[20] loss: 0.011

