In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import os, json, time, cv2, shutil, argparse

from torch.autograd import Variable
from datetime import datetime
from collections import OrderedDict
from utils.dataloader import get_loader, test_dataset
from utils.utils import clip_gradient, adjust_lr, AvgMeter
from CaraNet import caranet
import torchinfo
import subprocess

#### Helpers

In [None]:
def get_vram_usage():
    try:
        cmd = ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits']
        result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        if result.returncode == 0:
            vram_used = int(result.stdout.strip())
            return vram_used
        else:
            print("Error:", result.stderr, flush='True')
    except Exception as e:
        print("An error occurred:", e, flush='True')
    return None

# Calculate the loss
def structure_loss(pred, mask):
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    
    return (wbce + wiou).mean()

def evaluate(model, data_path):
    model.eval()
    image_root = '{}/images/'.format(data_path)
    gt_root = '{}/masks/'.format(data_path)
    test_loader = test_dataset(image_root, gt_root, 352)
    b=0.0
    print('[test_size]',test_loader.size)

    for i in range(test_loader.size):
        image, gt, name = test_loader.load_data()
        
        gt = np.asarray(gt, np.float32)
        gt /= (gt.max() + 1e-8)
        image = image.cuda()
        
        res5, res3, res2, res1 = model(image)

        # Dice
        res = res5
        res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
        res = res.sigmoid().data.cpu().numpy().squeeze()
        res = (res - res.min()) / (res.max() - res.min() + 1e-8)
        
        input = res
        target = np.array(gt)
        N = gt.shape
        smooth = 1
        input_flat = np.reshape(input,(-1))
        target_flat = np.reshape(target,(-1))
 
        intersection = (input_flat*target_flat)

        dice =  (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth)
        
        a =  '{:.4f}'.format(dice)
        a = float(a)
        b = b + a

    mdice = b/test_loader.size
        
    # Fixed, this should vary according to the test set size (Rather than be fixed)
    return mdice

# Trains the model for one epoch
def train(train_loader, model, optimizer, epoch, evaluate_path, exp_config, total_step):
    model.train()
    # ---- multi-scale training ----
    size_rates = [0.75, 1, 1.25]
    loss_record1, loss_record2, loss_record3, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
    epoch_loss_record = AvgMeter()
    train_mdice_record = AvgMeter()
    
    for i, pack in enumerate(train_loader, start=1):
        for rate in size_rates:
            optimizer.zero_grad()
            # ---- data prepare ----
            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            # ---- rescale ----
            trainsize = int(round(exp_config['trainsize']*rate/32)*32)
            if rate != 1:
                images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            # ---- forward ----
            lateral_map_5, lateral_map_3, lateral_map_2, lateral_map_1 = model(images)
            # ---- loss function ----
            loss5 = structure_loss(lateral_map_5, gts)
            loss3 = structure_loss(lateral_map_3, gts)
            loss2 = structure_loss(lateral_map_2, gts)
            loss1 = structure_loss(lateral_map_1, gts)
            
            # Structure Loss
            loss = loss5 + loss3 + loss2 + loss1
            
            ## Dice Metric
            batch_dice = []

            # For each ground truth mask in this batch
            for gt_id in range(0, len(gts)):
                gt = gts[gt_id].cpu()

                gt = np.asarray(gt, np.float32)
                gt /= (gt.max() + 1e-8)

                res = lateral_map_5[gt_id]
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                
                input = res
                target = np.array(gt)
                N = gt.shape
                smooth = 1
                input_flat = np.reshape(input,(-1))
                target_flat = np.reshape(target,(-1))
        
                intersection = (input_flat*target_flat)
                
                # Calculate the image dice metric and append it to the batch metrics
                img_dice =  (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth)
                batch_dice.append(img_dice)

            # ---- backward ----
            loss.backward()
            clip_gradient(optimizer, exp_config['clip'])
            optimizer.step()
            # ---- recording loss ----
            if rate == 1:
                loss_record5.update(loss5.data, exp_config['batchsize'])
                loss_record3.update(loss3.data, exp_config['batchsize'])
                loss_record2.update(loss2.data, exp_config['batchsize'])
                loss_record1.update(loss1.data, exp_config['batchsize'])
                epoch_loss_record.update(loss.detach().cpu(), exp_config['batchsize'])
                train_mdice_record.update(torch.tensor(batch_dice).mean(), exp_config['batchsize'])
        # ---- train visualization ----
        if i % 20 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
                  ' lateral-5: [{:0.4f}], lateral-3: [{:0.4f}], lateral-2: [{:0.4f}], lateral-1: [{:0.4f}], '
                  'train_loss: [{:0.4f}], train_mdice: [{:0.4f}]'.
                  format(datetime.now(), epoch, exp_config['epoch'], i, total_step,
                          loss_record5.show(),loss_record3.show(),loss_record2.show(),loss_record1.show()
                          ,epoch_loss_record.show().numpy(), train_mdice_record.show()))
    
    save_path = f'{exp_config["models_path"]}/'
    os.makedirs(save_path, exist_ok=True)

    test_mdice = evaluate(model, evaluate_path)

    train_results_dict = {
        'train_loss': ['{:.4f}'.format(epoch_loss_record.show().numpy())]
        # , 'test_loss': [test_loss]
        , 'train_mdice': ['{:.4f}'.format(train_mdice_record.show().numpy())]
        , 'test_mdice': ['{:.4f}'.format(test_mdice)]
        # , 'train_miou':
        # , 'test_miou':
        , 'epoch': [epoch]
        , 'exp_name': [exp_config['exp_name']]
    }

    ## Create dataframe and log to disk
    # Wide Format
    train_results_df = pd.DataFrame.from_dict(train_results_dict, orient='columns')
    output_path = f'{exp_config["logs_path"]}/results_train_wide.csv'
    train_results_df.to_csv(output_path, mode='a', index=False, header=not os.path.exists(output_path))

    # Long Format
    train_results_df_long = pd.melt(
        frame=train_results_df, id_vars=['epoch', 'exp_name']
        , value_vars=['train_loss', 'train_mdice', 'test_mdice']
        , var_name='metric', value_name='value')
    train_results_df_long = train_results_df_long.sort_values('epoch')
    output_path = f'{exp_config["logs_path"]}/results_train_long.csv'
    train_results_df_long.to_csv(output_path, mode='a', index=False, header=not os.path.exists(output_path))
    
    ## Log Results to Disk
    # Create files if they doesn't exist
    if not os.path.exists(f'{exp_config["logs_path"]}/log.txt'):
        with open(f'{exp_config["logs_path"]}/log.txt', 'w') as file:
            file.close()

    if not os.path.exists(f'{exp_config["logs_path"]}/best.txt'):
        with open(f'{exp_config["logs_path"]}/best.txt', 'w') as file:
            file.write('0')
            file.close()

    # Append the mdice metric to log file
    with open(f'{exp_config["logs_path"]}/log.txt', 'a') as file:
        file.write(str(test_mdice) + '\n')
        file.close()
    
    # Fetch the best mdice metric recorded until now
    fp = open(f'{exp_config["logs_path"]}/best.txt', 'r')
    best = fp.read()
    fp.close()
    
    if test_mdice > float(best):
        # Update the new best mdice in the file
        fp = open(f'{exp_config["logs_path"]}/best.txt','w')
        fp.write(str(test_mdice))
        fp.close()
        # Update the new best mdice in the local variable
        fp = open(f'{exp_config["logs_path"]}/best.txt','r')
        best = fp.read()
        fp.close()
        # Save the best model found until now
        torch.save(model.state_dict(), save_path + 'CaraNet-best.pth' )
        print('[Saving Snapshot:]', save_path + 'CaraNet-best.pth', test_mdice,'[best:]',best)

#### Train Model

In [None]:
# Fetch experiment configuration from disk
with open('exp_config.json', 'r') as file:
    json_data = file.read()
    exp_config = json.loads(json_data)

# ---- build models ----
torch.cuda.set_device(0)  # set your gpu device
model = caranet().cuda()
# summary_model = torchinfo.summary(model, input_size=[6, 3, 448, 448])
# summary_model

In [None]:
# Fetch experiment configuration from disk
with open('exp_config.json', 'r') as file:
    json_data = file.read()
    exp_config = json.loads(json_data)

# ---- build models ----
torch.cuda.set_device(0)  # set your gpu device
model = caranet().cuda()

params = model.parameters()

if exp_config['optimizer'] == 'Adam':
    optimizer = torch.optim.Adam(params, exp_config['lr'])
else:
    optimizer = torch.optim.SGD(params, exp_config['lr'], weight_decay = 1e-4, momentum = 0.9)
    
print(optimizer)
image_root = '{}/images/'.format(exp_config['train_path'])
gt_root = '{}/masks/'.format(exp_config['train_path'])

train_loader = get_loader(image_root, gt_root, batchsize=exp_config['batchsize'], trainsize=exp_config['trainsize'], augmentation = exp_config['augmentation'])
total_step = len(train_loader)

