<a href="https://colab.research.google.com/github/Turing-04/road_classifier_satellite/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install data
!git clone "https://github.com/Turing-04/road_classifier_satellite.git"
 

Cloning into 'road_classifier_satellite'...
remote: Enumerating objects: 11003, done.[K
remote: Counting objects: 100% (3209/3209), done.[K
remote: Compressing objects: 100% (3208/3208), done.[K
remote: Total 11003 (delta 3), reused 3195 (delta 1), pack-reused 7794[K
Receiving objects: 100% (11003/11003), 1.89 GiB | 54.50 MiB/s, done.
Resolving deltas: 100% (209/209), done.
Checking out files: 100% (5086/5086), done.


In [2]:
%cd road_classifier_satellite
%ls


/content/road_classifier_satellite
[0m[01;34manalysis[0m/             mask_to_submission.py  train.py
augmentation.py       [01;34mmodels[0m/                train_unet.ipynb
data_augmentation.py  README.md              train_xception.ipynb
data.py               [01;34mtest[0m/                  [01;34mutilitary[0m/
dataset.py            test.py                utils.py
helpers.py            train_cnn.ipynb
loss.py               [01;34mtraining[0m/


In [3]:
import os
import time
from glob import glob

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from dataset import DriveDataset
from loss import DiceLoss, DiceBCELoss
import sys
from utils import seeding, create_dir, epoch_time

from models import model_unet
from models import model_resnet
from models import model_cnn2
from models import model_cnn4
from models import model_cnn8
from models import model_cnn16

In [4]:
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss

In [6]:
""" Read command line arguments """
model_name = sys.argv[1]

""" Setup """
seeding(42)
create_dir("weights")

""" Load dataset """
train_x = sorted(glob("training/images/training/*"))
train_y = sorted(glob("training/groundtruth/training/*"))

valid_x = sorted(glob("training/images/validation/*"))
valid_y = sorted(glob("training/groundtruth/validation/*"))

data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
print(data_str)

""" Hyperparameters """
size = (400, 400)
batch_size = 2
num_epochs = 50
lr = 1e-4

""" Create dataloader """
train_dataset = DriveDataset(train_x, train_y)
valid_dataset = DriveDataset(valid_x, valid_y)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

""" Load the model """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model_unet.build_unet()

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

checkpoint_path = "weights/checkpoint_" + model_name + ".pth"
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    print(f"Checkpoint loaded: {checkpoint_path}")

""" Training the model """
best_valid_loss = float("inf")

for epoch in range(num_epochs):
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)

    """ Saving the model """
    if valid_loss < best_valid_loss:
        print(f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}")

        best_valid_loss = valid_loss
        torch.save(model.state_dict(), checkpoint_path)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
    data_str += f'\tTrain Loss: {train_loss:.3f}\n'
    data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
    print(data_str)

Dataset Size:
Train: 2400 - Valid: 100

Checkpoint loaded: weights/checkpoint_-f.pth
Valid loss improved from inf to 0.1495. Saving checkpoint: weights/checkpoint_-f.pth
Epoch: 01 | Epoch Time: 0m 57s
	Train Loss: 0.257
	 Val. Loss: 0.150



KeyboardInterrupt: ignored