In [1]:
%run C:/Users/MohammedSB/Desktop/projects/Utils/Utils.ipynb

In [None]:
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [None]:
import torch.nn.functional as F
    
def one_epoch(dataloader, models, device, criterion, optimizer=None, show_output=True):
    
    image_model, tabular_model, fusion_model = models["image_model"], models["tabular_model"], models["fusion_model"]
    
    metrics = dict()
    
    training = optimizer is not None

    if image_model: image_model.train(training)
    if tabular_model: tabular_model.train(training)
    if fusion_model: fusion_model.train(training)    
        
    count, total_loss, correct = 0, 0, 0
    preds, probs, targets = torch.tensor([]), torch.tensor([]), torch.tensor([])

    with torch.set_grad_enabled(training):
        for sample in tqdm(dataloader, desc="Batch in Progress", ascii=False, ncols = 100, disable=not show_output):
            
            sig = nn.Sigmoid()

            img, features, target = sample["image"], sample["features"], sample["label"]
            img, features, target = img.to(device).float(), features.to(device).float(),\
                                    target.to(device).float()
        
            target = target.unsqueeze(dim=-1)

            # forward tree
            if fusion_model and image_model == None and tabular_model == None: # Train just fusion
                output = fusion_model(features)
            elif fusion_model: # Train all paths
                image_features = image_model(img)
                tabular_features = tabular_model(features)
                combined_features = torch.cat((image_features, tabular_features), 1)
                output = fusion_model(combined_features)
            elif image_model:
                output = image_model(img)
            else:
                output = tabular_model(features)
                
            loss = criterion(output, target)

            # backward            
            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # bookkeeping
            sig = torch.nn.Sigmoid()  
            total_loss += loss * len(img)
            prob = sig(output.to('cpu').detach())
            pred = map(classify, prob)
            pred = torch.tensor(list(pred), device=device)
            correct += (pred == target.squeeze(1)).sum()

            pred = pred.cpu()
            target = target.cpu()
            probs = torch.cat((probs, prob))
            preds = torch.cat((preds, pred))
            targets = torch.cat((targets, target))

    metrics["Loss"] = total_loss.item()
    metrics["Average Loss"] = (total_loss.item() / len(dataloader.dataset))
    metrics["Correct"] = correct.item()
    metrics["Accuracy"] = (correct.item() / len(dataloader.dataset)) * 100
    metrics["Size"] = len(dataloader.dataset)

    # Precision, Recall, and F1        
    metrics["Precision"] = precision_score(targets, preds, zero_division=1) 
    metrics["Recall"] = recall_score(targets, preds, zero_division=1) 
    metrics["F1 Score"] = f1_score(targets, preds) 
    
    metrics["y_prob"] = probs
    metrics["y_true"] = targets
    metrics["y_pred"] = preds 
    
    return metrics
    
def train_val(epochs, models, criterion, optimizer, train_loader, val_loader, device, early_stop=None, show_output=True):
    print("Beginning Training: \n")
    metrics_train = dict()
    metrics_val = dict()
    metrics = dict()    
    
    if early_stop:
        early_stopping = EarlyStopper(patience=early_stop["patience"], min_delta=early_stop["min_delta"], multip=early_stop["multip"])
    
    
    for epoch in range(1, epochs + 1):
        
        print(f'Epoch {epoch}/{epochs}')
        metrics = one_epoch(train_loader, models, device, criterion, optimizer, show_output=show_output)
        print("Train Set:")
        show_metrics(metrics)

        metrics_train[epoch] = metrics
        
        metrics = one_epoch(val_loader, models, device, criterion, show_output=show_output)
        print("Validation Set:")
        show_metrics(metrics)
        
        metrics_val[epoch] = metrics
        
        if early_stop:
            early_stopping(metrics_val[epoch]["Average Loss"])
            if early_stopping.stop(epoch):
                break
                
    metrics = [metrics_train, metrics_val]
    
    return metrics

def train(epochs, models, criterion, optimizer, train_loader, device):
    print("Beginning Training: \n")
    metrics_train = dict()
    metrics = dict()    
    
    for epoch in range(1, epochs + 1):
        
        print(f'Epoch {epoch}/{epochs}')
        metrics = one_epoch(train_loader, models, device, criterion, optimizer)
        print("Train Set:")
        show_metrics(metrics)

        metrics_train[epoch] = metrics
        
    return metrics_train


def test(models, criterion, test_loader, device, show_output=True):
    
    metrics = one_epoch(test_loader, models, device, criterion, show_output=show_output)
    if show_output:
        show_metrics(metrics)  
    
    return metrics