In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np


def genDataFunction(x, lam=0.5, const=10):
    return lam * np.exp(-lam * x) * const + np.random.random()/100


class MyDataset(Dataset):
    def __init__(self):
        self.train_x = list((np.random.random(10000) * 10).astype(int))
        self.train_y = [genDataFunction(x) for x in self.train_x]
        self.test_x = list((np.random.random(2000) * 10).astype(int))
        self.test_y = [genDataFunction(x) for x in self.test_x]


    def __getitem__(self, index):
        x = self.train_x[index]
        y = self.train_y[index]
        return x, y

    def __len__(self):
        assert len(self.train_x) == len(self.train_y)
        return len(self.train_x)


class DataProvider:
    def __init__(self, batch_size=128):
        self.batch_size = batch_size
        self.dataset = MyDataset()
        self.dataiter = None

        self.train_len = self.dataset.__len__()
        self.train_batch_num = self.train_len // self.batch_size

    def build(self):
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=self.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        self.dataiter = iter(self.dataloader)

    def next(self):
        if self.dataiter is None:
            self.build()
        try:
            batch = next(self.dataiter)
            batch = (batch[0].float(), batch[1].float())
            return batch
        except StopIteration:
            self.build()
            batch = next(self.dataiter)
            batch = (batch[0].float(), batch[1].float())
            return batch


provider = DataProvider()

In [2]:
xs, ys = provider.next()
print(xs[:5], ys[:5])

tensor([6., 2., 1., 1., 6.]) tensor([0.2533, 1.8440, 3.0354, 3.0380, 0.2583])


In [3]:
class FitModel(nn.Module):
    '''
        define model here
    '''
    def __init__(self):
        super(FitModel, self).__init__()
        self.lam = nn.Parameter(torch.FloatTensor([1.]), requires_grad=True)
        self.const = nn.Parameter(torch.FloatTensor([1.]), requires_grad=True)

    def forward(self, x):
        return self.lam * torch.exp(-self.lam * x) * self.const


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        if 'lr' in param_group.keys():
            return param_group['lr']


class LogMSELoss(nn.Module):
    '''
        define custom loss function here
    '''
    def __init__(self):
        super(LogMSELoss, self).__init__()

    def forward(self, preds, targets):
        return torch.mean(torch.pow(torch.log(preds) - torch.log(targets), 2))


In [4]:
def fit(criterionFunc = "MSE", epochMax = 50):
    if criterionFunc == "MSE":
        criterion = nn.MSELoss()
    elif criterionFunc == "LogMSE":
        criterion = LogMSELoss()
    else:
        raise NotImplementedError

    model = FitModel()  # create a new model (retrain parameters)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3, verbose=True)

    WATCH_LOSS_PER_BATCH = 5
    for epoch in range(epochMax):
        sum_loss = 0.0
        epoch_loss = 0.0

        for batch_no in range(1, provider.train_batch_num + 1):
            xs, ys = provider.next()
            pred = model.forward(xs)
            loss = criterion(pred, ys)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.item()
            epoch_loss += loss.item()

            if batch_no % WATCH_LOSS_PER_BATCH == 0 or batch_no == provider.train_batch_num:
                print("lam = {:5f}".format(model.lam.item()))
                print("con = {:5f}".format(model.const.item()))

                print("[epoch:{}, batch:{}/{}] loss: {:.5f} lr: {:.5f}".format(
                    epoch, batch_no, provider.train_batch_num, sum_loss / WATCH_LOSS_PER_BATCH, get_lr(optimizer)))
                sum_loss = 0.0
        scheduler.step(epoch_loss)


In [5]:
fit(criterionFunc='MSE', epochMax=50)

 0.00151 lr: 0.01000
lam = 0.515677
con = 9.695725
[epoch:34, batch:40/78] loss: 0.00154 lr: 0.01000
lam = 0.515414
con = 9.699303
[epoch:34, batch:45/78] loss: 0.00147 lr: 0.01000
lam = 0.515174
con = 9.702762
[epoch:34, batch:50/78] loss: 0.00138 lr: 0.01000
lam = 0.514885
con = 9.706242
[epoch:34, batch:55/78] loss: 0.00144 lr: 0.01000
lam = 0.514142
con = 9.709711
[epoch:34, batch:60/78] loss: 0.00139 lr: 0.01000
lam = 0.515126
con = 9.713155
[epoch:34, batch:65/78] loss: 0.00133 lr: 0.01000
lam = 0.513491
con = 9.716452
[epoch:34, batch:70/78] loss: 0.00126 lr: 0.01000
lam = 0.514540
con = 9.719794
[epoch:34, batch:75/78] loss: 0.00127 lr: 0.01000
lam = 0.514088
con = 9.721755
[epoch:34, batch:78/78] loss: 0.00075 lr: 0.01000
lam = 0.513274
con = 9.724997
[epoch:35, batch:5/78] loss: 0.00123 lr: 0.01000
lam = 0.514230
con = 9.728339
[epoch:35, batch:10/78] loss: 0.00130 lr: 0.01000
lam = 0.513498
con = 9.731589
[epoch:35, batch:15/78] loss: 0.00121 lr: 0.01000
lam = 0.513270
con =

In [6]:
fit(criterionFunc="LogMSE", epochMax=100)

[epoch:84, batch:50/78] loss: 0.00058 lr: 0.01000
lam = 0.489746
con = 9.967126
[epoch:84, batch:55/78] loss: 0.00069 lr: 0.01000
lam = 0.491319
con = 9.967559
[epoch:84, batch:60/78] loss: 0.00048 lr: 0.01000
lam = 0.491984
con = 9.968042
[epoch:84, batch:65/78] loss: 0.00059 lr: 0.01000
lam = 0.489747
con = 9.968682
[epoch:84, batch:70/78] loss: 0.00071 lr: 0.01000
lam = 0.490861
con = 9.969113
[epoch:84, batch:75/78] loss: 0.00054 lr: 0.01000
lam = 0.491598
con = 9.969367
[epoch:84, batch:78/78] loss: 0.00037 lr: 0.01000
lam = 0.491880
con = 9.969896
[epoch:85, batch:5/78] loss: 0.00046 lr: 0.01000
lam = 0.489827
con = 9.970576
[epoch:85, batch:10/78] loss: 0.00062 lr: 0.01000
lam = 0.490170
con = 9.971026
[epoch:85, batch:15/78] loss: 0.00061 lr: 0.01000
lam = 0.492483
con = 9.971365
[epoch:85, batch:20/78] loss: 0.00056 lr: 0.01000
lam = 0.491331
con = 9.971945
[epoch:85, batch:25/78] loss: 0.00061 lr: 0.01000
lam = 0.490858
con = 9.972493
[epoch:85, batch:30/78] loss: 0.00053 lr: