# Create Folders

In [None]:
!mkdir log
!mkdir model
!mkdir result

# Import

In [None]:
import gc
import time
import torch
import datetime
from utils import *
import torch.nn as nn
from Model import Unet
import torch.optim as optim
from Dataset import NavigationDataset
from torch.utils.data import DataLoader
from Loss import TverskyLoss, WeightedCrossEntropyLoss

# GLOBAL VARS

In [None]:
IMG_WIDTH = 428
IMG_HIGHT = 240
torch.cuda.empty_cache()
gc.collect()

# Train function

In [None]:
def train(model, inital_learning_rate, max_epoch, train_dataloader, vaild_dataloader):
    best_val_loss = float('inf')
    train_loss = []
    vaild_loss = []
    learning_rate = []

    # optimizer
    optimizer = optim.Adamax(model.parameters(), inital_learning_rate, weight_decay=0.02)
    
    # loss function
    loss_tversky = TverskyLoss()
    loss_bce = WeightedCrossEntropyLoss(pos_weight = 0.65/0.35)

    epoch = 0
    while epoch <= max_epoch:
        loss_all = 0
        epoch_step = 0
        train_loss_bce_record = 0
        train_loss_tversky_record = 0
        
        optimizer = lrfn(epoch, optimizer)
        learning_rate.append(optimizer.param_groups[0]['lr'])
        
        model.train()
        for i, data in enumerate(train_dataloader):
            images, ground_truths = data
            images, ground_truths = images.to(get_device()), ground_truths.to(get_device())

            optimizer.zero_grad()
            outs, d2, d3, d4, b = model(images)

            batch_loss = \
                loss_bce(outs, ground_truths) + loss_bce(d2, ground_truths) + loss_bce(d3, ground_truths) + loss_bce(d4, ground_truths) + loss_bce(b, ground_truths) + \
                loss_tversky(outs, ground_truths) + loss_tversky(d2, ground_truths) + loss_tversky(d3, ground_truths) + loss_tversky(d4, ground_truths) + loss_tversky(b, ground_truths)
                
            
            loss_bce_result = loss_bce(outs, ground_truths)
            loss_tversky_result = loss_tversky(outs, ground_truths)

            del outs, d2, d3, d4, b 

            batch_loss.backward()
            optimizer.step()

            epoch_step +=1

            loss_all += batch_loss.item()
            
            train_loss_bce_record += loss_bce_result.item()
            train_loss_tversky_record += loss_tversky_result.item()


        train_avg_loss = loss_all / epoch_step
        
        train_loss_bce_record_avg = train_loss_bce_record / epoch_step
        train_loss_tversky_record_avg = train_loss_tversky_record / epoch_step
        
        train_loss.append(train_avg_loss)

        model.eval()
        with torch.no_grad():
            val_loss_all = 0
            val_step = 0
            val_loss_bce_record = 0
            val_loss_tversky_record = 0
            
            for i, data in enumerate(vaild_dataloader):
                images, ground_truths = data
                images, ground_truths = images.to(get_device()), ground_truths.to(get_device())

                outs, d2, d3, d4, b = model(images)

                batch_loss = \
                    loss_bce(outs, ground_truths) + loss_bce(d2, ground_truths) + loss_bce(d3, ground_truths) + loss_bce(d4, ground_truths) + loss_bce(b, ground_truths) + \
                    loss_tversky(outs, ground_truths) + loss_tversky(d2, ground_truths) + loss_tversky(d3, ground_truths) + loss_tversky(d4, ground_truths) + loss_tversky(b, ground_truths)
            
                loss_bce_result = loss_bce(outs, ground_truths)
                loss_tversky_result = loss_tversky(outs, ground_truths)
                
                del outs, d2, d3, d4, b 

                val_loss_all += batch_loss.item()
                
                val_loss_bce_record += loss_bce_result.item()
                val_loss_tversky_record += loss_tversky_result.item()
                
                val_step += 1


            val_avg_loss = val_loss_all / val_step
            
            val_loss_bce_record_avg = val_loss_bce_record / val_step
            val_loss_tversky_record_avg = val_loss_tversky_record / val_step
            
            vaild_loss.append(val_avg_loss)
            
        print(f"[{epoch+1:3d}/{max_epoch:3d}] learning rate: {optimizer.param_groups[0]['lr']}")
        print(f"training:")
        print(f"WBCE loss: {train_loss_bce_record_avg:.8f}, Tversky loss: {train_loss_tversky_record_avg:.8f}, total loss: {train_avg_loss:.8f}")
        print(f"validating:")
        print(f"WBCE loss: {val_loss_bce_record_avg:.8f}, Tversky loss: {val_loss_tversky_record_avg:.8f}, total loss: {val_avg_loss:.8f}")
        print("====================================================================")

        if val_avg_loss < best_val_loss:
            best_val_loss = val_avg_loss
            print("Saving better model...")
            torch.save(model.state_dict(), "./model/best_model.pth")
            print("Saving better model complete.")

        record_log(train_loss, vaild_loss, epoch, train_dataloader.batch_size, best_val_loss, learning_rate)

        torch.save(model.state_dict(), "./model/newest_model.pth")

        epoch +=1

# Main

In [None]:
def main():
    
    # hyper parameter
    batch_size = 50
    learning_rate = 0.00001
    max_epoch = 80

    # dataset
    print("Loading training dataset...")
    train_set = NavigationDataset(mode="TRAIN", dataset_path="./Dataset/")

    # 10% of training set used as valid set
    valid_set_percent = 0.1
    vaild_set_length = int(len(train_set) * valid_set_percent)

    train_set,valid_set = torch.utils.data.random_split(
        dataset = train_set, lengths = [len(train_set)-vaild_set_length, vaild_set_length],
        generator = torch.Generator().manual_seed(int(time.time()))
    )

    train_dataloader = DataLoader(train_set, batch_size= batch_size, shuffle= True)
    valid_dataloader = DataLoader(valid_set, batch_size= batch_size, shuffle= True)
    print("Loading complete.")

    # record indices in each dataset
    with open("./log/train_set_log.txt", "w") as f:
        for train_index in train_dataloader.dataset.indices:
            f.write(f"{train_index},")
        f.close()

    with open("./log/valid_set_log.txt", "w") as f:
        for vaild_index in valid_dataloader.dataset.indices:
            f.write(f"{vaild_index},")
        f.close()
    
    with open("./log/seed_log.txt", "w") as f:
        f.write(f"{torch.random.initial_seed()},")
        f.close()

    # model
    model = Unet(3, 1)
    model.to(get_device())
    model = nn.DataParallel(model, device_ids=[i for i in range(torch.cuda.device_count())])

    # train
    print("Start training...")
    start_time = datetime.datetime.now()
    train(model, learning_rate, max_epoch, train_dataloader, valid_dataloader)
    end_time = datetime.datetime.now()
    print(f"Training complete.\nCost {str(end_time - start_time)} to training.")
    torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

# Plot result

In [None]:
import matplotlib.pyplot as plt
train_loss = []
with open("./log/" + "train_loss_log.txt","r") as f:
    train_loss = f.read().split(",")[:-1]
    for i in range(len(train_loss)):
        train_loss[i] = float(train_loss[i])

vaild_loss = []
with open("./log/" + "vaild_loss_log.txt","r") as f:
    vaild_loss = f.read().split(",")[:-1]
    for i in range(len(vaild_loss)):
        vaild_loss[i] = float(vaild_loss[i])

fig, axes = plt.subplots(1,1, figsize = (10,5))

axes.grid(True)
axes.set_title("Loss per epoch")
axes.set_xlabel("epoch")
axes.set_ylabel("Loss")
axes.plot(list(range(1,len(train_loss)+1)), train_loss,'b')
axes.plot(list(range(1,len(vaild_loss)+1)), vaild_loss,'r')
axes.legend(["train","vaild"])

plt.show()