## Imports

In [1]:
import os
import torch
import torch.nn as nn
import wandb
import numpy as np

from torchvision import transforms
from mmcv_csn import ResNet3dCSN
from csn import csn50
from i3d_head import I3DHead
from cls_autoencoder import EncoderDecoder
from reconstruction_head import RecontructionHead
from scheduler import GradualWarmupScheduler
from mmaction.datasets import build_dataset

os.chdir('../')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(entity="cares", project="autoencoder",
           group="wlasl-100", name="reconstruction")

[34m[1mwandb[0m: Currently logged in as: [33msttaseen[0m ([33mcares[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Device Agnostic Code

In [3]:
# Set up device agnostic code
try:
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
except:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Dataset

In [4]:
train_cfg=dict(
    type='RawframeDataset',
    ann_file='data/wlasl/train_annotations.txt',
    data_prefix='data/wlasl/rawframes',
    pipeline=[
        dict(
            type='SampleFrames',
            clip_len=32,
            frame_interval=2,
            num_clips=1),
        dict(type='RawFrameDecode'),
        dict(type='Resize', scale=(-1, 256)),
        dict(type='RandomResizedCrop'),
        dict(type='Resize', scale=(224, 224), keep_ratio=False),
        dict(type='Flip', flip_ratio=0.5),
        dict(
            type='Normalize',
            mean=[123.675, 116.28, 103.53],
            std=[58.395, 57.12, 57.375],
            to_bgr=False),
        dict(type='FormatShape', input_format='NCTHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs', 'label'])
    ])


test_cfg=dict(
        type='RawframeDataset',
        ann_file='data/wlasl/test_annotations.txt',
        data_prefix='data/wlasl/rawframes',
        pipeline=[
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=1,
                test_mode=True),
            dict(type='RawFrameDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='CenterCrop', crop_size=224),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_bgr=False),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
            dict(type='ToTensor', keys=['imgs'])
    ])

In [5]:
work_dir = 'work_dirs/wlasl-dataset/'

os.makedirs(work_dir, exist_ok=True)

In [6]:
# Building the datasets

batch_size = 2

train_dataset = build_dataset(train_cfg)
test_dataset = build_dataset(test_cfg)

# Setting up dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

## Model

In [7]:
# Create a CSN model
encoder = ResNet3dCSN(
    pretrained2d=False,
    # pretrained=None,
    pretrained='https://download.openmmlab.com/mmaction/recognition/csn/ircsn_from_scratch_r50_ig65m_20210617-ce545a37.pth',
    depth=50,
    with_pool2=False,
    bottleneck_mode='ir',
    norm_eval=True,
    zero_init_residual=False,
    bn_frozen=True
)

encoder.init_weights()

reconstruct_head = RecontructionHead()

decoder = I3DHead(num_classes=400,
                 in_channels=2048,
                 spatial_type='avg',
                 dropout_ratio=0.5,
                 init_std=0.01)

decoder.init_weights()

model = EncoderDecoder(encoder, decoder, reconstruct_head)

## Optimizer

In [8]:
# Specify optimizer
optimizer = torch.optim.SGD(
    model.parameters(), lr=0.000125, momentum=0.9, weight_decay=0.00001)

# Specify Loss
loss_cls = nn.CrossEntropyLoss()
loss_reconstruct = nn.MSELoss()

# Specify total epochs
epochs = 100

# Specify learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=120, gamma=0.1)

scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[34, 84], gamma=0.1)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=16, after_scheduler=scheduler_steplr)

model.to(device)

EncoderDecoder(
  (encoder): ResNet3dCSN(
    (conv1): ConvModule(
      (conv): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
      (bn): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (maxpool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
    (pool2): MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=0, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): CSNBottleneck3d(
        (conv1): ConvModule(
          (conv): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (bn): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): Sequential(
          (0): ConvModule(
            (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=64, bi

## Train Loop

In [9]:
# Setup wandb
wandb.watch(model, log_freq=10)

def top_k_accuracy(scores, labels, topk=(1, )):
    """Calculate top k accuracy score.
    Args:
        scores (list[np.ndarray]): Prediction scores for each class.
        labels (list[int]): Ground truth labels.
        topk (tuple[int]): K value for top_k_accuracy. Default: (1, ).
    Returns:
        list[float]: Top k accuracy score for each k.
    """
    res = np.zeros(len(topk))
    labels = np.array(labels)[:, np.newaxis]
    for i, k in enumerate(topk):
        max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1]
        match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
        topk_acc_score = match_array.sum() / match_array.shape[0]
        res[i] = topk_acc_score

    return res


def train_one_epoch(epoch_index, interval=5):
    """Run one epoch for training.
    Args:
        epoch_index (int): Current epoch.
        interval (int): Frequency at which to print logs.
    Returns:
        last_loss (float): Loss value for the last batch.
    """
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, x in enumerate(train_loader):
        # Every data instance is an input + label pair
        images, targets = x['imgs'].to(device), x['label'].to(device)
        images = images.reshape((-1, ) + images.shape[2:])
        targets = targets.reshape(-1, )
        
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        cls_score, reconstructed = model(images)

        # Get losses
        loss_cls_score = loss_cls(cls_score, targets)
        loss_reconstruct_score = loss_reconstruct(reconstructed, images)
        loss = 0.8 * loss_cls_score + 0.2 * loss_reconstruct_score
        
        # Compute the loss and its gradients
        loss.backward()

        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=40, norm_type=2.0)

        # Adjust learning weights
        optimizer.step()
        

        # Gather data and report
        running_loss += loss.item()
        if i % interval == interval-1:
            last_loss = running_loss / interval  # loss per batch
            print(
                f'Epoch [{epoch_index}][{i+1}/{len(train_loader)}], loss_cls: {loss_cls_score.item():.5}, reconstruct_loss: {loss_reconstruct_score.item():.5} lr: {scheduler.get_last_lr()[0]:.5e}, loss: {last_loss:.5}')
            running_loss = 0.

    return last_loss, scheduler.get_last_lr()[0]


def validate():
    """Run one epoch for validation.
    Returns:
        avg_vloss (float): Validation loss value for the last batch.
        top1_acc (float): Top-1 accuracy in decimal.
        top5_acc (float): Top-5 accuracy in decimal.
    """
    running_vloss = 0.0
    running_vacc = np.zeros(2)

    print('Evaluating top_k_accuracy...')

    with torch.inference_mode():
        for i, x in enumerate(test_loader):
            vimages, vtargets = x['imgs'].to(device), x['label'].to(device)
            vimages = vimages.reshape((-1, ) + vimages.shape[2:])
            vtargets = vtargets.reshape(-1, )
            
            cls_score, reconstructed = model(vimages)

            # Get losses
            loss_cls_score = loss_cls(cls_score, vtargets)
            loss_reconstruct_score = loss_reconstruct(reconstructed, vimages)
            vloss = 0.8 * loss_cls_score + 0.2 * loss_reconstruct_score
            
            running_vloss += vloss

            running_vacc += top_k_accuracy(cls_score.detach().cpu().numpy(),
                                           vtargets.detach().cpu().numpy(), topk=(1, 5))

    avg_vloss = running_vloss / (i + 1)

    acc = running_vacc/len(test_loader)
    top1_acc = acc[0].item()
    top5_acc = acc[1].item()

    return (avg_vloss, top1_acc, top5_acc)


# Train Loop
best_vloss = 1_000_000.

for epoch in range(epochs):
    # Turn on gradient tracking and do a forward pass
    model.train(True)
    avg_loss, learning_rate = train_one_epoch(epoch+1)

    # Turn off  gradients for reporting
    model.train(False)

    avg_vloss, top1_acc, top5_acc = validate()

    print(
        f'top1_acc: {top1_acc:.4}, top5_acc: {top5_acc:.4}, train_loss: {avg_loss:.5}, val_loss: {avg_vloss:.5}')

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = work_dir + f'epoch_{epoch+1}.pth'
        print(f'Saving checkpoint at {epoch+1} epochs...')
        torch.save(model.state_dict(), model_path)

     # Adjust learning rate
    scheduler.step()

    # Track wandb
    wandb.log({'train/loss': avg_loss,
               'train/learning_rate': learning_rate,
               'val/loss': avg_vloss,
               'val/top1_accuracy': top1_acc,
               'val/top5_accuracy': top5_acc})

Epoch [1][5/721], loss_cls: 5.9559, reconstruct_loss: 3.8625 lr: 0.00000e+00, loss: 5.411
Epoch [1][10/721], loss_cls: 6.077, reconstruct_loss: 2.3185 lr: 0.00000e+00, loss: 5.3872
Epoch [1][15/721], loss_cls: 5.8407, reconstruct_loss: 3.0153 lr: 0.00000e+00, loss: 5.1726
Epoch [1][20/721], loss_cls: 5.9913, reconstruct_loss: 2.202 lr: 0.00000e+00, loss: 5.3453
Epoch [1][25/721], loss_cls: 5.9656, reconstruct_loss: 2.1652 lr: 0.00000e+00, loss: 5.3256
Epoch [1][30/721], loss_cls: 5.7886, reconstruct_loss: 2.0581 lr: 0.00000e+00, loss: 5.4699
Epoch [1][35/721], loss_cls: 5.9942, reconstruct_loss: 2.1716 lr: 0.00000e+00, loss: 5.3626
Epoch [1][40/721], loss_cls: 5.832, reconstruct_loss: 2.1708 lr: 0.00000e+00, loss: 5.2411
Epoch [1][45/721], loss_cls: 6.0216, reconstruct_loss: 2.9662 lr: 0.00000e+00, loss: 5.2971
Epoch [1][50/721], loss_cls: 6.1063, reconstruct_loss: 2.2705 lr: 0.00000e+00, loss: 5.2872
Epoch [1][55/721], loss_cls: 5.9997, reconstruct_loss: 1.8191 lr: 0.00000e+00, loss: 

Epoch [1][450/721], loss_cls: 5.9606, reconstruct_loss: 3.3753 lr: 0.00000e+00, loss: 5.3101
Epoch [1][455/721], loss_cls: 6.1868, reconstruct_loss: 2.3195 lr: 0.00000e+00, loss: 5.2679
Epoch [1][460/721], loss_cls: 6.0628, reconstruct_loss: 1.9187 lr: 0.00000e+00, loss: 5.3227
Epoch [1][465/721], loss_cls: 6.1129, reconstruct_loss: 2.0251 lr: 0.00000e+00, loss: 5.3119
Epoch [1][470/721], loss_cls: 6.0044, reconstruct_loss: 1.2593 lr: 0.00000e+00, loss: 5.308
Epoch [1][475/721], loss_cls: 5.9143, reconstruct_loss: 1.8403 lr: 0.00000e+00, loss: 5.3218
Epoch [1][480/721], loss_cls: 6.0154, reconstruct_loss: 2.7038 lr: 0.00000e+00, loss: 5.2919
Epoch [1][485/721], loss_cls: 6.0508, reconstruct_loss: 1.6417 lr: 0.00000e+00, loss: 5.3086
Epoch [1][490/721], loss_cls: 6.2807, reconstruct_loss: 3.1535 lr: 0.00000e+00, loss: 5.3136
Epoch [1][495/721], loss_cls: 5.9491, reconstruct_loss: 2.9956 lr: 0.00000e+00, loss: 5.2549
Epoch [1][500/721], loss_cls: 5.9422, reconstruct_loss: 2.0317 lr: 0.00

Epoch [2][170/721], loss_cls: 5.8941, reconstruct_loss: 3.057 lr: 7.81250e-06, loss: 5.2558
Epoch [2][175/721], loss_cls: 5.8287, reconstruct_loss: 2.8421 lr: 7.81250e-06, loss: 5.2006
Epoch [2][180/721], loss_cls: 6.1002, reconstruct_loss: 2.7902 lr: 7.81250e-06, loss: 5.5164
Epoch [2][185/721], loss_cls: 5.859, reconstruct_loss: 1.5538 lr: 7.81250e-06, loss: 5.2662
Epoch [2][190/721], loss_cls: 6.0371, reconstruct_loss: 1.6094 lr: 7.81250e-06, loss: 5.2619
Epoch [2][195/721], loss_cls: 6.0879, reconstruct_loss: 2.9528 lr: 7.81250e-06, loss: 5.3787
Epoch [2][200/721], loss_cls: 5.9397, reconstruct_loss: 2.0017 lr: 7.81250e-06, loss: 5.2498
Epoch [2][205/721], loss_cls: 6.1234, reconstruct_loss: 1.9043 lr: 7.81250e-06, loss: 5.2813
Epoch [2][210/721], loss_cls: 5.8246, reconstruct_loss: 2.6347 lr: 7.81250e-06, loss: 5.2983
Epoch [2][215/721], loss_cls: 5.8106, reconstruct_loss: 1.4554 lr: 7.81250e-06, loss: 5.2101
Epoch [2][220/721], loss_cls: 6.3837, reconstruct_loss: 2.0751 lr: 7.812

Epoch [2][615/721], loss_cls: 5.8299, reconstruct_loss: 2.1944 lr: 7.81250e-06, loss: 5.3437
Epoch [2][620/721], loss_cls: 5.5651, reconstruct_loss: 2.0238 lr: 7.81250e-06, loss: 5.155
Epoch [2][625/721], loss_cls: 5.8258, reconstruct_loss: 2.6239 lr: 7.81250e-06, loss: 5.2447
Epoch [2][630/721], loss_cls: 5.8735, reconstruct_loss: 2.0369 lr: 7.81250e-06, loss: 5.2954
Epoch [2][635/721], loss_cls: 6.1254, reconstruct_loss: 3.5061 lr: 7.81250e-06, loss: 5.1621
Epoch [2][640/721], loss_cls: 6.0142, reconstruct_loss: 2.1146 lr: 7.81250e-06, loss: 5.3238
Epoch [2][645/721], loss_cls: 5.8725, reconstruct_loss: 0.76192 lr: 7.81250e-06, loss: 5.1959
Epoch [2][650/721], loss_cls: 5.9484, reconstruct_loss: 2.5367 lr: 7.81250e-06, loss: 5.2149
Epoch [2][655/721], loss_cls: 5.9788, reconstruct_loss: 2.9929 lr: 7.81250e-06, loss: 5.0581
Epoch [2][660/721], loss_cls: 5.7171, reconstruct_loss: 2.2 lr: 7.81250e-06, loss: 5.0674
Epoch [2][665/721], loss_cls: 6.2615, reconstruct_loss: 2.7697 lr: 7.8125

Epoch [3][335/721], loss_cls: 5.9034, reconstruct_loss: 1.5243 lr: 1.56250e-05, loss: 4.749
Epoch [3][340/721], loss_cls: 5.9739, reconstruct_loss: 2.0973 lr: 1.56250e-05, loss: 4.6631
Epoch [3][345/721], loss_cls: 4.6865, reconstruct_loss: 1.3546 lr: 1.56250e-05, loss: 4.4989
Epoch [3][350/721], loss_cls: 5.5368, reconstruct_loss: 1.625 lr: 1.56250e-05, loss: 5.0658
Epoch [3][355/721], loss_cls: 4.901, reconstruct_loss: 3.0479 lr: 1.56250e-05, loss: 4.8572
Epoch [3][360/721], loss_cls: 5.3632, reconstruct_loss: 2.9578 lr: 1.56250e-05, loss: 4.9926
Epoch [3][365/721], loss_cls: 5.4993, reconstruct_loss: 2.0886 lr: 1.56250e-05, loss: 5.1123
Epoch [3][370/721], loss_cls: 6.0068, reconstruct_loss: 2.9424 lr: 1.56250e-05, loss: 4.8817
Epoch [3][375/721], loss_cls: 5.2896, reconstruct_loss: 2.9635 lr: 1.56250e-05, loss: 4.7918
Epoch [3][380/721], loss_cls: 5.3786, reconstruct_loss: 1.9215 lr: 1.56250e-05, loss: 4.9412
Epoch [3][385/721], loss_cls: 5.5944, reconstruct_loss: 3.0504 lr: 1.5625

Epoch [4][55/721], loss_cls: 5.2941, reconstruct_loss: 2.3552 lr: 2.34375e-05, loss: 4.4579
Epoch [4][60/721], loss_cls: 3.8712, reconstruct_loss: 2.8643 lr: 2.34375e-05, loss: 4.4667
Epoch [4][65/721], loss_cls: 4.8961, reconstruct_loss: 2.1197 lr: 2.34375e-05, loss: 4.855
Epoch [4][70/721], loss_cls: 5.285, reconstruct_loss: 1.4669 lr: 2.34375e-05, loss: 4.5517
Epoch [4][75/721], loss_cls: 4.4621, reconstruct_loss: 1.298 lr: 2.34375e-05, loss: 4.5179
Epoch [4][80/721], loss_cls: 4.31, reconstruct_loss: 2.4093 lr: 2.34375e-05, loss: 4.5033
Epoch [4][85/721], loss_cls: 5.7649, reconstruct_loss: 2.4129 lr: 2.34375e-05, loss: 4.8026
Epoch [4][90/721], loss_cls: 4.3198, reconstruct_loss: 1.9382 lr: 2.34375e-05, loss: 4.3824
Epoch [4][95/721], loss_cls: 4.6917, reconstruct_loss: 2.0446 lr: 2.34375e-05, loss: 4.1453
Epoch [4][100/721], loss_cls: 3.7388, reconstruct_loss: 2.3478 lr: 2.34375e-05, loss: 4.5288
Epoch [4][105/721], loss_cls: 4.9031, reconstruct_loss: 3.6034 lr: 2.34375e-05, loss

Epoch [4][500/721], loss_cls: 5.0503, reconstruct_loss: 1.9544 lr: 2.34375e-05, loss: 4.3417
Epoch [4][505/721], loss_cls: 5.2414, reconstruct_loss: 2.4866 lr: 2.34375e-05, loss: 4.334
Epoch [4][510/721], loss_cls: 3.8299, reconstruct_loss: 1.9243 lr: 2.34375e-05, loss: 4.1526
Epoch [4][515/721], loss_cls: 5.0416, reconstruct_loss: 1.1555 lr: 2.34375e-05, loss: 4.3132
Epoch [4][520/721], loss_cls: 4.6436, reconstruct_loss: 3.5353 lr: 2.34375e-05, loss: 4.3868
Epoch [4][525/721], loss_cls: 5.892, reconstruct_loss: 1.9366 lr: 2.34375e-05, loss: 4.7217
Epoch [4][530/721], loss_cls: 5.4733, reconstruct_loss: 3.698 lr: 2.34375e-05, loss: 4.6756
Epoch [4][535/721], loss_cls: 4.6125, reconstruct_loss: 1.5938 lr: 2.34375e-05, loss: 4.5655
Epoch [4][540/721], loss_cls: 5.4042, reconstruct_loss: 1.6591 lr: 2.34375e-05, loss: 4.2625
Epoch [4][545/721], loss_cls: 5.9773, reconstruct_loss: 1.5682 lr: 2.34375e-05, loss: 4.6146
Epoch [4][550/721], loss_cls: 4.3122, reconstruct_loss: 3.8903 lr: 2.3437

Epoch [5][220/721], loss_cls: 4.5725, reconstruct_loss: 2.4375 lr: 3.12500e-05, loss: 4.6007
Epoch [5][225/721], loss_cls: 4.9231, reconstruct_loss: 0.84579 lr: 3.12500e-05, loss: 4.4483
Epoch [5][230/721], loss_cls: 5.2238, reconstruct_loss: 1.5312 lr: 3.12500e-05, loss: 4.671
Epoch [5][235/721], loss_cls: 5.3373, reconstruct_loss: 1.7664 lr: 3.12500e-05, loss: 4.4858
Epoch [5][240/721], loss_cls: 4.7613, reconstruct_loss: 1.6432 lr: 3.12500e-05, loss: 4.6445
Epoch [5][245/721], loss_cls: 5.2738, reconstruct_loss: 1.9339 lr: 3.12500e-05, loss: 4.1184
Epoch [5][250/721], loss_cls: 4.4927, reconstruct_loss: 3.1957 lr: 3.12500e-05, loss: 4.0839
Epoch [5][255/721], loss_cls: 4.4511, reconstruct_loss: 1.3127 lr: 3.12500e-05, loss: 4.2321
Epoch [5][260/721], loss_cls: 5.0431, reconstruct_loss: 1.5481 lr: 3.12500e-05, loss: 4.2565
Epoch [5][265/721], loss_cls: 4.5535, reconstruct_loss: 2.3397 lr: 3.12500e-05, loss: 4.1746
Epoch [5][270/721], loss_cls: 4.8482, reconstruct_loss: 1.2574 lr: 3.1

Epoch [5][665/721], loss_cls: 5.0136, reconstruct_loss: 1.1936 lr: 3.12500e-05, loss: 4.51
Epoch [5][670/721], loss_cls: 5.0649, reconstruct_loss: 2.0647 lr: 3.12500e-05, loss: 4.1457
Epoch [5][675/721], loss_cls: 4.8582, reconstruct_loss: 1.2851 lr: 3.12500e-05, loss: 4.0935
Epoch [5][680/721], loss_cls: 5.0291, reconstruct_loss: 1.7455 lr: 3.12500e-05, loss: 4.4321
Epoch [5][685/721], loss_cls: 4.5276, reconstruct_loss: 1.16 lr: 3.12500e-05, loss: 4.0952
Epoch [5][690/721], loss_cls: 4.5706, reconstruct_loss: 3.0713 lr: 3.12500e-05, loss: 4.3967
Epoch [5][695/721], loss_cls: 5.6889, reconstruct_loss: 2.2758 lr: 3.12500e-05, loss: 4.3168
Epoch [5][700/721], loss_cls: 4.5588, reconstruct_loss: 1.9038 lr: 3.12500e-05, loss: 4.2532
Epoch [5][705/721], loss_cls: 5.1533, reconstruct_loss: 1.1397 lr: 3.12500e-05, loss: 4.43
Epoch [5][710/721], loss_cls: 4.6373, reconstruct_loss: 1.057 lr: 3.12500e-05, loss: 4.0773
Epoch [5][715/721], loss_cls: 4.9935, reconstruct_loss: 1.8231 lr: 3.12500e-0

KeyboardInterrupt: 