In [1]:
import os
from datetime import datetime
import traceback

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import torchvision
import numpy as np

from tqdm import tqdm
from tensorboardX import SummaryWriter

from dataset import voc
from retinanet import model, val
from retinanet import transforms as aug


In [2]:
# consts
tag = 'debug'
split_name = 'voc-1'
root_path = '/home/voyager/data/root/voc/'

device_name = 'cuda'
batch_size = 24
depth = 50

epochs = 500
lr = 1e-5
patience = 3

image_size = 512
num_classes = 2
num_workers = 8

# info and deps
now = datetime.now()

if not os.path.exists('./result'):
    os.mkdir('./result')

result_path = './result/{}_{}_{}_{}'.format(
    tag,
    depth,
    split_name,
    now.strftime('%Y%m%d_%H%M%S')
)

summary_writer = SummaryWriter(result_path)


In [3]:
# data loader

# train

# transforms for train
train_trans = aug.Compose([
    aug.RandomCropAndPad(),
    aug.Pad(), # pad to square image
    aug.Resize(image_size, image_size),
    aug.RandomFlipLeftRight(0.5),
    aug.RandomFlipUpDown(0.5),
    aug.RandomRotate(5),
    aug.RandomTranslatePc(50, 50),
    aug.AutoLevel(min_level_rate=1, max_level_rate=1),
    aug.AutoContrast(),
    aug.RandomContrast(0.5),
    aug.Contrast(1.25),
    aug.RandomChoice([
        aug.RandomSaltPepperNoise(0.9, 0.5),
        aug.RandomSaltPepperNoise(0.95, 0.5),
        aug.RandomSaltPepperNoise(0.99, 0.5)
    ]),
    aug.ToTensor(),
    # aug.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) # mean and std of pretrained model
])

train_set = voc.VOCDetection(
    root_path,
    image_set="{}_train".format(split_name),
    transforms=train_trans
)

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=voc.collate,
    num_workers=num_workers
)


In [4]:
# model, loss and optimizer

device = torch.device(device_name)

if depth == 34:
    net = model.resnet34(num_classes, pretrained=True)
elif depth == 50:
    net = model.resnet50(num_classes, pretrained=True)
elif depth == 101:
    net = model.resnet101(num_classes, pretrained=True)
elif depth == 152:
    net = model.resnet152(num_classes, pretrained=True)
    
net = net.to(device)
net = torch.nn.DataParallel(net).to(device)
net.training = True

optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    patience=patience,
    verbose=True
)

net.train()
net.module.freeze_bn()


In [5]:
# train-val loop

for epoch in range(epochs):
    print('training epoch {}:'.format(epoch))
    
    # train
    net.train()
    net.module.freeze_bn()
    
    epoch_loss = []
    
    with tqdm(total=len(train_loader)) as pbar:
        for i, data in enumerate(train_loader):
            try:
                optimizer.zero_grad()

                # convert annos to focal loss format - [x1, y1, x2, y2, cls]
                batch_annos = []
                max_anno_count = 0

                for batch in range(len(data[1])):
                    bboxes = data[1][batch][1]
                    bboxes = bboxes.to(dtype=torch.long)
                    labels = data[1][batch][0]
                    labels = labels.to(dtype=torch.long)

                    # bboxes.shape = [4, n]
                    bboxes = torch.t(bboxes)
                    # labels.shape = [1, n]
                    labels = torch.unsqueeze(labels, 0)

                    # annos.shape = [n, 5]
                    annos = torch.cat((bboxes, labels), 0)
                    annos = torch.t(annos)

                    # record max anno count
                    anno_count = annos.shape[0]

                    if anno_count > max_anno_count:
                        max_anno_count = anno_count

                    batch_annos.append(annos)

                dummy_anno = torch.tensor([[0, 0, 0, 0, -1]])
                padded_batch_annos = []

                for anno in batch_annos:
                    if anno.shape[0] < max_anno_count:
                        dummy_count = max_anno_count - anno.shape[0]

                        for i in range(dummy_count):
                            anno = torch.cat((anno, dummy_anno), 0)

                    padded_batch_annos.append(anno)

                padded_batch_annos = torch.stack(padded_batch_annos)
                padded_batch_annos = padded_batch_annos.to(dtype=torch.float32)

                # forward
                classification_loss, regression_loss = net([
                    data[0].to(device),
                    padded_batch_annos
                ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss
                epoch_loss.append(loss.item())

                if loss == 0:
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(net.parameters(), 0.1)

                optimizer.step()

                del classification_loss
                del regression_loss
                
                pbar.update(1)
            except Exception as e:
                traceback.print_exc()
                break

    # epoch-wise work and record
    mean_loss = np.mean(epoch_loss)
    print('epoch avg loss: {}'.format(mean_loss))

    scheduler.step(mean_loss)

    # save checkpoint
    torch.save(net.module, os.path.join(result_path, '{:0>3}_{:1.4f}.pth'.format(
        epoch,
        mean_loss
    )))

    # write summary for tensorboardX
    summary_writer.add_scalar(
        '/train/loss',
        mean_loss,
        epoch
    )
    

  0%|          | 0/53 [00:00<?, ?it/s]

training epoch 0:


100%|██████████| 53/53 [01:23<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 1.5287253451797198
training epoch 1:


100%|██████████| 53/53 [01:22<00:00,  1.43s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 1.0892193036259346
training epoch 2:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.8620456817015162
training epoch 4:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.8048656774017046
training epoch 5:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.771808885178476
training epoch 6:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.7422746914737629
training epoch 7:


100%|██████████| 53/53 [01:21<00:00,  1.42s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.708106242260843
training epoch 8:


100%|██████████| 53/53 [01:24<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6882835525386738
training epoch 9:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6676934530150216
training epoch 10:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6577729324124894
training epoch 11:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.657033641383333
training epoch 12:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6320350496274121
training epoch 13:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6201697970336338
training epoch 14:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.615825244278278
training epoch 15:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.6124953368924698
training epoch 16:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5971171844680354
training epoch 17:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5809080021561317
training epoch 18:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5874122088810183
training epoch 19:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5780279130305884
training epoch 20:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5673393151670132
training epoch 21:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5600070289845737
training epoch 22:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5578111516979506
training epoch 23:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5551173017834717
training epoch 24:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5509087263413195
training epoch 25:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5364664399398947
training epoch 26:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5405376980889518
training epoch 27:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5250548967775309
training epoch 28:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5282024662449675
training epoch 29:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5235367113689207
training epoch 30:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5232248711136153
training epoch 31:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.514684575346281
training epoch 32:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5095626447560653
training epoch 33:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.5079332606972389
training epoch 34:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.49925866048291045
training epoch 35:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.49887595424112285
training epoch 36:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.49741398725869523
training epoch 37:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.49210639157385194
training epoch 38:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4894803245112581
training epoch 39:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.49126329849351125
training epoch 40:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4884074127899026
training epoch 41:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4796746502507408
training epoch 42:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4750060039871144
training epoch 43:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4763828397921796
training epoch 44:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.47349082580152546
training epoch 45:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.47008697435540975
training epoch 46:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.465011402683438
training epoch 47:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4658107841914555
training epoch 48:


100%|██████████| 53/53 [01:21<00:00,  1.39s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4563372146408513
training epoch 49:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4571021542234241
training epoch 50:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4547094415943578
training epoch 51:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.45800411869894786
training epoch 52:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.46021227060623887
training epoch 53:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.45650706707306626
training epoch 54:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.4511640212445889
training epoch 55:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.44619950764584093
training epoch 56:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.44431618980641635
training epoch 57:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.44425416048967614
training epoch 58:


  8%|▊         | 4/53 [00:09<02:05,  2.56s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33863560602350057
training epoch 273:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33121828762990124
training epoch 274:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33503691036746186
training epoch 275:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3350005464733772
training epoch 276:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33314693423936953
training epoch 277:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3364024072323205
training epoch 278:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3377774363418795
training epoch 279:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3342069241235841
training epoch 280:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33183846833570946
training epoch 281:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3345115302868609
training epoch 282:


100%|██████████| 53/53 [01:21<00:00,  1.37s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3374768225651867
training epoch 283:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33565797164755046
training epoch 284:


100%|██████████| 53/53 [01:21<00:00,  1.37s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.332783136165367
training epoch 285:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33613090526382877
training epoch 286:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33823301657190863
training epoch 287:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33371305971775417
training epoch 288:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33768469348268687
training epoch 289:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33760498829607694
training epoch 290:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3367223728377864
training epoch 291:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.331068346522889
training epoch 292:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3358991382256994
training epoch 293:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33385424119121626
training epoch 294:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33729419112205505
training epoch 295:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3361604972830359
training epoch 296:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33768786405617335
training epoch 297:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3348439948738746
training epoch 298:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33531855468480093
training epoch 299:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33620073154287516
training epoch 300:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3360824781768727
training epoch 301:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3328750099775926
training epoch 302:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3344792414386317
training epoch 303:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33292340501299444
training epoch 304:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33555994000074996
training epoch 305:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33626229864246443
training epoch 306:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3377102157979641
training epoch 307:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33433655185519523
training epoch 308:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33644774837313957
training epoch 309:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33480882532191725
training epoch 310:


100%|██████████| 53/53 [01:20<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.335432694205698
training epoch 311:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33575876703802143
training epoch 312:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3354730431763631
training epoch 313:


100%|██████████| 53/53 [01:20<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33915246545143846
training epoch 314:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3402148142175854
training epoch 315:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33453306949363565
training epoch 316:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33582073337626905
training epoch 317:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33227641346319664
training epoch 318:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3385339750433868
training epoch 319:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33595507887174497
training epoch 320:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3338464228612072
training epoch 321:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33810656531801764
training epoch 322:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33337331774099815
training epoch 323:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33760368542851144
training epoch 324:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3352693819774772
training epoch 325:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.33467131234564873
training epoch 326:


100%|██████████| 53/53 [01:21<00:00,  1.38s/it]
  0%|          | 0/53 [00:00<?, ?it/s]

epoch avg loss: 0.3351051205734037
training epoch 327:


 19%|█▉        | 10/53 [00:18<01:09,  1.61s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

