In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torch.utils.data as data
from collections import OrderedDict
import numpy as np
import datasets.GeneralDataset as GenDataset
import time
from san_vision import transforms
import os.path as osp
from utils import AverageMeter
from model import ITN_CPM
from model import np2variable, variable2np
import torchvision.transforms as TT
from tensorboardX import SummaryWriter

In [3]:
global params
params = {
    'num_pts': 68,
    'convert68to49' : False,
    'path': 'xx', 
    'argmax_radius' : 3,
    'downsample' : 8,
    'batch_size' : 15,
    'heatmap_type' : 'gaussian',
    'dataset_name' : '300W/Original(GTB)',
    'momentum' : 0.9,
    'learning_rate' : 0.00005,
    'decay' : 0.0005,
    'total_epochs' : 100,
    'crop_width' : 256,
    'crop_height' : 256,
    'pre_crop_expand' : 0.2,
    'crop_perturb_max' : 30,
    'scale_prob' : 1.1,
    'scale_max' : 1,
    'scale_min' : 1,
    'scale_eval' : 1,
    'sigma' : 4,
    'train_list' : '/home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB',
    'resume' : 'checkpoint.pth.tar'
         }

In [4]:
def weights_init_cpm(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0, 0.01)
        if m.bias is not None: m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()

In [5]:
def remove_module_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    return new_state_dict
def load_weight_from_dict(model, weight_state_dict, param_pair=None, remove_prefix=True):
    if remove_prefix: weight_state_dict = remove_module_dict(weight_state_dict)
    all_parameter = model.state_dict()
    all_weights   = []
    finetuned_layer, random_initial_layer = [], []
    for key, value in all_parameter.items():
        if param_pair is not None and key in param_pair:
            all_weights.append((key, weight_state_dict[ param_pair[key] ]))
        elif key in weight_state_dict:
            all_weights.append((key, weight_state_dict[key]))
            finetuned_layer.append(key)
        else:
            all_weights.append((key, value))
            random_initial_layer.append(key)
#     print ('==>[load_model] finetuned layers : {}'.format(finetuned_layer))
#     print ('==>[load_model] keeped layers : {}'.format(random_initial_layer))
    all_weights = OrderedDict(all_weights)
    model.load_state_dict(all_weights)

In [6]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [7]:
def compute_loss(criterion, target_var, outputs, mask_var, total_labeled_cpm):
    total_loss = 0
    each_stage_loss = []
    mask_outputs = []
    for output_var in outputs:
        stage_loss = 0
        output = torch.masked_select(output_var, mask_var)
        target = torch.masked_select(target_var, mask_var)
    #     print(output.size(), target.size())
        mask_outputs.append(output)

        stage_loss = criterion(output, target)/(total_labeled_cpm*2)
        total_loss += stage_loss
        each_stage_loss.append(stage_loss.item())
    return total_loss, each_stage_loss
# mask_var[:,:num_pts,:,:]

In [8]:
mean_fill   = tuple( [int(x*255) for x in [0.5, 0.5, 0.5] ] )
normalize   = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                      std=[0.5, 0.5, 0.5])

In [9]:
train_transform = [transforms.PreCrop(params['pre_crop_expand'])]
train_transform = [transforms.TrainScale2WH((params['crop_width'], params['crop_height']))]
train_transform += [transforms.AugScale(params['scale_prob'], params['scale_min'], params['scale_max'])]
train_transform += [transforms.AugCrop(params['crop_width'], params['crop_height'], params['crop_perturb_max'], mean_fill)]
train_transform += [transforms.ToTensor(), normalize]
train_transform = transforms.Compose( train_transform )

In [10]:
train_data = GenDataset(train_transform, params['sigma'], params['downsample'], params['heatmap_type'], params['dataset_name'])
train_data.load_list(params['train_list'], params['num_pts'], True)
# if(params['convert68to49']): train_data.convert68to49()

The general dataset initialization done, sigma is 4, downsample is 8, dataset-name : 300W/Original(GTB), self is : GeneralDataset(number of point=-1, heatmap_type=gaussian)
Load list from /home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB
Load [0/1]-th list : /home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB with 3148 images
Start load data for the general datas
Load data done for the general dataset, which has 3148 images.


In [11]:
# eval_transform  = transforms.Compose([transforms.PreCrop(params['pre_crop_expand']),
#                                       transforms.TrainScale2WH((params['crop_width'],params['crop_height'])),
#                                       transforms.ToTensor(), normalize])

# eval_data = GenDataset(eval_transform, params['sigma'], params['downsample'], params['heatmap_type'], params['dataset_name'])
# eval_data.load_list('/home/abhirup/Datasets/300W-Style/box-coords/300W-Original/test',
#                    68, True)

In [12]:
writer = SummaryWriter(log_dir='./log')
net = ITN_CPM(params)
writer.add_graph(net)
net.apply(weights_init_cpm)
# net_param_dict = net.parameters()
net = net.cuda()

Error occurs, No graph saved


In [13]:
criterion = torch.nn.MSELoss(False)
criterion.cuda()
optimizer = torch.optim.Adam(
#                             net.parameters(), lr=params['learning_rate'],
                            net.specify_parameter(base_lr=params['learning_rate'], 
                                                  base_weight_decay=params['decay']))
#                             momentum=params['momentum'],
#                             nesterov=True)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.75**(epoch//5))

In [14]:
start_epoch=0
best_loss=99999
if (params['resume'] and osp.isfile(params['resume'])):
    print("=> loading checkpoint '{}'".format(params['resume']))
    checkpoint = torch.load(params['resume'])
    start_epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {}) current_loss = {}"
          .format(params['resume'], checkpoint['epoch'], best_loss))
else:
    model_urls = 'http://download.pytorch.org/models/vgg16-397923af.pth'
    weights = model_zoo.load_url(model_urls)
    load_weight_from_dict(net, weights, None, False);

=> loading checkpoint 'checkpoint.pth.tar'
=> loaded checkpoint 'checkpoint.pth.tar' (epoch 21) current_loss = 16.389808654785156


In [15]:
train_loader = data.DataLoader(train_data, batch_size=params['batch_size'], shuffle=True, pin_memory=False)

In [16]:
batch_time = AverageMeter()
data_time = AverageMeter()
forward_time = AverageMeter()
visible_points = AverageMeter()

In [17]:
# start_epoch = 1
end = time.time()
for epoch in range(start_epoch, params['total_epochs']):
    scheduler.step(epoch)
    print('Epoch:{}, Now learning rate:{}'.format(epoch, scheduler.get_lr()))
    for i, (inputs, target, mask, points, image_index, label_sign) in enumerate(train_loader):
        # inputs : Batch, Squence, Channel, Height, Width
        # data prepare
        target = target.cuda(async=True)
        # get the real mask
        mask.masked_scatter_((1-label_sign).unsqueeze(-1).unsqueeze(-1), torch.ByteTensor(mask.size()).zero_())
        mask_var   = mask.cuda(async=True)

        batch_size, num_pts = inputs.size(0), mask.size(1)-1
        image_index = variable2np(image_index).squeeze(1).tolist()
        # check the label indicator, whether is has annotation or not
        sign_list = variable2np(label_sign).astype('bool').squeeze(1).tolist()
        data_time.update(time.time() - end)
        cvisible_points = torch.sum(mask[:,:-1,:,:]) * 1. / batch_size
        visible_points.update(cvisible_points, batch_size)

        batch_cpms, batch_locs, batch_scos = net(inputs.cuda())

#         forward_time.update(time.time() - end)

        total_labeled_cpm = int(np.sum(sign_list))

        cpm_loss, each_stage_loss_values = compute_loss(criterion, target, batch_cpms, mask_var, total_labeled_cpm)
        writer.add_scalars('Loss', {'total loss':cpm_loss, 
                                    'stage1 loss': each_stage_loss_values[0],
                                    'stage2 loss': each_stage_loss_values[1],
                                    'stage3 loss': each_stage_loss_values[2]}, epoch*209+i)
        for name, param in net.named_parameters():
            writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch*209+i)
        #passing 'weight_of_idt' or identity loss as None
        optimizer.zero_grad()
        cpm_loss.backward()
        optimizer.step()

        if (cpm_loss < best_loss + 5 and i%10==0):
            best_loss = cpm_loss
            save_checkpoint({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_loss': best_loss,
                'optimizer' : optimizer.state_dict(),
            })
        print("batch_iter:{}, total_loss:{}, stage_losses:{}, {}, {}".format(i, cpm_loss, each_stage_loss_values[0], each_stage_loss_values[1], each_stage_loss_values[2]))
    print("End of {}th epoch\n".format(epoch))
#     if cpm_loss < best_loss:
#             save_checkpoint({
#                 'epoch': epoch,
#                 'state_dict': net.state_dict(),
#                 'best_loss': best_loss,
#                 'optimizer' : optimizer.state_dict(),
#             })

Epoch:21, Now learning rate:[1.58203125e-05, 3.1640625e-05, 1.58203125e-05, 3.1640625e-05, 1.58203125e-05, 3.1640625e-05, 6.328125e-05, 0.0001265625, 6.328125e-05, 0.0001265625]
batch_iter:0, total_loss:15.049715042114258, stage_losses:6.573525428771973, 4.20458984375, 4.271600246429443
batch_iter:1, total_loss:19.081539154052734, stage_losses:8.266315460205078, 5.330289363861084, 5.4849348068237305
batch_iter:2, total_loss:22.137542724609375, stage_losses:9.175634384155273, 6.411427974700928, 6.550479412078857
batch_iter:3, total_loss:23.48540687561035, stage_losses:10.772645950317383, 6.105295181274414, 6.607465744018555
batch_iter:4, total_loss:16.30769157409668, stage_losses:6.805949687957764, 4.577945709228516, 4.923795700073242
batch_iter:5, total_loss:15.921429634094238, stage_losses:7.456477165222168, 4.127904891967773, 4.337047576904297
batch_iter:6, total_loss:22.738784790039062, stage_losses:9.461689949035645, 6.606312274932861, 6.670783042907715
batch_iter:7, total_loss:19.

batch_iter:71, total_loss:23.048736572265625, stage_losses:10.383448600769043, 6.548351764678955, 6.116935729980469
batch_iter:72, total_loss:18.10077667236328, stage_losses:7.706736087799072, 5.077030658721924, 5.317008972167969
batch_iter:73, total_loss:16.537954330444336, stage_losses:7.32468318939209, 4.510649681091309, 4.702621936798096
batch_iter:74, total_loss:23.472980499267578, stage_losses:9.08523178100586, 7.127549648284912, 7.260199069976807
batch_iter:75, total_loss:15.96142864227295, stage_losses:6.711132526397705, 4.517739295959473, 4.732556343078613
batch_iter:76, total_loss:24.177059173583984, stage_losses:10.083646774291992, 6.937820911407471, 7.155591011047363
batch_iter:77, total_loss:16.96393585205078, stage_losses:6.988500118255615, 4.817051410675049, 5.158383846282959
batch_iter:78, total_loss:16.88089942932129, stage_losses:8.046085357666016, 4.3485107421875, 4.486303806304932
batch_iter:79, total_loss:22.284265518188477, stage_losses:9.47745132446289, 6.2535324

batch_iter:143, total_loss:17.407482147216797, stage_losses:9.226593971252441, 3.9282495975494385, 4.25263786315918
batch_iter:144, total_loss:20.88248634338379, stage_losses:10.43427848815918, 5.140512466430664, 5.3076958656311035
batch_iter:145, total_loss:21.891128540039062, stage_losses:9.049473762512207, 6.38106632232666, 6.460587978363037
batch_iter:146, total_loss:17.416492462158203, stage_losses:7.346584796905518, 4.914101600646973, 5.155805587768555
batch_iter:147, total_loss:25.971630096435547, stage_losses:9.879573822021484, 8.008023262023926, 8.08403491973877
batch_iter:148, total_loss:17.776317596435547, stage_losses:8.150806427001953, 4.720216274261475, 4.905293941497803
batch_iter:149, total_loss:18.642513275146484, stage_losses:9.18429946899414, 4.674628734588623, 4.783583641052246
batch_iter:150, total_loss:18.899005889892578, stage_losses:8.10024642944336, 5.354413032531738, 5.4443464279174805
batch_iter:151, total_loss:18.67066192626953, stage_losses:8.19082164764404

