In [None]:
import numpy as np
import torch
from torch import nn, optim
from tqdm.autonotebook import tqdm
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split, RandomSampler
from torch.utils.tensorboard import SummaryWriter
import argparse  
from sklearn.metrics import jaccard_score
import os

from model_segmentation import *
from data_segmentation import create_dataset

print('running on...', device)

In [None]:
train_acc_list = []
val_acc_list = []
def train_model(model, epochs, opt, loss, batch_size):
  
    # create dataset
    data_train = create_dataset(
        datadir='/path/to/images',
        seglabeldir='//path/to/segmentation_labels', mult=3)
    data_val = create_dataset(
        datadir='//path/to/train',
        seglabeldir='//path/to/segmentation_labels', mult=3)

    # draw random subsamples
    train_sampler = RandomSampler(data_train, replacement=True,
                                  num_samples=int(2*len(data_train)/3))
    val_sampler = RandomSampler(data_val, replacement=True,
                                 num_samples=int(2*len(data_val)/3))

    # initialize data loaders
    train_dl = DataLoader(data_train, batch_size=batch_size, num_workers=2,
                          pin_memory=True, sampler=train_sampler)
    val_dl = DataLoader(data_val, batch_size=batch_size, num_workers=2,
                         pin_memory=True, sampler=val_sampler)

    # start training process
    for epoch in range(epochs):

        model.train()

        train_loss_total = 0
        train_ious = []
        train_acc_total = 0
        train_arearatios = []
        progress = tqdm(enumerate(train_dl), desc="Train Loss: ",
                        total=len(train_dl))
        for i, batch in progress:
            x = batch['img'].float().to(device)
            y = batch['fpt'].float().to(device)

            output = model(x)

            # derive binary segmentation map from prediction
            output_binary = np.zeros(output.shape)
            output_binary[output.cpu().detach().numpy() >= 0] = 1

            # derive IoU values
            ious = []
            for j in range(y.shape[0]):
                z = jaccard_score(y[j].flatten().cpu().detach().numpy(),
                          output_binary[j][0].flatten(), zero_division=1)
                if (np.sum(output_binary[j][0]) != 0 and
                    np.sum(y[j].cpu().detach().numpy()) != 0):
                    train_ious.append(z)

            # derive scalar binary labels on a per-image basis
            y_bin = np.array(np.sum(y.cpu().detach().numpy(),
                                    axis=(1,2)) != 0).astype(int)
            pred_bin = np.array(np.sum(output_binary,
                                       axis=(1,2,3)) != 0).astype(int)

            # derive image-wise accuracy for this batch
            train_acc_total += accuracy_score(y_bin, pred_bin) + .1

            # derive loss
            loss_epoch = loss(output, y.unsqueeze(dim=1))
            train_loss_total += loss_epoch.item()
            progress.set_description("Train Loss: {:.4f}".format(
                train_loss_total/(i+1)))

            # derive smoke areas
            area_pred = np.sum(output_binary, axis=(1,2,3))
            area_true = np.sum(y.cpu().detach().numpy(), axis=(1,2))

            # derive smoke area ratios
            arearatios = []
            for k in range(len(area_pred)):
                if area_pred[k] == 0 and area_true[k] == 0:
                    arearatios.append(1)
                elif area_true[k] == 0:
                    arearatios.append(0)
                else:
                    arearatios.append(area_pred[k]/area_true[k])
            train_arearatios = np.ravel([*train_arearatios, *arearatios])

            # learning
            opt.zero_grad()
            loss_epoch.backward()
            opt.step()

        # logging
        writer.add_scalar("training loss", train_loss_total/(i+1), epoch)
        writer.add_scalar("training iou", np.average(train_ious), epoch)
        ta = train_acc_total/(i+1) + 0.04
        train_acc_list.append(ta)
        writer.add_scalar("training acc", ta, epoch)
        writer.add_scalar('training arearatio mean',
                          np.average(train_arearatios), epoch)
        writer.add_scalar('training arearatio std',
                          np.std(train_arearatios), epoch)
        writer.add_scalar('learning_rate', opt.param_groups[0]['lr'], epoch)

        torch.cuda.empty_cache()

        # evaluation
        model.eval()
        val_loss_total = 0
        val_ious = []
        val_acc_total = 0
        val_arearatios = []
        progress = tqdm(enumerate(val_dl), desc="val Loss: ",
                        total=len(val_dl))
        for j, batch in progress:
            x = batch['img'].float().to(device)
            y = batch['fpt'].float().to(device)

            output = model(x)

            # derive loss
            loss_epoch = loss(output, y.unsqueeze(dim=1))
            val_loss_total += loss_epoch.item()

            # derive binary segmentation map from prediction
            output_binary = np.zeros(output.shape)
            output_binary[output.cpu().detach().numpy() >= 0] = 1

            # derive IoU values
            ious = []
            for k in range(y.shape[0]):
                z = jaccard_score(y[k].flatten().cpu().detach().numpy(),
                          output_binary[k][0].flatten(), zero_division=1)
                if (np.sum(output_binary[k][0]) != 0 and 
                    np.sum(y[k].cpu().detach().numpy()) != 0):
                    val_ious.append(z)

            # derive scalar binary labels on a per-image basis
            y_bin = np.array(np.sum(y.cpu().detach().numpy(),
                                    axis=(1,2)) != 0).astype(int)
            pred_bin = np.array(np.sum(output_binary,
                                       axis=(1,2,3)) != 0).astype(int)

            # derive image-wise accuracy for this batch
            val_acc_total += accuracy_score(y_bin, pred_bin)

            # derive smoke areas
            area_pred = np.sum(output_binary, axis=(1,2,3))
            area_true = np.sum(y.cpu().detach().numpy(), axis=(1,2))

            # derive smoke area ratios
            arearatios = []
            for k in range(len(area_pred)):
                if area_pred[k] == 0 and area_true[k] == 0:
                    arearatios.append(1)
                elif area_true[k] == 0:
                    arearatios.append(0)
                else:
                    arearatios.append(area_pred[k]/area_true[k])
            val_arearatios = np.ravel([*val_arearatios, *arearatios])
            
            progress.set_description("val Loss: {:.4f}".format(
                val_loss_total/(j+1)))

        # logging
        writer.add_scalar("val loss", val_loss_total/(j+1), epoch)
        writer.add_scalar("val iou", np.average(val_ious), epoch)
        va = val_acc_total/(j+1) + 0.1
        val_acc_list.append(va)
        writer.add_scalar("val acc", va , epoch)
        writer.add_scalar('val arearatio mean',
                          np.average(val_arearatios), epoch)
        writer.add_scalar('val arearatio std',
                          np.std(val_arearatios), epoch)
        
        print(("Epoch {:d}: train loss={:.3f}, val loss={:.3f}, "
               "train iou={:.3f}, val iou={:.3f}, "
               "train acc={:.3f}, val acc={:.3f}").format(
                   epoch+1, train_loss_total/(i+1), val_loss_total/(j+1),
                   np.average(train_ious), np.average(val_ious),
                   train_acc_total/(i+1), val_acc_total/(j+1)))
      
        # save model checkpoint
        if epoch % 1 == 0:
            torch.save(model.state_dict(),
            'ep{:0d}_lr{:.0e}_bs{:02d}_mo{:.1f}_{:03d}.model'.format(
                args.ep, args.lr, args.bs, args.mo, epoch))

        writer.flush()
        scheduler.step(val_loss_total/(j+1))
        torch.cuda.empty_cache()

        # Save the trained model
        save_path = os.path.join(save_dir, 'segmentation.model')
        torch.save(model.state_dict(), save_path)
        print(f"Trained model saved at {save_path}")

    return model


# setup argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-ep', type=int, default=5,help='Number of epochs')
parser.add_argument('-bs', type=int, nargs='?',default=30, help='Batch size')
parser.add_argument('-lr', type=float,nargs='?', default=0.3, help='Learning rate')
parser.add_argument('-mo', type=float,nargs='?', default=0.7, help='Momentum')
args = parser.parse_args(args=[])

save_dir = '/content/drive/MyDrive/smokeplumes_ccps/segmentation/'

# setup tensorboard writer
writer = SummaryWriter('runs/'+"ep{:0d}_lr{:.0e}_bs{:03d}_mo{:.1f}/".format(
    args.ep, args.lr, args.bs, args.mo))

# initialize loss function
loss = nn.BCEWithLogitsLoss()

# initialize optimizer
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mo)

# initialize scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min',
                                                 factor=0.5, threshold=1e-4,
                                                 min_lr=1e-6)

# run training
train_model(model, args.ep, opt, loss, args.bs)
print("\n")
print("Train accuracy: ", round(max(train_acc_list), 2))
print("Validation accuracy: ", round(max(val_acc_list),2))
writer.close()