<a href="https://colab.research.google.com/github/Turing-04/road_classifier_satellite/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
#!pip install data
!git clone "https://github.com/Turing-04/road_classifier_satellite.git"
 

Cloning into 'road_classifier_satellite'...
remote: Enumerating objects: 290, done.[K
remote: Counting objects: 100% (290/290), done.[K
remote: Compressing objects: 100% (271/271), done.[K
remote: Total 290 (delta 39), reused 260 (delta 18), pack-reused 0[K
Receiving objects: 100% (290/290), 35.54 MiB | 33.82 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [25]:
%cd road_classifier_satellite
%ls


/content/sample_data/road_classifier_satellite
augmentation.py         mask_to_submission.py        tf.py
create_more_data.ipynb  model.py                     [0m[01;34mtraining[0m/
data.py                 README.md                    train.ipynb
dummy_submission.csv    sample_submission.csv        train.py
environment.yml         segment_aerial_images.ipynb  utils.py
helpers.py              submission_to_mask.py
loss.py                 test.py


In [26]:
import os
import time

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from glob import glob
from data import DriveDataset
from model import build_unet
from loss import DiceLoss, DiceBCELoss
from utils import seeding, create_dir, epoch_time



In [27]:

def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss




In [None]:


""" Seeding """
seeding(42)

""" Directories """
create_dir("weights")

""" Load dataset """
train_x = sorted(glob("training/images/expanded/*"))
train_y = sorted(glob("training/groundtruth/expanded/*"))

# valid_x = sorted(glob("training/images/*"))
# valid_y = sorted(glob("training/groundtruth/*"))

data_str = f"Dataset Size:\nTrain: {len(train_x)}" # - Valid: {len(valid_x)}\n
print(data_str)

""" Hyperparameters """
H = 400
W = 400
size = (H, W)
batch_size = 2
num_epochs = 50
lr = 1e-4
checkpoint_path = "weights/checkpoint.pth"

""" Dataset and loader """
train_dataset = DriveDataset(train_x, train_y)
#valid_dataset = DriveDataset(valid_x, valid_y)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

# valid_loader = DataLoader(
#     dataset=valid_dataset,
#     batch_size=batch_size,
#     shuffle=False,
#     num_workers=2
# )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_unet()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    data_str = f"Checkpoint loaded: {checkpoint_path}"
    print(data_str)

""" Training the model """
#best_valid_loss = float("inf")

losses=[]
for epoch in range(num_epochs):
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    #valid_loss = evaluate(model, valid_loader, loss_fn, device)

    """ Saving the model """
    # if valid_loss < best_valid_loss:
    #     data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
    #     print(data_str)

    #    best_valid_loss = valid_loss
    if epoch%5==0:
        torch.save(model.state_dict(), checkpoint_path)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    losses.append(train_loss)

    data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
    data_str += f'\tTrain Loss: {train_loss:.3f}\n'
    #data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
    print(data_str)

#should be recorded in analysis/current_perf.md
print(losses)

Dataset Size:
Train: 100 - Valid: 100