batch_iter:3, total_loss:16.110477447509766, stage_losses:7.314036846160889, 4.331753253936768, 4.464687347412109
batch_iter:4, total_loss:19.101978302001953, stage_losses:9.102346420288086, 4.864576816558838, 5.135056018829346
batch_iter:5, total_loss:20.77322006225586, stage_losses:8.804616928100586, 5.802681922912598, 6.165920734405518
batch_iter:6, total_loss:18.853626251220703, stage_losses:8.655705451965332, 4.985413074493408, 5.212508201599121
batch_iter:7, total_loss:17.630460739135742, stage_losses:8.583230018615723, 4.492417812347412, 4.554813861846924
batch_iter:8, total_loss:15.253774642944336, stage_losses:6.933773040771484, 4.082043170928955, 4.2379584312438965
batch_iter:9, total_loss:17.597877502441406, stage_losses:7.518567085266113, 4.920166969299316, 5.159142971038818
batch_iter:10, total_loss:23.833633422851562, stage_losses:9.617773056030273, 6.966115474700928, 7.249744415283203
batch_iter:11, total_loss:19.892248153686523, stage_losses:7.882940769195557, 5.8425116

batch_iter:75, total_loss:17.653884887695312, stage_losses:7.546725273132324, 5.046789169311523, 5.060370922088623
batch_iter:76, total_loss:17.709001541137695, stage_losses:7.525148391723633, 5.005068302154541, 5.17878532409668
batch_iter:77, total_loss:16.799564361572266, stage_losses:7.600728511810303, 4.464161396026611, 4.734674453735352
batch_iter:78, total_loss:21.71517562866211, stage_losses:8.671329498291016, 6.431989669799805, 6.611855506896973
batch_iter:79, total_loss:19.86014175415039, stage_losses:9.08254623413086, 5.338425159454346, 5.439170837402344
batch_iter:80, total_loss:19.49692726135254, stage_losses:8.16125774383545, 5.449502944946289, 5.886166572570801
batch_iter:81, total_loss:26.899606704711914, stage_losses:10.87236213684082, 7.835318088531494, 8.191926956176758
batch_iter:82, total_loss:25.447851181030273, stage_losses:10.583828926086426, 7.180391788482666, 7.683630466461182
batch_iter:83, total_loss:19.806415557861328, stage_losses:8.005828857421875, 5.76820

batch_iter:146, total_loss:18.326080322265625, stage_losses:7.595579624176025, 5.3178181648254395, 5.412681579589844
batch_iter:147, total_loss:23.566667556762695, stage_losses:11.801682472229004, 5.697995662689209, 6.066989421844482
batch_iter:148, total_loss:18.853010177612305, stage_losses:7.752085208892822, 5.381708145141602, 5.719217300415039
batch_iter:149, total_loss:21.158571243286133, stage_losses:9.633211135864258, 5.519448280334473, 6.005911350250244
batch_iter:150, total_loss:19.592815399169922, stage_losses:7.830234050750732, 5.795152187347412, 5.967428684234619
batch_iter:151, total_loss:14.60750961303711, stage_losses:5.893192291259766, 4.26025915145874, 4.454057216644287
batch_iter:152, total_loss:19.332712173461914, stage_losses:8.258960723876953, 5.479161262512207, 5.594589710235596
batch_iter:153, total_loss:14.281683921813965, stage_losses:5.886199951171875, 4.095258712768555, 4.300225257873535
batch_iter:154, total_loss:15.923076629638672, stage_losses:7.0768599510

batch_iter:6, total_loss:18.331871032714844, stage_losses:9.160070419311523, 4.523953437805176, 4.647848129272461
batch_iter:7, total_loss:19.139869689941406, stage_losses:8.334922790527344, 5.348844528198242, 5.456101417541504
batch_iter:8, total_loss:18.872587203979492, stage_losses:7.603793621063232, 5.497414588928223, 5.771379470825195
batch_iter:9, total_loss:16.937009811401367, stage_losses:7.029748439788818, 4.829770088195801, 5.077491283416748
batch_iter:10, total_loss:17.740793228149414, stage_losses:7.430605411529541, 5.043863296508789, 5.266324996948242
batch_iter:11, total_loss:27.237751007080078, stage_losses:10.985891342163086, 8.102035522460938, 8.149824142456055
batch_iter:12, total_loss:14.725700378417969, stage_losses:6.02839994430542, 4.265719890594482, 4.431580543518066
batch_iter:13, total_loss:26.27009391784668, stage_losses:10.605324745178223, 7.770695209503174, 7.894073963165283
batch_iter:14, total_loss:24.17293357849121, stage_losses:10.234453201293945, 6.8427

batch_iter:78, total_loss:17.04192543029785, stage_losses:7.13090705871582, 4.852120876312256, 5.058896541595459
batch_iter:79, total_loss:23.529090881347656, stage_losses:8.757772445678711, 7.352036952972412, 7.419281959533691
batch_iter:80, total_loss:18.238262176513672, stage_losses:7.511116981506348, 5.273924350738525, 5.453220844268799
batch_iter:81, total_loss:20.75341796875, stage_losses:8.146415710449219, 6.296120643615723, 6.310882091522217
batch_iter:82, total_loss:17.12959861755371, stage_losses:7.5050225257873535, 4.728637218475342, 4.895938873291016
batch_iter:83, total_loss:18.283096313476562, stage_losses:8.576423645019531, 4.7762675285339355, 4.930405616760254
batch_iter:84, total_loss:19.81382179260254, stage_losses:8.414891242980957, 5.628922939300537, 5.770007610321045
batch_iter:85, total_loss:23.969114303588867, stage_losses:8.867297172546387, 7.34736967086792, 7.7544474601745605
batch_iter:86, total_loss:13.759481430053711, stage_losses:6.506687641143799, 3.606296

batch_iter:150, total_loss:18.765201568603516, stage_losses:7.667509078979492, 5.423860549926758, 5.673830986022949
batch_iter:151, total_loss:16.689767837524414, stage_losses:6.875678062438965, 4.785345077514648, 5.028744697570801
batch_iter:152, total_loss:20.540752410888672, stage_losses:9.242878913879395, 5.460787773132324, 5.837086200714111
batch_iter:153, total_loss:26.499584197998047, stage_losses:12.287437438964844, 7.067485809326172, 7.144660949707031
batch_iter:154, total_loss:17.375131607055664, stage_losses:7.196289539337158, 4.990610122680664, 5.188232421875
batch_iter:155, total_loss:17.85262107849121, stage_losses:7.8445725440979, 4.899966716766357, 5.108081340789795
batch_iter:156, total_loss:18.391881942749023, stage_losses:7.736755847930908, 5.243439197540283, 5.41168737411499
batch_iter:157, total_loss:17.259822845458984, stage_losses:7.1061296463012695, 5.020379543304443, 5.1333136558532715
batch_iter:158, total_loss:19.845989227294922, stage_losses:8.88106346130371

batch_iter:10, total_loss:14.520185470581055, stage_losses:6.443264484405518, 3.914031505584717, 4.162889003753662
batch_iter:11, total_loss:15.801066398620605, stage_losses:6.621889591217041, 4.548105239868164, 4.631072044372559
batch_iter:12, total_loss:23.293052673339844, stage_losses:9.592063903808594, 6.711413860321045, 6.9895734786987305
batch_iter:13, total_loss:18.975738525390625, stage_losses:9.800392150878906, 4.510701656341553, 4.664644241333008
batch_iter:14, total_loss:21.663156509399414, stage_losses:9.102652549743652, 6.118202209472656, 6.4423017501831055
batch_iter:15, total_loss:18.188261032104492, stage_losses:7.787289619445801, 5.109525203704834, 5.291447162628174
batch_iter:16, total_loss:18.13134765625, stage_losses:7.657528877258301, 5.1520185470581055, 5.32180118560791
batch_iter:17, total_loss:16.358131408691406, stage_losses:7.558090686798096, 4.272150039672852, 4.527890205383301
batch_iter:18, total_loss:15.6695556640625, stage_losses:6.739135265350342, 4.4120

