In [1]:
import os
import torch
import torch.nn.functional as F

## 1.Metrics

In [2]:
def mean_absolute_error(y_true, y_pred):
    return torch.mean(torch.abs(y_true - y_pred))

In [3]:
def mse_error(y_true, y_pred):#y_pred : prob
    y_true = y_true.float()
    y_pred = y_pred.float()
    pred = F.softmax(y_pred,dim=1)
    _, y_pred = torch.max(pred, dim = 1)
    
    error = torch.mean((y_pred - y_true) ** 2)
    
    return error

In [4]:
def mean_absolute_percentage_error(y_true, y_pred, epsilon=1e-8):
    return torch.mean(torch.abs((y_true - y_pred) / (y_true + epsilon))) * 100

In [5]:
def r2_score(y_true, y_pred):
    ss_res = torch.sum((y_true - y_pred) ** 2)
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
    r2_score = 1 - ss_res / ss_tot
    return r2_score

## 2.Training Helper Functions

### *2.1 Load DATA on GPU*

In [6]:
def data_gpu(data, device):
    if isinstance(data, dict):
        for k, v in data.items():
            if isinstance(v, torch.Tensor):
                data[k] = v.detach().to(device,non_blocking=True)

    elif isinstance(data, torch.Tensor):
        data = data.detach().to(device,non_blocking=True)

    return data

### *2.2 Checkpoint Saving&Loading*

In [7]:
def save_checkpoint(state, savepath, flag):
    if not os.path.isdir(savepath):
        os.makedirs(savepath, 0o777)
    if flag == "mean":
        filename = os.path.join(savepath, "best_mean_ckpt.pth.tar")
    elif flag == "median":
        filename = os.path.join(savepath, "best_median_ckpt.pth.tar")
    elif flag == "newest":
        filename = os.path.join(savepath, "newest_ckpt.pth.tar")
    torch.save(state, filename)


def load_checkpoint(savepath,flag):
    if flag == "mean":
        filename = os.path.join(savepath, "best_mean_ckpt.pth.tar")
    elif flag == "median":
        filename = os.path.join(savepath, "best_median_ckpt.pth.tar")
    elif flag == "newest":
        filename = os.path.join(savepath, "newest_ckpt.pth.tar")
    if not os.path.isfile(filename):
        return None
    state = torch.load(filename)
    return state

### *2.3 Log Average Meter*

In [8]:
class AverageMeter(object):
    """Computes and stores the average, median, and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.values = []

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.values.extend([val] * n)

    def median(self):
        sorted_values = sorted(self.values)
        n = len(sorted_values)

        if n % 2 == 0:  # If there is an even number of elements
            middle1 = sorted_values[n // 2 - 1]
            middle2 = sorted_values[n // 2]
            return (middle1 + middle2) / 2
        else:  # If there is an odd number of elements
            return sorted_values[n // 2]

### *2.4 Adjust Learning Rate*

In [9]:
def adjust_learning_rate(optimizer, base_lr, decay_rate, step_size, epoch):
    """Set the learning rate to the initial LR decayed by decay_rate(ExpLR)"""
    lr = base_lr * decay_rate**(epoch//step_size)
    lr = max(lr, 0.001)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr