In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torch.optim.lr_scheduler import ExponentialLR
from sklearn import metrics as skmet
import os
import json

import transforms as my_transforms
from dataset import EchoNetFrames

In [3]:
def train_one_epoch(model, optimizer, train_dataloader, loss_function, device):
    model.train()

    num_steps_per_epoch = len(train_dataloader)

    losses = []
    for ix, batch in enumerate(train_dataloader):
        inputs = batch['img'].to(device)
        esv_true = batch['ESV'].to(device).type(torch.float32)
        edv_true = batch['EDV'].to(device).type(torch.float32)
        outputs = model(inputs)
        esv_pred = outputs[:,0] 
        edv_pred = outputs[:,1]
        loss_esv = loss_function(esv_pred, esv_true)
        loss_edv = loss_function(edv_pred, edv_true)
        loss = loss_esv + loss_edv

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach().item())
        print(f"\tBatch {ix+1} of {num_steps_per_epoch}. Loss={loss.detach().item():0.3f}", end='\r')
    
    print(' '*100, end='\r')
        
    return np.mean(losses)    
            
def evaluate(model, val_dataloader, loss_function, device):
    model.eval()

    num_steps_per_epoch = len(val_dataloader)

    esv_true_ls = []
    edv_true_ls = []
    esv_pred_ls = []
    edv_pred_ls = []
    losses = []
    for ix, batch in enumerate(val_dataloader):
        inputs = batch['img'].to(device)
        esv_true = batch['ESV'].to(device).type(torch.float32)
        edv_true = batch['EDV'].to(device).type(torch.float32)
        esv_true_ls.append(esv_true.cpu().numpy())
        edv_true_ls.append(edv_true.cpu().numpy())
        
        with torch.no_grad():
            outputs = model(inputs)
            esv_pred = outputs[:,0] 
            edv_pred = outputs[:,1]
            esv_pred_ls.append(esv_pred.cpu().numpy())
            edv_pred_ls.append(edv_pred.cpu().numpy())
            
            loss_esv = loss_function(esv_pred, esv_true)
            loss_edv = loss_function(edv_pred, edv_true)
            loss = loss_esv + loss_edv
            
        losses.append(loss.detach().item())
        
    esv_true_ar = np.concatenate(esv_true_ls)
    esv_pred_ar = np.concatenate(esv_pred_ls)
    edv_true_ar = np.concatenate(edv_true_ls)
    edv_pred_ar = np.concatenate(edv_pred_ls)
    ef_true_ar = (edv_true_ar - esv_true_ar) / edv_true_ar
    ef_pred_ar = (edv_pred_ar - esv_pred_ar) / edv_pred_ar
    metrics_esv = compute_metrics(esv_true_ar, esv_pred_ar)
    metrics_edv = compute_metrics(edv_true_ar, edv_pred_ar)
    metrics_ef = compute_metrics(ef_true_ar, ef_pred_ar)
    
    return np.mean(losses), metrics_esv, metrics_edv, metrics_ef

def compute_metrics(y_true, y_pred):
    mets = dict()
    
    mets['r2'] = skmet.r2_score(y_true, y_pred)
    mets['mae'] = skmet.mean_absolute_error(y_true, y_pred)
    mets['rmse'] = np.sqrt(skmet.mean_squared_error(y_true, y_pred))
    
    return mets