batch_iter:82, total_loss:17.574203491210938, stage_losses:7.203171253204346, 5.136532783508301, 5.234498500823975
batch_iter:83, total_loss:13.873575210571289, stage_losses:5.983663082122803, 3.863811492919922, 4.026099681854248
batch_iter:84, total_loss:16.81344985961914, stage_losses:7.267373085021973, 4.725351333618164, 4.820725440979004
batch_iter:85, total_loss:23.33655548095703, stage_losses:9.518970489501953, 6.737149715423584, 7.080434322357178
batch_iter:86, total_loss:17.882938385009766, stage_losses:8.63771915435791, 4.615190505981445, 4.630029201507568
batch_iter:87, total_loss:17.216815948486328, stage_losses:7.159514427185059, 4.964321136474609, 5.09298038482666
batch_iter:88, total_loss:17.141090393066406, stage_losses:7.32796049118042, 4.799417972564697, 5.013712406158447
batch_iter:89, total_loss:16.873210906982422, stage_losses:7.208226203918457, 4.76766300201416, 4.897322654724121
batch_iter:90, total_loss:18.350217819213867, stage_losses:7.548828125, 5.305482387542

batch_iter:154, total_loss:17.35076904296875, stage_losses:6.915550708770752, 5.174703121185303, 5.2605156898498535
batch_iter:155, total_loss:16.613910675048828, stage_losses:8.055932998657227, 4.138819694519043, 4.419157028198242
batch_iter:156, total_loss:22.420398712158203, stage_losses:10.736286163330078, 5.890026092529297, 5.794086933135986
batch_iter:157, total_loss:17.554777145385742, stage_losses:8.09831428527832, 4.669867992401123, 4.786594867706299
batch_iter:158, total_loss:22.882850646972656, stage_losses:9.5603666305542, 6.504299163818359, 6.818184852600098
batch_iter:159, total_loss:17.986589431762695, stage_losses:7.469429016113281, 5.157529830932617, 5.359631061553955
batch_iter:160, total_loss:20.230484008789062, stage_losses:8.415669441223145, 5.8016204833984375, 6.013194561004639
batch_iter:161, total_loss:18.152252197265625, stage_losses:7.543661594390869, 5.194729804992676, 5.4138593673706055
batch_iter:162, total_loss:21.506637573242188, stage_losses:9.2649059295

batch_iter:13, total_loss:17.090736389160156, stage_losses:7.196557998657227, 4.934258460998535, 4.9599199295043945
batch_iter:14, total_loss:16.77192497253418, stage_losses:7.098435401916504, 4.778145790100098, 4.89534330368042
batch_iter:15, total_loss:22.221515655517578, stage_losses:8.778929710388184, 6.807624340057373, 6.63496208190918
batch_iter:16, total_loss:17.233051300048828, stage_losses:7.783247470855713, 4.654256343841553, 4.795546531677246
batch_iter:17, total_loss:16.082855224609375, stage_losses:7.027029514312744, 4.447066307067871, 4.608758449554443
batch_iter:18, total_loss:17.715944290161133, stage_losses:7.49129581451416, 5.074543476104736, 5.150104522705078
batch_iter:19, total_loss:16.505672454833984, stage_losses:7.665604114532471, 4.3253068923950195, 4.5147600173950195
batch_iter:20, total_loss:16.69073486328125, stage_losses:7.281590938568115, 4.6007256507873535, 4.808419227600098
batch_iter:21, total_loss:18.13294219970703, stage_losses:8.2887544631958, 4.8280

KeyboardInterrupt: 

In [None]:
writer.close()