In [None]:
########### define the training and testing function ###########
import os
import numpy as np
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
lr = 3e-4
#optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr=3e-4
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

batch_size=24
#%% define training and validation function
#second training function for optimizing the model
import pandas as pd
def IoU(pr, gt, th=0.5, eps=1e-7):
    pr = torch.sigmoid(pr) > th
    gt = gt > th
    intersection = torch.sum(gt * pr, axis=(-2,-1))
    union = torch.sum(gt, axis=(-2,-1)) + torch.sum(pr, axis=(-2,-1)) - intersection + eps
    ious = (intersection + eps) / union
    return torch.mean(ious).item()


from tqdm import tqdm
from collections import OrderedDict

##################### training function ##########
def train(dataloader, model, criterion, optimizer, epoch, scheduler=None):
    bar = tqdm(dataloader['train'])
    losses_avg, ious_avg = [], []
    train_loss, train_iou = [], []
    #model.to(device)
    model.cuda()
    model.train()
    for imgs, masks in bar:
        #imgs, masks = imgs.to(device), masks.to(device)
        imgs, masks = imgs.cuda(), masks.cuda()
        optimizer.zero_grad()
        y_hat = model(imgs)
        loss = criterion(y_hat, masks)
        loss.backward()
        optimizer.step()
        ious = IoU(y_hat, masks)
        train_loss.append(loss.item())
        train_iou.append(ious)
        #bar.set_description(f"loss {np.mean(train_loss):.5f} iou {np.mean(train_iou):.5f}")
    losses_avg=np.mean(train_loss)
    ious_avg=np.mean(train_iou)
    
    log = OrderedDict([('loss', losses_avg),
                       ('iou', ious_avg),
                       ])
    return log

def validate(dataloader, model, criterion):
    bar = tqdm(dataloader['val'])
    test_loss, test_iou = [], []
    losses_avg, ious_avg = [], []
    #model.to(device)
    model.cuda()
    model.eval()
    with torch.no_grad():
        for imgs, masks in bar:
            #imgs, masks = imgs.to(device), masks.to(device)
            imgs, masks = imgs.cuda(), masks.cuda()
            y_hat = model(imgs)
            loss = criterion(y_hat, masks)
            ious = IoU(y_hat, masks)
            test_loss.append(loss.item())
            test_iou.append(ious)
            bar.set_description(f"test_loss {np.mean(test_loss):.5f} test_iou {np.mean(test_iou):.5f}")
    losses_avg=np.mean(test_loss)
    ious_avg=np.mean(test_iou)
    log = OrderedDict([('loss', losses_avg),
                       ('iou', ious_avg),
                       ])
    
    return log

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
log = pd.DataFrame(index=[], columns=['epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'])
early_stop=20
epochs=10000
best_iou = 0
name='DensUnet' # create folder model and then create folder DensUnet
# model
   #DensUnet
trigger = 0
for epoch in range(epochs):
    print('Epoch [%d/%d]' %(epoch, epochs))
    # train for one epoch
    train_log = train(dataloader, model, criterion, optimizer, epoch)
    #train_log = train(train_loader, model, optimizer, epoch)
    # evaluate on validation set
    #val_log = validate(val_loader, model)
    val_log =validate(dataloader, model, criterion)
    print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'%(train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))

    tmp = pd.Series([epoch,lr,train_log['loss'],train_log['iou'],val_log['loss'],val_log['iou']], index=['epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'])

    log = log.append(tmp, ignore_index=True)
    log.to_csv('models/%s/log.csv' %name, index=False)

    trigger += 1

    if val_log['iou'] > best_iou:
        torch.save(model.state_dict(), 'models/%s/model.pth' %name)
        best_iou = val_log['iou']
        print("=> saved best model")
        trigger = 0

    # early stopping
    if not early_stop is None:
        if trigger >= early_stop:
            print("=> early stopping")
            break

    torch.cuda.empty_cache()
print("done training")