In [None]:
# -*- coding: utf-8 -*-
# -*- author : Vincent Roduit -*-
# -*- date : 2023-11-25 -*-
# -*- Last revision: 2023-11-25 -*-
# -*- python version : 3.11.6 -*-
# -*- Description: Notebook that summarize results-*-

# <center> CS -433 Machine Learning </center>
## <center> Ecole Polytechnique Fédérale de Lausanne </center>
### <center>Road Segmentation </center>
--- 

### Preparing environment for Google Colaboratory

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/ml-project-2-team-slo/source

### Imports

In [None]:
#import libraries
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

#import model parameters
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam

%load_ext autoreload
%autoreload 2

In [None]:
#import files
from data_processing import*
from visualization import visualize, visualize_patch
import constants
from test_data import TestData

#import models
from cnn import Basic_CNN
from logistic_regression import LogisticRegression

In [None]:
# Set random seed for reproducibility
torch.manual_seed(0)

## 1. Data wrangling and visualization

In [None]:
myDatas = AdvancedProcessing(standardize=False)
myDatas.proceed()

In [None]:
visualize(myDatas.imgs, myDatas.gt_imgs, index=4)

In [None]:
visualize_patch(myDatas.X_train[0].transpose(1,2,0))

## 2. Define and train models

### 2.1 Logistic regression

A first attempt could be to try with some linear model. The first approach here is to use a simple logistic regression. In order to use a logistic regression, one need to extract feature from the image. A choice could be to use the mean and the standard deviation as features. The following section will present these approach.

In [None]:
LogisticData = BasicProcessing()
LogisticData.load_data()
LogisticData.create_patches()

In [None]:
LogReg = LogisticRegression(LogisticData.imgs_patches, LogisticData.gt_imgs_patches)
LogReg.compute_vectors()

In [None]:
plt.scatter(LogReg.X[:, 0], LogReg.X[:, 1], c=LogReg.Y, edgecolors="k", cmap=plt.cm.Paired)
plt.show()

A problem already arises. The datas are not linearly separable.

In [None]:
LogReg.train()
LogReg.predict()

In [None]:
print(f'From this model, the accuracy is {LogReg.accuracy*100:.2f}% and the F1 score is {LogReg.f1*100:.2f}%')

The unsatisfactory results tend us to move to Convolutional Networks, which are more suitable for image datas. 

### 2.2 Basic Convolutional Neural Network

In [None]:
cnn = Basic_CNN(constants.WINDOW_SIZE)

# Define the loss function, optimizer and scheduler
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
criterion = nn.BCEWithLogitsLoss()

# Train the model
cnn.train_model(
    optimizer, 
    scheduler, 
    criterion, 
    myDatas.train_dataloader, 
    myDatas.validate_dataloader, 
    num_epochs=10)

In [None]:
myTestData = TestData(model=cnn)
myTestData.proceed()