# danbooru tagger training

this nb is used to show a training process with a small sample of 12 train and 4 val.  
the actual training was performed with ~1000 train and ~100 val.

In [1]:
from effnet_tagger import *

## defs

In [2]:
def get_data(train_ds, val_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True,
                   #  drop_last=True
                   ),
        DataLoader(val_ds, batch_size=bs * 2)
    )

In [3]:
def get_model(msd=None, osd=None):
    model = EffnetTagger().to(dev)

    if msd is not None:
        model.load_state_dict(torch.load(msd))

    optimizer = optim.AdamW(
        model.parameters(),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=0.01,
        amsgrad=False
    )

    if osd is not None:
        optimizer.load_state_dict(torch.load(osd))

    return model, optimizer

In [4]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

In [5]:
def fit(model, opt, train_dl, val_dl, res=pd.DataFrame(), epochs=10, loss_func=nn.MSELoss()):

    loss_data = {
        'train': [],
        'val': []
    }
    
    msd, osd = [
        Path('nn-params') / f'{var}_state_dict'
        for var in ['model', 'opt']
    ]

    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            xb.to(dev)
            yb.to(dev)

            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb, opt)
                  for xb, yb in train_dl]
            )

            torch.save(model.state_dict(), msd)
            torch.save(opt.state_dict(), osd)

        train_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        loss_data['train'].append(train_loss)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb)
                  for xb, yb in val_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        loss_data['val'].append(val_loss)

        print(
            f'epoch: {epoch} | train MSE: {train_loss:.4f} | val MSE: {val_loss:.4f}')

        pd.concat(
            [res, pd.DataFrame(loss_data)]
        ).to_csv('res.csv', index=False)

## loop

In [6]:
train_dir = Path('data') / 'train'
train_ids = [int(f.stem) for f in train_dir.glob('*')]
train_labels = all_labels[all_labels.id.isin(train_ids)]
train_ds = DanbooruDataset(label_data=train_labels, img_dir=train_dir)

In [7]:
val_dir = Path('data') / 'val'
val_ids = [int(f.stem) for f in val_dir.glob('*')]
val_labels = all_labels[all_labels.id.isin(val_ids)]
val_ds = DanbooruDataset(label_data=val_labels, img_dir=val_dir)

In [8]:
train_dl, val_dl = get_data(train_ds, val_ds, 2)
model, opt = get_model(
#     *[Path('nn-params') / f'{var}_state_dict'
#         for var in ['model', 'opt']]
)
# res = pd.read_csv('res.csv')

Using cache found in C:\Users\Morshay/.cache\torch\hub\pytorch_vision_v0.12.0


In [9]:
fit(model, opt, train_dl, val_dl,
    # res
    )

epoch: 0 | train MSE: 0.1637 | val MSE: 0.1351
epoch: 1 | train MSE: 0.1251 | val MSE: 0.1375
epoch: 2 | train MSE: 0.0636 | val MSE: 0.1071
epoch: 3 | train MSE: 0.0405 | val MSE: 0.0333
epoch: 4 | train MSE: 0.0173 | val MSE: 0.0116
epoch: 5 | train MSE: 0.0110 | val MSE: 0.0088
epoch: 6 | train MSE: 0.0083 | val MSE: 0.0091
epoch: 7 | train MSE: 0.0094 | val MSE: 0.0089
epoch: 8 | train MSE: 0.0098 | val MSE: 0.0094
epoch: 9 | train MSE: 0.0119 | val MSE: 0.0102


## MSE plot

In [10]:
alt.Chart(
    pd.read_csv(
        'res.csv'
    ).reset_index(
    ).rename(
        {'index': 'epoch'},
        axis='columns'
    ).melt(
        id_vars=['epoch'],
        value_vars=['train', 'val'],
        var_name='stage',
        value_name='MSE loss'
    )
).mark_line(
).encode(
    x='epoch:Q',
    y='MSE loss:Q',
    color=alt.Color(
        'stage',
        scale=alt.Scale(
            scheme='viridis'
        )
    )
)