def main(cfg):
    os.makedirs(cfg['artifact_folder'], exist_ok=True)

    # save the config file to the artifact folder
    with open(cfg['artifact_folder'] + '/config.json', 'w') as f: 
        json.dump(cfg, f, indent=4)

    device = torch.device(cfg['device'])

    # transforms
    tfms = my_transforms.ImageTransforms(cfg['res'])
    tfms_train = tfms.get_transforms(cfg['transforms']['train'])
    tfms_test = tfms.get_transforms(cfg['transforms']['test'])

    # load data
    df_train = pd.read_csv(cfg['in_paths']['train'])
    df_val = pd.read_csv(cfg['in_paths']['val'])
    df_frames = pd.read_csv(cfg['in_paths']['frames'])
    
    # create datasets
    d_train = EchoNetFrames(df_train, df_frames, transforms = tfms_train, downsample_frac=cfg['downsample_frac'])
    dl_train = DataLoader(d_train, batch_size=cfg['bs_train'], num_workers=cfg['num_workers'], shuffle=True)

    d_val = EchoNetFrames(df_val, df_frames, transforms = tfms_test, downsample_frac=cfg['downsample_frac'])
    dl_val= DataLoader(d_val, batch_size=cfg['bs_val'], num_workers=cfg['num_workers'])

    print("Train data size:", len(d_train))
    print("Validation data size:", len(d_val))

    # classifier network
    m = timm.create_model(cfg['model'], pretrained=cfg['pretrained'], num_classes=2, in_chans=3, drop_rate=cfg['dropout'])
    m.to(device)

    # freeze model weights
    # don't freeze classifier or first conv/bn
    for layer in list(m.children())[2:-1]:
        for p in layer.parameters():
            p.requires_grad = False
    is_frozen=True

    # fit
    optimizer = optim.AdamP(m.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
    scheduler = ExponentialLR(optimizer, gamma=cfg['lr_gamma'])
    loss_function = torch.functional.F.mse_loss

    train_loss_ls = []
    test_loss_ls = []

    best_test_loss = 1000
    for epoch in range(cfg['num_epochs']):
        print("-"*40)
        print(f"Epoch {epoch+1} of {cfg['num_epochs']}:")

        # maybe unfreeze 
        if epoch >= cfg['unfreeze_after_n'] and is_frozen:
            print("Unfreezing model encoder.")
            is_frozen=False
            for p in m.parameters():
                p.requires_grad = True

            for g in optimizer.param_groups:
                g['lr'] = cfg['lr_unfrozen']

        # train for a single epoch
        train_loss = train_one_epoch(m, optimizer, dl_train, loss_function, device)
        train_loss_ls.append(train_loss)
        print(f"Training:")
        print(f"\tMSE loss = {train_loss:0.3f}")       

        # evaluate
        test_loss, met_esv, met_edv, met_ef = evaluate(m, dl_val, loss_function, device)
        test_loss_ls.append(test_loss)
        print(f"Test:")
        print(f"\tMSE loss = {test_loss:0.3f}")
        print(f"\tEF metrics:")
        for k, v in met_ef.items():
            print(f"\t\t{k} = {v:0.3f}")
        print(f"\tESV metrics:")
        for k, v in met_esv.items():
            print(f"\t\t{k} = {v:0.3f}")
        print(f"\tEDV metrics:")
        for k, v in met_edv.items():
            print(f"\t\t{k} = {v:0.3f}")

        if test_loss < best_test_loss:
            torch.save(m.state_dict(), f"{cfg['artifact_folder']}/model_checkpoint.ckpt")
            best_test_loss = test_loss

        scheduler.step()
        
if __name__=='__main__':
    from config import config_pretrain_echonet as cfg
    import argparse
    # parser = argparse.ArgumentParser('Pretrain frame classifier')
    # parser.add_argument('--artifact-folder', type=str, metavar='DIR', required=True,
    #                     help='path to artifact folder')
    # args = parser.parse_args()
    
    cfg['artifact_folder'] = '/zfs/wficai/pda/model_run_artifacts/echonet_pretrain' #args.artifact_folder
    main(cfg)

Train data size: 1315340
Validation data size: 228836
----------------------------------------
Epoch 1 of 10:
Unfreezing model encoder.
Training:                                                                                           
	MSE loss = 0.163
Test:
	MSE loss = 0.119
	EF metrics:
		r2 = 0.170
		mae = 0.087
		rmse = 0.110
	ESV metrics:
		r2 = 0.431
		mae = 0.138
		rmse = 0.209
	EDV metrics:
		r2 = 0.390
		mae = 0.203
		rmse = 0.276
----------------------------------------
Epoch 2 of 10:
Training:                                                                                           
	MSE loss = 0.110
Test:
	MSE loss = 0.099
	EF metrics:
		r2 = 0.313
		mae = 0.077
		rmse = 0.100
	ESV metrics:
		r2 = 0.546
		mae = 0.121
		rmse = 0.186
	EDV metrics:
		r2 = 0.484
		mae = 0.187
		rmse = 0.253
----------------------------------------
Epoch 3 of 10:
Training:                                                                                           
	MSE loss = 0.093
Test:
	MSE lo