In [None]:
!pip install numpy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install numpy==1.16.1

Reading package lists... Done
Building dependency tree       
Reading state information... Done
[1;31mE: [0mUnable to locate package numpy=[0m


In [None]:
print(numpy.__version__)

1.24.3


In [None]:
epoch_losses = []
err_lst = []

In [None]:
import torch
import numpy as np
import torch.utils.data
import torchvision
from loader import *
import os
from fcrn import FCRN
from torch.autograd import Variable
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot

dtype = torch.cuda.FloatTensor
weights_file = "NYU_ResNet-UpProj.npy"


def load_split():
    current_directoty = os.getcwd()
    train_lists_path = current_directoty + '/trainIdxs.txt'
    test_lists_path = current_directoty + '/testIdxs.txt'

    train_f = open(train_lists_path)
    test_f = open(test_lists_path)

    train_lists = []
    test_lists = []

    train_lists_line = train_f.readline()
    while train_lists_line:
        train_lists.append(int(train_lists_line) - 1)
        train_lists_line = train_f.readline()
    train_f.close()

    test_lists_line = test_f.readline()
    while test_lists_line:
        test_lists.append(int(test_lists_line) - 1)
        test_lists_line = test_f.readline()
    test_f.close()

    val_start_idx = int(len(train_lists) * 0.8)

    val_lists = train_lists[val_start_idx:-1]
    train_lists = train_lists[0:val_start_idx]

    return train_lists, val_lists, test_lists


def main():
    batch_size = 32
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 50
    resume_from_file = False
    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists),
                                               batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists),
                                               batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, test_lists),
                                             batch_size=batch_size, shuffle=True, drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model......")
    model = FCRN(batch_size)
    #resnet = torchvision.models.resnet50(pretrained=True)
    resnet = torchvision.models.resnet50()
    resnet.load_state_dict(torch.load('./resnet50-19c8e357.pth'))
    print("resnet50 loaded.")
    resnet50_pretrained_dict = resnet.state_dict()

    model.load_state_dict(load_weights(model, weights_file, dtype))
    """
    print('\nresnet50 keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict = model.state_dict()
    """
    print('\nmodel keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())

    print("resnet50.dict loaded.")
    """
    # load pretrained weights
    #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict}
    print("resnet50_pretrained_dict loaded.")
    """
    print('\nresnet50_pretrained keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict.update(resnet50_pretrained_dict)
    print("model_dict updated.")
    """
    print('\nupdated model dict keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())
    """
    #model.load_state_dict(model_dict)
    print("model_dict loaded.")
    model = model.cuda()

    # 3.Loss
    loss_fn = torch.nn.MSELoss().cuda()
    print("loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('gt_depth_epoch_0.png', input_gt_depth_image, cmap="viridis")
            plot.imsave('pred_depth_epoch_0.png', pred_depth_image, cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    for epoch in range(num_epochs):

        # 4.Optim
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' % (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            # input, depth = data
            #input_var = input.cuda()
            #depth_var = depth.cuda()
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss:', loss.item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        epoch_losses.append(epoch_loss)
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32)

                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave('input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image)
                plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, cmap="viridis")
                plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error:', err)
        err_lst.append(err)
        if err < best_val_err:
            best_val_err = err
            torch.save({
                'epoch': start_epoch + epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, 'checkpoint.pth.tar')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.6


In [None]:
def load_weights(model, weights_file, dtype):

    model_params = model.state_dict()
    data_dict = np.load(weights_file, allow_pickle = True, encoding='latin1').item()

    if True:
        model_params['conv1.weight'] = torch.from_numpy(data_dict['conv1']['weights']).type(dtype).permute(3,2,0,1)
        #model_params['conv1.bias'] = torch.from_numpy(data_dict['conv1']['biases']).type(dtype)
        model_params['bn1.weight'] = torch.from_numpy(data_dict['bn_conv1']['scale']).type(dtype)
        model_params['bn1.bias'] = torch.from_numpy(data_dict['bn_conv1']['offset']).type(dtype)

        model_params['layer1.0.downsample.0.weight'] = torch.from_numpy(data_dict['res2a_branch1']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn2a_branch1']['scale']).type(dtype)
        model_params['layer1.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn2a_branch1']['offset']).type(dtype)

        model_params['layer1.0.conv1.weight'] = torch.from_numpy(data_dict['res2a_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.0.bn1.weight'] = torch.from_numpy(data_dict['bn2a_branch2a']['scale']).type(dtype)
        model_params['layer1.0.bn1.bias'] = torch.from_numpy(data_dict['bn2a_branch2a']['offset']).type(dtype)

        model_params['layer1.0.conv2.weight'] = torch.from_numpy(data_dict['res2a_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.0.bn2.weight'] = torch.from_numpy(data_dict['bn2a_branch2b']['scale']).type(dtype)
        model_params['layer1.0.bn2.bias'] = torch.from_numpy(data_dict['bn2a_branch2b']['offset']).type(dtype)

        model_params['layer1.0.conv3.weight'] = torch.from_numpy(data_dict['res2a_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.0.bn3.weight'] = torch.from_numpy(data_dict['bn2a_branch2c']['scale']).type(dtype)
        model_params['layer1.0.bn3.bias'] = torch.from_numpy(data_dict['bn2a_branch2c']['offset']).type(dtype)

        model_params['layer1.1.conv1.weight'] = torch.from_numpy(data_dict['res2b_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.1.bn1.weight'] = torch.from_numpy(data_dict['bn2b_branch2a']['scale']).type(dtype)
        model_params['layer1.1.bn1.bias'] = torch.from_numpy(data_dict['bn2b_branch2a']['offset']).type(dtype)

        model_params['layer1.1.conv2.weight'] = torch.from_numpy(data_dict['res2b_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.1.bn2.weight'] = torch.from_numpy(data_dict['bn2b_branch2b']['scale']).type(dtype)
        model_params['layer1.1.bn2.bias'] = torch.from_numpy(data_dict['bn2b_branch2b']['offset']).type(dtype)

        model_params['layer1.1.conv3.weight'] = torch.from_numpy(data_dict['res2b_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.1.bn3.weight'] = torch.from_numpy(data_dict['bn2b_branch2c']['scale']).type(dtype)
        model_params['layer1.1.bn3.bias'] = torch.from_numpy(data_dict['bn2b_branch2c']['offset']).type(dtype)

        model_params['layer1.2.conv1.weight'] = torch.from_numpy(data_dict['res2c_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.2.bn1.weight'] = torch.from_numpy(data_dict['bn2c_branch2a']['scale']).type(dtype)
        model_params['layer1.2.bn1.bias'] = torch.from_numpy(data_dict['bn2c_branch2a']['offset']).type(dtype)

        model_params['layer1.2.conv2.weight'] = torch.from_numpy(data_dict['res2c_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.2.bn2.weight'] = torch.from_numpy(data_dict['bn2c_branch2b']['scale']).type(dtype)
        model_params['layer1.2.bn2.bias'] = torch.from_numpy(data_dict['bn2c_branch2b']['offset']).type(dtype)

        model_params['layer1.2.conv3.weight'] = torch.from_numpy(data_dict['res2c_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer1.2.bn3.weight'] = torch.from_numpy(data_dict['bn2c_branch2c']['scale']).type(dtype)
        model_params['layer1.2.bn3.bias'] = torch.from_numpy(data_dict['bn2c_branch2c']['offset']).type(dtype)

        model_params['layer2.0.downsample.0.weight'] = torch.from_numpy(data_dict['res3a_branch1']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn3a_branch1']['scale']).type(dtype)
        model_params['layer2.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn3a_branch1']['offset']).type(dtype)

        model_params['layer2.0.conv1.weight'] = torch.from_numpy(data_dict['res3a_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.0.bn1.weight'] = torch.from_numpy(data_dict['bn3a_branch2a']['scale']).type(dtype)
        model_params['layer2.0.bn1.bias'] = torch.from_numpy(data_dict['bn3a_branch2a']['offset']).type(dtype)

        model_params['layer2.0.conv2.weight'] = torch.from_numpy(data_dict['res3a_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.0.bn2.weight'] = torch.from_numpy(data_dict['bn3a_branch2b']['scale']).type(dtype)
        model_params['layer2.0.bn2.bias'] = torch.from_numpy(data_dict['bn3a_branch2b']['offset']).type(dtype)

        model_params['layer2.0.conv3.weight'] = torch.from_numpy(data_dict['res3a_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.0.bn3.weight'] = torch.from_numpy(data_dict['bn3a_branch2c']['scale']).type(dtype)
        model_params['layer2.0.bn3.bias'] = torch.from_numpy(data_dict['bn3a_branch2c']['offset']).type(dtype)

        model_params['layer2.1.conv1.weight'] = torch.from_numpy(data_dict['res3b_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.1.bn1.weight'] = torch.from_numpy(data_dict['bn3b_branch2a']['scale']).type(dtype)
        model_params['layer2.1.bn1.bias'] = torch.from_numpy(data_dict['bn3b_branch2a']['offset']).type(dtype)

        model_params['layer2.1.conv2.weight'] = torch.from_numpy(data_dict['res3b_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.1.bn2.weight'] = torch.from_numpy(data_dict['bn3b_branch2b']['scale']).type(dtype)
        model_params['layer2.1.bn2.bias'] = torch.from_numpy(data_dict['bn3b_branch2b']['offset']).type(dtype)

        model_params['layer2.1.conv3.weight'] = torch.from_numpy(data_dict['res3b_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.1.bn3.weight'] = torch.from_numpy(data_dict['bn3b_branch2c']['scale']).type(dtype)
        model_params['layer2.1.bn3.bias'] = torch.from_numpy(data_dict['bn3b_branch2c']['offset']).type(dtype)

        model_params['layer2.2.conv1.weight'] = torch.from_numpy(data_dict['res3c_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.2.bn1.weight'] = torch.from_numpy(data_dict['bn3c_branch2a']['scale']).type(dtype)
        model_params['layer2.2.bn1.bias'] = torch.from_numpy(data_dict['bn3c_branch2a']['offset']).type(dtype)

        model_params['layer2.2.conv2.weight'] = torch.from_numpy(data_dict['res3c_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.2.bn2.weight'] = torch.from_numpy(data_dict['bn3c_branch2b']['scale']).type(dtype)
        model_params['layer2.2.bn2.bias'] = torch.from_numpy(data_dict['bn3c_branch2b']['offset']).type(dtype)

        model_params['layer2.2.conv3.weight'] = torch.from_numpy(data_dict['res3c_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.2.bn3.weight'] = torch.from_numpy(data_dict['bn3c_branch2c']['scale']).type(dtype)
        model_params['layer2.2.bn3.bias'] = torch.from_numpy(data_dict['bn3c_branch2c']['offset']).type(dtype)

        model_params['layer2.3.conv1.weight'] = torch.from_numpy(data_dict['res3d_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.3.bn1.weight'] = torch.from_numpy(data_dict['bn3d_branch2a']['scale']).type(dtype)
        model_params['layer2.3.bn1.bias'] = torch.from_numpy(data_dict['bn3d_branch2a']['offset']).type(dtype)

        model_params['layer2.3.conv2.weight'] = torch.from_numpy(data_dict['res3d_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.3.bn2.weight'] = torch.from_numpy(data_dict['bn3d_branch2b']['scale']).type(dtype)
        model_params['layer2.3.bn2.bias'] = torch.from_numpy(data_dict['bn3d_branch2b']['offset']).type(dtype)

        model_params['layer2.3.conv3.weight'] = torch.from_numpy(data_dict['res3d_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer2.3.bn3.weight'] = torch.from_numpy(data_dict['bn3d_branch2c']['scale']).type(dtype)
        model_params['layer2.3.bn3.bias'] = torch.from_numpy(data_dict['bn3d_branch2c']['offset']).type(dtype)

        model_params['layer3.0.downsample.0.weight'] = torch.from_numpy(data_dict['res4a_branch1']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn4a_branch1']['scale']).type(dtype)
        model_params['layer3.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn4a_branch1']['offset']).type(dtype)

        model_params['layer3.0.conv1.weight'] = torch.from_numpy(data_dict['res4a_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.0.bn1.weight'] = torch.from_numpy(data_dict['bn4a_branch2a']['scale']).type(dtype)
        model_params['layer3.0.bn1.bias'] = torch.from_numpy(data_dict['bn4a_branch2a']['offset']).type(dtype)

        model_params['layer3.0.conv2.weight'] = torch.from_numpy(data_dict['res4a_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.0.bn2.weight'] = torch.from_numpy(data_dict['bn4a_branch2b']['scale']).type(dtype)
        model_params['layer3.0.bn2.bias'] = torch.from_numpy(data_dict['bn4a_branch2b']['offset']).type(dtype)

        model_params['layer3.0.conv3.weight'] = torch.from_numpy(data_dict['res4a_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.0.bn3.weight'] = torch.from_numpy(data_dict['bn4a_branch2c']['scale']).type(dtype)
        model_params['layer3.0.bn3.bias'] = torch.from_numpy(data_dict['bn4a_branch2c']['offset']).type(dtype)

        model_params['layer3.1.conv1.weight'] = torch.from_numpy(data_dict['res4b_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.1.bn1.weight'] = torch.from_numpy(data_dict['bn4b_branch2a']['scale']).type(dtype)
        model_params['layer3.1.bn1.bias'] = torch.from_numpy(data_dict['bn4b_branch2a']['offset']).type(dtype)

        model_params['layer3.1.conv2.weight'] = torch.from_numpy(data_dict['res4b_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.1.bn2.weight'] = torch.from_numpy(data_dict['bn4b_branch2b']['scale']).type(dtype)
        model_params['layer3.1.bn2.bias'] = torch.from_numpy(data_dict['bn4b_branch2b']['offset']).type(dtype)

        model_params['layer3.1.conv3.weight'] = torch.from_numpy(data_dict['res4b_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.1.bn3.weight'] = torch.from_numpy(data_dict['bn4b_branch2c']['scale']).type(dtype)
        model_params['layer3.1.bn3.bias'] = torch.from_numpy(data_dict['bn4b_branch2c']['offset']).type(dtype)

        model_params['layer3.2.conv1.weight'] = torch.from_numpy(data_dict['res4c_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.2.bn1.weight'] = torch.from_numpy(data_dict['bn4c_branch2a']['scale']).type(dtype)
        model_params['layer3.2.bn1.bias'] = torch.from_numpy(data_dict['bn4c_branch2a']['offset']).type(dtype)

        model_params['layer3.2.conv2.weight'] = torch.from_numpy(data_dict['res4c_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.2.bn2.weight'] = torch.from_numpy(data_dict['bn4c_branch2b']['scale']).type(dtype)
        model_params['layer3.2.bn2.bias'] = torch.from_numpy(data_dict['bn4c_branch2b']['offset']).type(dtype)

        model_params['layer3.2.conv3.weight'] = torch.from_numpy(data_dict['res4c_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.2.bn3.weight'] = torch.from_numpy(data_dict['bn4c_branch2c']['scale']).type(dtype)
        model_params['layer3.2.bn3.bias'] = torch.from_numpy(data_dict['bn4c_branch2c']['offset']).type(dtype)

        model_params['layer3.3.conv1.weight'] = torch.from_numpy(data_dict['res4d_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.3.bn1.weight'] = torch.from_numpy(data_dict['bn4d_branch2a']['scale']).type(dtype)
        model_params['layer3.3.bn1.bias'] = torch.from_numpy(data_dict['bn4d_branch2a']['offset']).type(dtype)

        model_params['layer3.3.conv2.weight'] = torch.from_numpy(data_dict['res4d_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.3.bn2.weight'] = torch.from_numpy(data_dict['bn4d_branch2b']['scale']).type(dtype)
        model_params['layer3.3.bn2.bias'] = torch.from_numpy(data_dict['bn4d_branch2b']['offset']).type(dtype)

        model_params['layer3.3.conv3.weight'] = torch.from_numpy(data_dict['res4d_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.3.bn3.weight'] = torch.from_numpy(data_dict['bn4d_branch2c']['scale']).type(dtype)
        model_params['layer3.3.bn3.bias'] = torch.from_numpy(data_dict['bn4d_branch2c']['offset']).type(dtype)

        model_params['layer3.4.conv1.weight'] = torch.from_numpy(data_dict['res4e_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.4.bn1.weight'] = torch.from_numpy(data_dict['bn4e_branch2a']['scale']).type(dtype)
        model_params['layer3.4.bn1.bias'] = torch.from_numpy(data_dict['bn4e_branch2a']['offset']).type(dtype)

        model_params['layer3.4.conv2.weight'] = torch.from_numpy(data_dict['res4e_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.4.bn2.weight'] = torch.from_numpy(data_dict['bn4e_branch2b']['scale']).type(dtype)
        model_params['layer3.4.bn2.bias'] = torch.from_numpy(data_dict['bn4e_branch2b']['offset']).type(dtype)

        model_params['layer3.4.conv3.weight'] = torch.from_numpy(data_dict['res4e_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.4.bn3.weight'] = torch.from_numpy(data_dict['bn4e_branch2c']['scale']).type(dtype)
        model_params['layer3.4.bn3.bias'] = torch.from_numpy(data_dict['bn4e_branch2c']['offset']).type(dtype)

        model_params['layer3.5.conv1.weight'] = torch.from_numpy(data_dict['res4f_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.5.bn1.weight'] = torch.from_numpy(data_dict['bn4f_branch2a']['scale']).type(dtype)
        model_params['layer3.5.bn1.bias'] = torch.from_numpy(data_dict['bn4f_branch2a']['offset']).type(dtype)

        model_params['layer3.5.conv2.weight'] = torch.from_numpy(data_dict['res4f_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.5.bn2.weight'] = torch.from_numpy(data_dict['bn4f_branch2b']['scale']).type(dtype)
        model_params['layer3.5.bn2.bias'] = torch.from_numpy(data_dict['bn4f_branch2b']['offset']).type(dtype)

        model_params['layer3.5.conv3.weight'] = torch.from_numpy(data_dict['res4f_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer3.5.bn3.weight'] = torch.from_numpy(data_dict['bn4f_branch2c']['scale']).type(dtype)
        model_params['layer3.5.bn3.bias'] = torch.from_numpy(data_dict['bn4f_branch2c']['offset']).type(dtype)

        model_params['layer4.0.downsample.0.weight'] = torch.from_numpy(data_dict['res5a_branch1']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn5a_branch1']['scale']).type(dtype)
        model_params['layer4.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn5a_branch1']['offset']).type(dtype)

        model_params['layer4.0.conv1.weight'] = torch.from_numpy(data_dict['res5a_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.0.bn1.weight'] = torch.from_numpy(data_dict['bn5a_branch2a']['scale']).type(dtype)
        model_params['layer4.0.bn1.bias'] = torch.from_numpy(data_dict['bn5a_branch2a']['offset']).type(dtype)

        model_params['layer4.0.conv2.weight'] = torch.from_numpy(data_dict['res5a_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.0.bn2.weight'] = torch.from_numpy(data_dict['bn5a_branch2b']['scale']).type(dtype)
        model_params['layer4.0.bn2.bias'] = torch.from_numpy(data_dict['bn5a_branch2b']['offset']).type(dtype)

        model_params['layer4.0.conv3.weight'] = torch.from_numpy(data_dict['res5a_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.0.bn3.weight'] = torch.from_numpy(data_dict['bn5a_branch2c']['scale']).type(dtype)
        model_params['layer4.0.bn3.bias'] = torch.from_numpy(data_dict['bn5a_branch2c']['offset']).type(dtype)

        model_params['layer4.1.conv1.weight'] = torch.from_numpy(data_dict['res5b_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.1.bn1.weight'] = torch.from_numpy(data_dict['bn5b_branch2a']['scale']).type(dtype)
        model_params['layer4.1.bn1.bias'] = torch.from_numpy(data_dict['bn5b_branch2a']['offset']).type(dtype)

        model_params['layer4.1.conv2.weight'] = torch.from_numpy(data_dict['res5b_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.1.bn2.weight'] = torch.from_numpy(data_dict['bn5b_branch2b']['scale']).type(dtype)
        model_params['layer4.1.bn2.bias'] = torch.from_numpy(data_dict['bn5b_branch2b']['offset']).type(dtype)

        model_params['layer4.1.conv3.weight'] = torch.from_numpy(data_dict['res5b_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.1.bn3.weight'] = torch.from_numpy(data_dict['bn5b_branch2c']['scale']).type(dtype)
        model_params['layer4.1.bn3.bias'] = torch.from_numpy(data_dict['bn5b_branch2c']['offset']).type(dtype)

        model_params['layer4.2.conv1.weight'] = torch.from_numpy(data_dict['res5c_branch2a']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.2.bn1.weight'] = torch.from_numpy(data_dict['bn5c_branch2a']['scale']).type(dtype)
        model_params['layer4.2.bn1.bias'] = torch.from_numpy(data_dict['bn5c_branch2a']['offset']).type(dtype)

        model_params['layer4.2.conv2.weight'] = torch.from_numpy(data_dict['res5c_branch2b']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.2.bn2.weight'] = torch.from_numpy(data_dict['bn5c_branch2b']['scale']).type(dtype)
        model_params['layer4.2.bn2.bias'] = torch.from_numpy(data_dict['bn5c_branch2b']['offset']).type(dtype)

        model_params['layer4.2.conv3.weight'] = torch.from_numpy(data_dict['res5c_branch2c']['weights']).type(dtype).permute(3,2,0,1)
        model_params['layer4.2.bn3.weight'] = torch.from_numpy(data_dict['bn5c_branch2c']['scale']).type(dtype)
        model_params['layer4.2.bn3.bias'] = torch.from_numpy(data_dict['bn5c_branch2c']['offset']).type(dtype)

    model_params['conv2.weight'] = torch.from_numpy(data_dict['layer1']['weights']).type(dtype).permute(3,2,0,1)
    #model_params['conv2.bias'] = torch.from_numpy(data_dict['layer1']['biases']).type(dtype)
    model_params['bn2.weight'] = torch.from_numpy(data_dict['layer1_BN']['scale']).type(dtype)
    model_params['bn2.bias'] = torch.from_numpy(data_dict['layer1_BN']['offset']).type(dtype)

    # set True to enable weight import, or set False to initialize by yourself
    if True:

        model_params['up1.conv1_1.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv1_1.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvA']['biases']).type(dtype)

        model_params['up1.conv1_2.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv1_2.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvB']['biases']).type(dtype)

        model_params['up1.conv1_3.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv1_3.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvC']['biases']).type(dtype)

        model_params['up1.conv1_4.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv1_4.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvD']['biases']).type(dtype)

        model_params['up1.bn1_1.weight'] = torch.from_numpy(data_dict['layer2x_br1_BN']['scale']).type(dtype)
        model_params['up1.bn1_1.bias'] = torch.from_numpy(data_dict['layer2x_br1_BN']['offset']).type(dtype)

        model_params['up1.conv2_1.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv2_1.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvA']['biases']).type(dtype)

        model_params['up1.conv2_2.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv2_2.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvB']['biases']).type(dtype)

        model_params['up1.conv2_3.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv2_3.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvC']['biases']).type(dtype)

        model_params['up1.conv2_4.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv2_4.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvD']['biases']).type(dtype)

        model_params['up1.bn1_2.weight'] = torch.from_numpy(data_dict['layer2x_br2_BN']['scale']).type(dtype)
        model_params['up1.bn1_2.bias'] = torch.from_numpy(data_dict['layer2x_br2_BN']['offset']).type(dtype)

        model_params['up1.conv3.weight'] = torch.from_numpy(data_dict['layer2x_Conv']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up1.conv3.bias'] = torch.from_numpy(data_dict['layer2x_Conv']['biases']).type(dtype)

        model_params['up1.bn2.weight'] = torch.from_numpy(data_dict['layer2x_BN']['scale']).type(dtype)
        model_params['up1.bn2.bias'] = torch.from_numpy(data_dict['layer2x_BN']['offset']).type(dtype)

        model_params['up2.conv1_1.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv1_1.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvA']['biases']).type(dtype)

        model_params['up2.conv1_2.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv1_2.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvB']['biases']).type(dtype)

        model_params['up2.conv1_3.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv1_3.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvC']['biases']).type(dtype)

        model_params['up2.conv1_4.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv1_4.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvD']['biases']).type(dtype)

        model_params['up2.bn1_1.weight'] = torch.from_numpy(data_dict['layer4x_br1_BN']['scale']).type(dtype)
        model_params['up2.bn1_1.bias'] = torch.from_numpy(data_dict['layer4x_br1_BN']['offset']).type(dtype)

        model_params['up2.conv2_1.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv2_1.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvA']['biases']).type(dtype)

        model_params['up2.conv2_2.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv2_2.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvB']['biases']).type(dtype)

        model_params['up2.conv2_3.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv2_3.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvC']['biases']).type(dtype)

        model_params['up2.conv2_4.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv2_4.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvD']['biases']).type(dtype)

        model_params['up2.bn1_2.weight'] = torch.from_numpy(data_dict['layer4x_br2_BN']['scale']).type(dtype)
        model_params['up2.bn1_2.bias'] = torch.from_numpy(data_dict['layer4x_br2_BN']['offset']).type(dtype)

        model_params['up2.conv3.weight'] = torch.from_numpy(data_dict['layer4x_Conv']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up2.conv3.bias'] = torch.from_numpy(data_dict['layer4x_Conv']['biases']).type(dtype)

        model_params['up2.bn2.weight'] = torch.from_numpy(data_dict['layer4x_BN']['scale']).type(dtype)
        model_params['up2.bn2.bias'] = torch.from_numpy(data_dict['layer4x_BN']['offset']).type(dtype)

        model_params['up3.conv1_1.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv1_1.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvA']['biases']).type(dtype)

        model_params['up3.conv1_2.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv1_2.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvB']['biases']).type(dtype)

        model_params['up3.conv1_3.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv1_3.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvC']['biases']).type(dtype)

        model_params['up3.conv1_4.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv1_4.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvD']['biases']).type(dtype)

        model_params['up3.bn1_1.weight'] = torch.from_numpy(data_dict['layer8x_br1_BN']['scale']).type(dtype)
        model_params['up3.bn1_1.bias'] = torch.from_numpy(data_dict['layer8x_br1_BN']['offset']).type(dtype)

        model_params['up3.conv2_1.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv2_1.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvA']['biases']).type(dtype)

        model_params['up3.conv2_2.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv2_2.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvB']['biases']).type(dtype)

        model_params['up3.conv2_3.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv2_3.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvC']['biases']).type(dtype)

        model_params['up3.conv2_4.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv2_4.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvD']['biases']).type(dtype)

        model_params['up3.bn1_2.weight'] = torch.from_numpy(data_dict['layer8x_br2_BN']['scale']).type(dtype)
        model_params['up3.bn1_2.bias'] = torch.from_numpy(data_dict['layer8x_br2_BN']['offset']).type(dtype)

        model_params['up3.conv3.weight'] = torch.from_numpy(data_dict['layer8x_Conv']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up3.conv3.bias'] = torch.from_numpy(data_dict['layer8x_Conv']['biases']).type(dtype)

        model_params['up3.bn2.weight'] = torch.from_numpy(data_dict['layer8x_BN']['scale']).type(dtype)
        model_params['up3.bn2.bias'] = torch.from_numpy(data_dict['layer8x_BN']['offset']).type(dtype)

        model_params['up4.conv1_1.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv1_1.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvA']['biases']).type(dtype)

        model_params['up4.conv1_2.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv1_2.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvB']['biases']).type(dtype)

        model_params['up4.conv1_3.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv1_3.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvC']['biases']).type(dtype)

        model_params['up4.conv1_4.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv1_4.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvD']['biases']).type(dtype)

        model_params['up4.bn1_1.weight'] = torch.from_numpy(data_dict['layer16x_br1_BN']['scale']).type(dtype)
        model_params['up4.bn1_1.bias'] = torch.from_numpy(data_dict['layer16x_br1_BN']['offset']).type(dtype)

        model_params['up4.conv2_1.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv2_1.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvA']['biases']).type(dtype)

        model_params['up4.conv2_2.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv2_2.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvB']['biases']).type(dtype)

        model_params['up4.conv2_3.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv2_3.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvC']['biases']).type(dtype)

        model_params['up4.conv2_4.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv2_4.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvD']['biases']).type(dtype)

        model_params['up4.bn1_2.weight'] = torch.from_numpy(data_dict['layer16x_br2_BN']['scale']).type(dtype)
        model_params['up4.bn1_2.bias'] = torch.from_numpy(data_dict['layer16x_br2_BN']['offset']).type(dtype)

        model_params['up4.conv3.weight'] = torch.from_numpy(data_dict['layer16x_Conv']['weights']).type(dtype).permute(3,2,0,1)
        model_params['up4.conv3.bias'] = torch.from_numpy(data_dict['layer16x_Conv']['biases']).type(dtype)

        model_params['up4.bn2.weight'] = torch.from_numpy(data_dict['layer16x_BN']['scale']).type(dtype)
        model_params['up4.bn2.bias'] = torch.from_numpy(data_dict['layer16x_BN']['offset']).type(dtype)

        model_params['conv3.weight'] = torch.from_numpy(data_dict['ConvPred']['weights']).type(dtype).permute(3,2,0,1)
        model_params['conv3.bias'] = torch.from_numpy(data_dict['ConvPred']['biases']).type(dtype)

    print('Up proj weights loaded!!!!!!!!!!!')
    return model_params


In [None]:
main()

Loading data......
<torch.utils.data.dataloader.DataLoader object at 0x7f6fc9c059c0>
Loading model......
resnet50 loaded.
Up proj weights loaded!!!!!!!!!!!
resnet50_pretrained_dict loaded.
model_dict updated.
model_dict loaded.
loss_fn set.
val_error before train: 4.300591468811035
optimizer set.
Starting train epoch 1 / 50
loss: 0.8021764755249023
loss: 0.8095851540565491
loss: 0.7461864948272705
loss: 0.6649988889694214
loss: 0.5604764223098755
loss: 0.4809028208255768
loss: 0.5316054821014404
loss: 0.33954644203186035
loss: 0.5962244272232056
loss: 0.5173604488372803
loss: 0.2292918562889099
loss: 0.2513006329536438
loss: 0.4189075231552124
loss: 0.3184768855571747
loss: 0.19435520470142365
loss: 0.5204819440841675
loss: 0.3303290605545044
loss: 0.28123360872268677
loss: 0.35343989729881287
epoch loss: 0.47088840368546936
val_error: 4.136179447174072
optimizer set.
Starting train epoch 2 / 50
loss: 0.23385798931121826
loss: 0.2538354992866516
loss: 0.1723986715078354
loss: 0.2104881

In [None]:
print(len(epoch_losses))
print(len(err_lst))

50
50


In [None]:
print(epoch_losses)

[0.47088840368546936, 0.2703721107620942, 0.23238963123999143, 0.24455989112979487, 0.22935021943167636, 0.20927249993148603, 0.20413699667704732, 0.19815596310715927, 0.17747752603731656, 0.15927156374642723, 0.18164423500236712, 0.17809604422042244, 0.16964208334684372, 0.16482967648066973, 0.15877990032497205, 0.16930784520350003, 0.16052783045329547, 0.1514621976959078, 0.14868785440921783, 0.13562150770112089, 0.16305754020025856, 0.15672702342271805, 0.14434671480404704, 0.13467198806373695, 0.15108832873796163, 0.14295380091980883, 0.13535889434187034, 0.13543948531150818, 0.15565641145957143, 0.14069202071742007, 0.14724579884817726, 0.12767035161194049, 0.13948338282735726, 0.1313298332848047, 0.13560796176132403, 0.1327349692583084, 0.1312339407832999, 0.12591311806126645, 0.12934349438077525, 0.13701279030034416, 0.13714320918447093, 0.1445345913893298, 0.1325588692959986, 0.13433271411218142, 0.12687117135838458, 0.12748802335638748, 0.12642552860473333, 0.12590271744288897

In [None]:
print(err_lst)

[4.136179447174072, 4.605781555175781, 3.94053316116333, 3.6912832260131836, 3.9037423133850098, 1.1530299186706543, 0.15613864362239838, 0.12989526987075806, 0.14757953584194183, 0.15619289875030518, 0.14398431777954102, 0.15253213047981262, 0.1655038744211197, 0.14035269618034363, 0.13756246864795685, 0.13392946124076843, 0.13890627026557922, 0.13872723281383514, 0.14875715970993042, 0.13603226840496063, 0.1402978003025055, 0.13791774213314056, 0.14268678426742554, 0.1487409770488739, 0.1336924433708191, 0.13575014472007751, 0.1376960128545761, 0.13666222989559174, 0.13965001702308655, 0.14386014640331268, 0.13212236762046814, 0.1461818516254425, 0.12969475984573364, 0.14443430304527283, 0.13757197558879852, 0.1325635313987732, 0.14857134222984314, 0.14261430501937866, 0.15052814781665802, 0.1363501399755478, 0.13397739827632904, 0.14278623461723328, 0.13513696193695068, 0.13344039022922516, 0.13782653212547302, 0.14137887954711914, 0.14705398678779602, 0.14330258965492249, 0.1354684

In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_losses, [i for i in range(1,51)])
plt.xlabel("MSE loss each epoch")
plt.ylabel("epochs")
plot.show()