In [None]:
from torchmetrics import MeanMetric
import tqdm 
import torch
import torch.nn as  nn
import os


def train_one_epoch (model, train_loader, loss_fn, optimizer, metric, metric2, metric3, scheduler=None, epoch=None):
    model.train()
    loss_train = MeanMetric()
    metric.reset()
    metric2.reset()
    metric3.reset()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
  

    with tqdm.tqdm(train_loader, unit='batch') as tepoch:
        for inputs, targets in tepoch:
            if epoch is not None:
                tepoch.set_description(f"Epoch{epoch}")

            inputs = inputs.to(device)
            targets = targets.to(device)

           
            outputs = model(inputs)
            loss = loss_fn(outputs, targets.int())
            loss.backward()
            optimizer.step()
                
            optimizer.zero_grad()

            loss_train.update(loss.item(), weight=len(train_loader))
            metric.update(outputs, targets.int())
            #outputs2 = (torch.sigmoid(outputs) > 0.5).long()
            metric2.update(outputs, targets.int())
            metric3.update(outputs, targets.int())

            tepoch.set_postfix(loss = loss_train.compute().item(),
                              LR = optimizer.param_groups[0]['lr'],
                              metric = metric.compute().item(),
                              metric2 = metric2.compute().item(),
                              metric3 = metric3.compute().item())

            if scheduler:
                scheduler.step()

        return model, loss_train.compute().item(), metric.compute().item(), metric2.compute().item(), metric3.compute().item()
            

def evaluate(model, test_dl, loss_fn, metric, metric2, metric3):
  model.eval()
  loss_test = MeanMetric()
  metric.reset()
  metric2.reset()
  metric3.reset()
  for x_batch, y_batch in test_dl:
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    y_pred = model(x_batch)
    #y_pred = torch.sigmoid(y_pred)
    loss = loss_fn(y_pred, y_batch.int())

    loss_test.update(loss.item(), weight=len(y_batch))
    #y_pred_bin = torch.sigmoid(y_pred)
    metric(y_pred, y_batch.int())
    #y_pred2 = (torch.sigmoid(y_pred) > 0.5).long()
    metric2(y_pred, y_batch.int())
    metric3(y_pred, y_batch.int())
      
  return loss_test.compute().item(), metric.compute().item(), metric2.compute().item(), metric3.compute().item()



def train_model(model, train_loader, val_loader, loss_fn, optimizer, metric, metric2, metric3, num_epochs, scheduler, checkpoint_dir, patience=25):

    model.to(device)
    os.makedirs(checkpoint_dir, exist_ok = True)
    early_stopping = EarlyStopping(patience=patience, verbose=True)


    best_val_loss = torch.inf
    val_loss_history = []
    val_metric_history = []
    val_metric2_history = []
    val_metric3_history = []
    for epoch in range(num_epochs):
        model, train_loss, train_metric, train_metric2, train_metric3= train_one_epoch(model, train_loader, loss_fn, optimizer, metric, metric2, metric3, scheduler=scheduler, epoch=epoch) 
        val_loss, val_metric, val_metric2,  val_metric3 = evaluate(model,
                                      val_loader,
                                      loss_fn,
                                      metric,
                                      metric2,
                                      metric3
                                    )
        val_loss_history.append(val_loss)
        val_metric_history.append(val_metric)
        val_metric2_history.append(train_metric2)
        val_metric3_history.append(train_metric3)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val DICE: {val_metric:.4f}, IOU Metric:{val_metric2: .4f}, F1 Metric:{val_metric3:.4f}")
       

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model_version_Unet++_v04.pt')
            torch.save({
                "epoch": epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_metric': val_metric
            }, checkpoint_path)
            print(f"Bets model save to {checkpoint_path} with val loss: {val_loss:.4f}")

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print('Early stopping triggered')
            break
    print('Training Complete!!')
    return  best_val_loss, val_loss_history, val_metric_history, val_metric2_history , val_metric3_history

