# End-to-End Ego Lane Estimation based on Sequential Transfer Learning for Self-Driving Cars

Following the [paper here](https://openaccess.thecvf.com/content_cvpr_2017_workshops/w13/papers/Kim_End-To-End_Ego_Lane_CVPR_2017_paper.pdf), we are attempting to use transfer learning to estimate lanes better.

In [3]:
import torch
from torch import nn, optim
from models import SymVGG16

learning_rate = 0.05
momentum = 0.95
epochs = 10

# -------------

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = SymVGG16().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
# data
import loader

training_dataset = loader.Dataset("data")
test_dataset = loader.Dataset("data", test=True)

training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

In [None]:
# testing and training
# testing and trainking
from tqdm.notebook import tqdm

loss_history = []
acc_history = []
n = 10

def test():
    correct = 0
    total = 0
    
    model.eval()
    for data in test_dataloader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += torch.numel(labels)
        correct += (predicted == labels).sum().item()
        
    return correct / total

current_epoch = 0
def train():
    global current_epoch
    progress_bar = tqdm(range((epochs - current_epoch) * len(training_dataloader)))
    epoch_bar = tqdm(range(epochs - current_epoch))
    
    current_progress = current_epoch * len(training_dataloader)
    progress_bar.n = current_progress
    progress_bar.last_print_n = current_progress
    
    model.train()
    try:
        while(current_epoch < epochs):
            for data in training_dataloader:
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels.to(torch.long))
                
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                progress_bar.update(1)
                progress_bar.set_description("loss: %.8f" % loss.item())
                
            if epoch_bar.n % n == 0:
                loss_history.append(loss.item())
                
                acc = test()
                acc_history.append(acc)
                model.train()
                epoch_bar.set_description("accuracy: %.3f" % acc)
                
            epoch_bar.update(1)
    except KeyboardInterrupt:
        current_epoch = epoch_bar.n

In [None]:
train()