print("#"*20, "Start Training", "#"*20)

train_start_time = time.time()
for epoch in range(1, exp_config['epoch'] + 1):
    adjust_lr(optimizer, exp_config['lr'], epoch, 0.1, 200)
    train(train_loader, model, optimizer, epoch, exp_config['evaluate_path'], exp_config, total_step)
    vram_used = get_vram_usage()
    dict_memory = {
        'epoch': [epoch]
        , 'VRAM': [vram_used]
    }

    # print(dict_memory, flush=True
    df_memory = pd.DataFrame.from_dict(dict_memory)
    filename = f'{exp_config["logs_path"]}/cnn_memory_centralized.csv'
    df_memory.to_csv(filename, index=False, mode='a', header=not os.path.exists(filename))

train_end_time = time.time()
train_elapsed_time = train_end_time - train_start_time

#### Log Model Results

In [None]:
# Open the log file and fetch the epoch with the best dice value
with open(f'{exp_config["logs_path"]}/log.txt', 'r') as file:
    contents = file.read()
    data = np.array(contents.split('\n'))
    best_epoch = data.argmax()

# Export experiment summary to disk
results_df = pd.read_csv(f'{exp_config["logs_path"]}/results_train_wide.csv')
results_df = results_df[(results_df['exp_name']==exp_config['exp_name']) & (results_df['epoch']==best_epoch)]
exp_config_simplified = {key: value for key, value in exp_config.items() 
                         if key not in ('train_path', 'test_path', 'evaluate_path', 'models_path', 'logs_path')}

dict_exp_summary = {
    'exp_name': [exp_config['exp_name']]
    , 'exp_type': [exp_config['exp_type']]
    , 'exp_resource': [exp_config['exp_resource']]
    , 'exp_device': [exp_config['exp_device']]
    , 'exp_configuration': [json.dumps(exp_config_simplified)]
    , 'elapsed_time': [float(f'{train_elapsed_time:.4f}')]
    , 'epoch_best_model': [best_epoch]
    , 'train_loss_mean': [results_df['train_loss'].values[0]]
    , 'train_mdice_mean': [results_df['train_mdice'].values[0]]
    , 'test_mdice_mean': [results_df['test_mdice'].values[0]]
}

#### Test Model

In [None]:
testsize = 352
best_model_path = f'{exp_config["models_path"]}/CaraNet-best.pth'
data_path = '../../../data/inputs/kvasir/test'
logs_path = '../../../data/logs/experiments/cnn_caranet_kvasir/masks/'
image_root = f'{data_path}/images/'
gt_root = f'{data_path}/masks/'
test_loader = test_dataset(image_root, gt_root, testsize)

# Remove the masks directory and its contents if existing from a previous run
if os.path.isdir(logs_path):
    shutil.rmtree(logs_path)

# Create a clean masks directory
os.mkdir(logs_path)

## Instantiate a model with the weights from the best model trained
model = caranet()

# Iterate over the stored model parameters and create a state_dict
weights = torch.load(best_model_path)
new_state_dict = OrderedDict()

for k, v in weights.items():
    if 'total_ops' not in k and 'total_params' not in k:
        name = k
        new_state_dict[name] = v
    
# Load the weights in the model
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()

# Perform inference with the stored model
inference_results = evaluate(model, exp_config['test_path'])

## For each image in the dataset generate its corresponding predicted mask and output to disk
for i in range(test_loader.size):
    image, gt, name = test_loader.load_data()
    gt = np.asarray(gt, np.float32)
    gt /= (gt.max() + 1e-8)
    image = image.cuda()

    # res = model(image)
    res5, res4, res2, res1 = model(image)
    res = res5
    res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
    res = res.sigmoid().data.cpu().numpy().squeeze()
    res = (res - res.min()) / (res.max() - res.min() + 1e-8)
    
    # misc.imsave(save_path+name, res)
    cv2.imwrite(logs_path+name, res*255) # Multiply by 255 to be able to write to file (Images are normalized)

In [None]:
exp_summary_df = pd.DataFrame.from_dict(dict_exp_summary, orient='columns')
filename_summary = f'{exp_config["logs_path"]}/exp_summary_train.csv'
exp_summary_df.to_csv(filename_summary, index=False, mode='a', header=not os.path.exists(filename_summary))

# Logs Clean-Up
os.remove(f'{exp_config["logs_path"]}/best.txt')
os.remove(f'{exp_config["logs_path"]}/log.txt')