In [None]:
!nvidia-smi

Wed Apr 21 06:48:54 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
cd "/content/gdrive/My Drive/Colab Notebooks"

/content/gdrive/My Drive/Colab Notebooks


DataLoader载入数据

In [None]:
from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, Grayscale

import random
import math
from torch.autograd import Variable
import torch

import torchvision.transforms as transforms

# gray = transforms.Gray()
import numpy as np

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', 'bmp', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, blocksize):
    return crop_size - (crop_size % blocksize)


def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        RandomHorizontalFlip(p=0.5),
        RandomVerticalFlip(p=0.5),
        Grayscale(),
        ToTensor(),
    ])



def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse < 1.0e-10:
        return 100
    PIXEL_MAX = 1.0
    return 20 * math.log10(PIXEL_MAX/math.sqrt(mse))


class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, blocksize):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, blocksize)
        self.hr_transform = train_hr_transform(crop_size)

    def __getitem__(self, index):
        try:
            hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
            return hr_image, hr_image
        except:
            hr_image = self.hr_transform(Image.open(self.image_filenames[index+1]))
            return hr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, blocksize, crop_size=96):
        super(TestDatasetFromFolder, self).__init__()
        self.blocksize = blocksize
        self.high_res_length = crop_size
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        self.test_compose = Compose([
            CenterCrop(crop_size),
            Grayscale(),
            ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])

        hr_image = self.test_compose(hr_image)
            
        return hr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)



AttentionLayer(以Self_Atten为主)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

"""Attention Layer"""
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1) 
        atten = y.view(b, c,)
        return x * y.expand_as(x), atten

#暂时不成熟，有一些bug
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_dim, activation, blocksize=32):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.blocksize = blocksize

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        # #bug聚集地
        # self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # self.linear = nn.Linear(in_features=in_dim, out_features=in_dim*blocksize*blocksize,bias=True)
        # #end 
        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        '''bug聚集地'''
        # atten = self.avg_pool(out).view(m_batchsize,C)
        # atten = self.linear(atten)
        # atten = atten.view(m_batchsize, C, self.blocksize, self.blocksize)
        return out  # B * C * blocksize * blocksize

In [None]:
"""参数化，区别是否需要反向传播更新"""
def to_var(x, requires_grad=False, volatile=False):
    """
    Varialbe type that automatically choose cpu or cuda
    """
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=requires_grad, volatile=volatile)

#+-1
# A的静态方法
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        '''
        #******************soft_ste*****************
        size = input.size()
        zeros = torch.zeros(size).cuda()
        grad = torch.max(zeros, 1 - torch.abs(input))
        #print(grad)
        grad_input = grad_output * grad
        '''
        return grad_input
# W 权值静态方法
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input
# +-1、0
class Ternary(Function):

    @staticmethod
    def forward(self, input):
        # 通道级别的求平均值  
        E = torch.mean(torch.abs(input), (3, 2, 1), keepdim=True)
        # Threshold
        threshold = E * 0.7
        # 三值化
        output = torch.sign(torch.add(torch.sign(torch.add(input, threshold)),torch.sign(torch.add(input, -threshold))))
        return output, threshold

    @staticmethod
    def backward(self, grad_output, grad_threshold):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# 激活函数二值(暂时无用)
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 1、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def ternary(self, input):
    output = Ternary.apply(input)
    return output

  def forward(self, input):
    if self.W == 2 or self.W == 3:
        # 权值二值
        if self.W == 2:
            # output = meancenter_clampConvParams(input) # W中心化+截断
            output = input
            # E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
            # 缩放因子
            # alpha = E
            # 二值化
            output = self.binary(output)
            
            #output = output * alpha # 若不需要缩放因子，注释
            # 权值三值
        elif self.W == 3:
            output_fp = input.clone()
            # 三值化
            output, threshold = self.ternary(input)
            # 缩放因子
            output_abs = torch.abs(output_fp)
            mask_le = output_abs.le(threshold)
            mask_gt = output_abs.gt(threshold)
            output_abs[mask_le] = 0
            output_abs_th = output_abs.clone()
            output_abs_th_sum = torch.sum(output_abs_th, (3, 2, 1), keepdim=True)
            mask_gt_sum = torch.sum(mask_gt, (3, 2, 1), keepdim=True).float()
            alpha = output_abs_th_sum / mask_gt_sum # alpha(缩放因子)
            
            # output = output * alpha # 若不需要缩放因子，注释
    else:
      output = input
    return output

# 卷积(对外暴露的调用类)
class Conv2d_Q(nn.Conv2d):  #_Atten_Pruning
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=False,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        # self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
        self.mask_flag = False

    def set_mask(self, mask):
        self.mask = mask
        self.weight.data = self.weight.data * self.mask.data
        self.mask_flag = True

    def get_mask(self):
        print(self.mask_flag)
        return self.mask
          
    def forward(self, input):
        # 量化A和W
        # bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        if self.mask_flag == True:
            tnn_bin_weight = tnn_bin_weight * self.mask

        # print(tnn_bin_weight[0][0][0][:])
        # 用量化后的A和W做卷积
        output = F.conv2d(
            input=input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

In [None]:
class csPart(nn.Module):
    def __init__(self, blocksize=32, subrate=0.3, channel=1):
        super(csPart, self).__init__()
        #sampling
        self.sampling = Conv2d_Q(in_channels=channel, out_channels=int(np.round(blocksize*blocksize*subrate*channel)), kernel_size=blocksize, stride=blocksize, padding=0, bias=False, W=2)
        self.real_atten = nn.Conv2d(in_channels=channel, out_channels=int(np.round(blocksize*blocksize*subrate*channel)), kernel_size=blocksize, stride=blocksize, padding=0, bias=False)
        self.gamma = nn.Parameter(torch.zeros(1))
        # self.sampling = nn.Conv2d(1, int(np.round(blocksize*blocksize*subrate)), blocksize, stride=blocksize, padding=0, bias=False)
        self.atten = Self_Attn(in_dim=int(np.round(blocksize*blocksize*subrate*channel)), activation=False, blocksize=blocksize) #SELayer(channel=int(np.round(blocksize*blocksize*subrate*channel)), reduction=16)
        #init reconstruction
        # self.upsampling = nn.Conv2d(int(np.round(blocksize*blocksize*subrate*channel)), blocksize*blocksize, 1, stride=1, padding=0)
        

    def forward(self, input):
        x = self.sampling(input)
        x = self.gamma*x + x
        atten_matrix = self.real_atten(input)
        atten_matrix = self.gamma*atten_matrix + atten_matrix
        after_atten_x = self.atten(atten_matrix)
        # x = self.upsampling(after_atten_x)
        return after_atten_x
    
class targetPart(nn.Module):
    def __init__(self, blocksize=32, subrate=0.3, channel=1):
        super(targetPart, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))#torch.Tensor([-0.3960])
        self.real_atten = nn.Conv2d(in_channels=channel, out_channels=int(np.round(blocksize*blocksize*subrate*channel)), kernel_size=blocksize, stride=blocksize, padding=0, bias=False)
    def forward(self, input):
        x = self.real_atten(input)
        x = self.gamma*x + x
        return x


class sconv2d(nn.Module):
    def __init__(self, channels=64, outchannels=64):
        super(sconv2d, self).__init__()
        self.channels = channels
        self.separable_conv2d = nn.Sequential(
            nn.Conv2d(in_channels=channels,
                      out_channels=channels,
                      kernel_size=(1, 1), stride=1, padding=0,
                      groups=channels),
            nn.Conv2d(in_channels=channels, out_channels=outchannels, kernel_size=(1, 1), padding=0),
            # nn.LeakyReLU()
        )
        # self.relu = nn.LeakyReLU()
    
    def forward(self, input):
        x = self.separable_conv2d(input)
        # x = self.relu(x)
        return x


class baseblock(nn.Module):
    def __init__(self, channels=64):
        super(baseblock, self).__init__()
        
        self.conv1 = nn.Sequential(
            # nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=1, padding=1, bias=True),
            sconv2d(channels=channels),
            nn.LeakyReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            # nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=1, padding=1, bias=True),
            sconv2d(channels=channels),
            nn.LeakyReLU(inplace=True)
        )
        # self.relu = nn.LeakyReLU()
    def forward(self, input):
        # x = self.relu(input)
        x = self.conv1(input)
        x = x + input
        x = self.conv2(x)
        return x


class DenseBlock(nn.Module):
    def __init__(self, conv_in, conv_out, k_size, beta=0.2):
        super(DenseBlock, self).__init__()

        self.res1 = nn.Sequential(
            # nn.Conv2d(conv_in, conv_out, kernel_size=(k_size, k_size), stride=1, padding=1),
            sconv2d(channels=conv_in, outchannels=conv_out),
            nn.LeakyReLU(inplace=True)
        )

        self.res2 = nn.Sequential(
            # nn.Conv2d(conv_in*2, conv_out, kernel_size=(k_size, k_size), stride=1, padding=1),
            sconv2d(channels=conv_in*2, outchannels=conv_out),
            nn.LeakyReLU(inplace=True)
        )

        self.res3 = nn.Sequential(
            # nn.Conv2d(conv_in*3, conv_out, kernel_size=(k_size, k_size), stride=1, padding=1),
            sconv2d(channels=conv_in*3, outchannels=conv_out),
            nn.LeakyReLU(inplace=True)
        )

        self.res4 = nn.Sequential(
            # nn.Conv2d(conv_in*4, conv_out, kernel_size=(k_size, k_size), stride=1, padding=1),
            sconv2d(channels=conv_in*4, outchannels=conv_out),
            # nn.LeakyReLU(inplace=True),
        )

        # self.res5 = nn.Sequential(
        #     nn.Conv2d(conv_in*5, conv_out, kernel_size=(k_size, k_size), stride=1, padding=1),
        #     sconv2d(channels=conv_in*3, outchannels=conv_out),
        # )
        self.beta = beta

    def forward(self, input):
        x = input
        #feature size = convin*2
        result = self.res1(x)
        x = torch.cat([x, result], 1)

        result = self.res2(x)
        # print(x.shape,result.shape)
        x = torch.cat([x, result], 1)
        

        result = self.res3(x)
        x = torch.cat([x, result], 1)

        x = self.res4(x)
        # x = torch.cat([x, result], 1)

        # x = self.res5(x)

        output = x.mul(self.beta)
        return output + input

class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, conv_in=64, k_size=3, beta=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()

        self.dense1 = DenseBlock(conv_in, conv_in, k_size)
        self.dense2 = DenseBlock(conv_in, conv_in, k_size)
        self.dense3 = DenseBlock(conv_in, conv_in, k_size)
        self.beta = beta

    def forward(self, input):
        x = self.dense1(input)
        x = self.dense2(x)
        x = self.dense3(x)
        output = x.mul(self.beta)
        return output + input


In [None]:
class CSNetPlus(nn.Module):
    def __init__(self,blocksize=32, subrate=0.2, channels=1):
        super(CSNetPlus, self).__init__()
        self.blocksize = blocksize
        self.subrate = subrate
        self.channels = channels
        n_baseblock = 32
        outchannels = 64

        self.csPart = csPart(blocksize, subrate, channels)
        
        

    def forward(self, input):
        x = self.csPart(input)
        
        return x

Train

In [9]:
import argparse, os
import torch
import torchvision
import math, random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
# from tqdm import tqdm
from tqdm import tqdm_notebook as tqdm
from torchvision import models
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
from PIL import Image


# Training settings
parser = argparse.ArgumentParser(description="PyTorch SRResNet")
parser.add_argument("--batchSize", type=int, default=1024, help="training batch size")
parser.add_argument("--nEpochs", type=int, default=900, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4")
parser.add_argument("--step", type=int, default=200, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
parser.add_argument("--cuda", action="store_true", help="Use cuda?")
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--threads", type=int, default=4, help="Number of threads for data loader to use, Default: 1")
parser.add_argument("--bin_model", default="", type=str, help="path to bin_model model (default: none)")
parser.add_argument("--vgg_loss", action="store_true", help="Use content loss?")
parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")

def main():

    global opt, model, netContent
    input_args = []
    opt = parser.parse_args([])
    # print(opt)
    channel = 1
    opt.vgg_loss = True if channel==3 else False
    opt.cuda = True
    opt.bin_model = "saved_models/0.5/ten_db3_10_31.43.pth" #
    # opt.bin_model = "CSNetPlus_model_t/0.2_888.6.pth"  
    load_model_path = ""
    # load_model_path = "/content/checkpoint/model_epoch_101_.pth"
    # load_model_path = "CSNetPlus_model_mask/12000.pth"
    opt.lr = 0.01
    scheduler_step_size=50
    scheduler_gamma=0.2
    subrate = 0.5
    cuda = opt.cuda
    if cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
        if not torch.cuda.is_available():
                raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed) #5868  7434
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    # train_set = DatasetFromHdf5("/path/to/your/hdf5/data/like/rgb_srresnet_x4.h5")
    train_data_val2017 = 'dataSet/val2017'
    train_data_BSDS500 = 'dataSet/BSDS500/data/images/train'

    train_path = train_data_val2017
    train_set = TrainDatasetFromFolder(train_path, crop_size=96, blocksize=32)
    test_set = TestDatasetFromFolder("dataSet/Test/Set14/", blocksize=32)
    # train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=opt.batchSize, shuffle=True)

    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, \
        batch_size=opt.batchSize, shuffle=True)
    testing_data_loader = DataLoader(test_set, batch_size=16, shuffle=False, num_workers=0)

    if opt.vgg_loss:
        print('===> Loading VGG model')
        netVGG = models.vgg19()
        netVGG.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))
        class _content_model(nn.Module):
            def __init__(self):
                super(_content_model, self).__init__()
                self.feature = nn.Sequential(*list(netVGG.features.children())[:-1])
                
            def forward(self, x):
                out = self.feature(x)
                return out

        netContent = _content_model()

    print("===> Building model")
    model_val = CSNetPlus(blocksize=32, subrate=subrate, channels=channel)
    weights = torch.load(opt.bin_model, map_location=torch.device('cpu'))
    bin_model_dict = weights['model'].state_dict()
    model_dict = model_val.state_dict()
    # 筛除不加载的层结构
    bin_model_dict = {k: v for k, v in bin_model_dict.items() if k in model_dict}
    # 更新当前网络的结构字典
    model_dict.update(bin_model_dict)
    model_val.load_state_dict(model_dict)

    for p in model_val.parameters():
        p.requires_grad = False
    print(model_val.csPart.gamma)
    model = targetPart(blocksize=32, subrate=subrate, channel=channel)
    model.gamma = model_val.csPart.gamma#.item()
    model.gamma.requires_grad = False
    # criterion = nn.MSELoss(size_average=False)
    criterion = nn.MSELoss(size_average=True)
    print("===> Setting GPU")
    if cuda:
        model = model.cuda()
        model_val = model_val.cuda()
        criterion = criterion.cuda()
        if opt.vgg_loss:
            netContent = netContent.cuda() 

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # optionally copy weights from a checkpoint
    if load_model_path:
        if os.path.isfile(load_model_path):
            print("=> loading model '{}'".format(load_model_path))
            weights = torch.load(load_model_path)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(load_model_path))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)

    print("===> Training")
    now_min_loss = float('inf')
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        epoch_loss = train(training_data_loader, optimizer, scheduler, model, model_val, criterion, epoch, testing_data_loader)
        print(epoch_loss)
        save_checkpoint("/content/checkpoint/", model, epoch ,0)
        if epoch%5 == 0 and epoch_loss < 0.006 and epoch_loss < now_min_loss:
            save_checkpoint("/content/gdrive/My Drive/Colab Notebooks/CSNetPlus_model_mask/",model,epoch,round(epoch_loss,5))
            now_min_loss = epoch_loss

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10"""
    lr = opt.lr * (0.1 ** (epoch // opt.step))
    return lr 

def train(training_data_loader, optimizer, scheduler, model, model_val, criterion, epoch, testing_data_loader):
    
    train_bar = tqdm(training_data_loader)
    running_results = {'batch_sizes': 0, 'loss': 0, }
    # lr = adjust_learning_rate(optimizer, epoch-1)
    # for param_group in optimizer.param_groups:
    #     param_group["lr"] = lr
    model.train()
    scheduler.step()
    # print(model.gamma)
    # for iteration, batch in enumerate(training_data_loader, 1):
    # print("a")
    for data, target in train_bar:
        batch_size = data.size(0)
        # print("b")
        if batch_size <= 0:
            continue
        # print("c")
        running_results['batch_sizes'] += batch_size
        real_img = Variable(target)
        if opt.cuda:
            real_img = real_img.cuda()

        z = Variable(data)
        if opt.cuda:
            z = z.cuda()
        fake_img = model(z)
        real_img = model_val(z)

        loss = criterion(fake_img, real_img)

        if opt.vgg_loss:
            content_input = netContent(fake_img)
            content_target = netContent(real_img)
            content_target = content_target.detach()
            content_loss = criterion(content_input, content_target)

        optimizer.zero_grad()

        if opt.vgg_loss:
            netContent.zero_grad()
            content_loss.backward(retain_graph=True)

        loss.backward()

        optimizer.step()
        running_results['loss'] += loss.item() * batch_size

        train_bar.set_description(desc='[%d] Loss: %.4f lr: %.7f' % (
            epoch, running_results['loss'] / running_results['batch_sizes'], optimizer.param_groups[0]['lr']))
    
    # show samples
    # low_res_sample, high_res_sample = next(iter(testing_data_loader))
    # idx = np.random.randint(0, 14, 1)
    # fake_image = model(low_res_sample[idx].cuda())
    # fake_image = fake_image.cpu().detach()
    # ground_truth = high_res_sample[idx]
    # image_grid = torchvision.utils.make_grid([fake_image[0], ground_truth[0]], nrow=2, normalize=True)
    # _, plot = plt.subplots(figsize=(12, 12))
    # plt.axis('off')
    # plot.imshow(image_grid.permute(1, 2, 0))
    # # plot.show(image_grid.permute(1, 2, 0))
    # plt.savefig("CSNetPlus_CheckpointImage_Full" + '/epoch_{}_checkpoint.jpg'.format(epoch), bbox_inches='tight')
    return running_results['loss'] / running_results['batch_sizes']


def save_checkpoint(path, model, epoch, epoch_loss=0):
    model_out_path = path + "model_epoch_{}_{}.pth".format(epoch, "" if epoch_loss==0 else epoch_loss)
    state = {"epoch": epoch ,"model": model ,"atten_matrix":model.real_atten.weight }
    if not os.path.exists(path):
        os.makedirs(path)

    torch.save(state, model_out_path)

    print("Checkpoint saved to {}".format(model_out_path))

if __name__ == "__main__":
    main()


0.006793432869017124
Checkpoint saved to /content/checkpoint/model_epoch_737_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007084176875650883
Checkpoint saved to /content/checkpoint/model_epoch_738_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006625528447329998
Checkpoint saved to /content/checkpoint/model_epoch_739_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006922130472958088
Checkpoint saved to /content/checkpoint/model_epoch_740_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006557442247867584
Checkpoint saved to /content/checkpoint/model_epoch_741_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006778225768357515
Checkpoint saved to /content/checkpoint/model_epoch_742_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006501179654151201
Checkpoint saved to /content/checkpoint/model_epoch_743_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006961490958929062
Checkpoint saved to /content/checkpoint/model_epoch_744_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006978743709623814
Checkpoint saved to /content/checkpoint/model_epoch_745_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006463638041168451
Checkpoint saved to /content/checkpoint/model_epoch_746_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006586454808712006
Checkpoint saved to /content/checkpoint/model_epoch_747_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006652457639575005
Checkpoint saved to /content/checkpoint/model_epoch_748_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006579251028597355
Checkpoint saved to /content/checkpoint/model_epoch_749_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0069098928943276405
Checkpoint saved to /content/checkpoint/model_epoch_750_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006732426583766937
Checkpoint saved to /content/checkpoint/model_epoch_751_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007018225267529488
Checkpoint saved to /content/checkpoint/model_epoch_752_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006851258687674999
Checkpoint saved to /content/checkpoint/model_epoch_753_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006799377501010895
Checkpoint saved to /content/checkpoint/model_epoch_754_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007017404306679964
Checkpoint saved to /content/checkpoint/model_epoch_755_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007263494655489922
Checkpoint saved to /content/checkpoint/model_epoch_756_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006551843602210283
Checkpoint saved to /content/checkpoint/model_epoch_757_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006738689262419939
Checkpoint saved to /content/checkpoint/model_epoch_758_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007048619911074638
Checkpoint saved to /content/checkpoint/model_epoch_759_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006664409302175045
Checkpoint saved to /content/checkpoint/model_epoch_760_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006978254299610853
Checkpoint saved to /content/checkpoint/model_epoch_761_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0068546864204108715
Checkpoint saved to /content/checkpoint/model_epoch_762_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006394478492438793
Checkpoint saved to /content/checkpoint/model_epoch_763_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0068376800045371056
Checkpoint saved to /content/checkpoint/model_epoch_764_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0066493116319179535
Checkpoint saved to /content/checkpoint/model_epoch_765_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007141037378460169
Checkpoint saved to /content/checkpoint/model_epoch_766_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006589914672076702
Checkpoint saved to /content/checkpoint/model_epoch_767_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006575440987944603
Checkpoint saved to /content/checkpoint/model_epoch_768_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006589259020984173
Checkpoint saved to /content/checkpoint/model_epoch_769_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006808034610003233
Checkpoint saved to /content/checkpoint/model_epoch_770_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006782702170312405
Checkpoint saved to /content/checkpoint/model_epoch_771_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006811609026044607
Checkpoint saved to /content/checkpoint/model_epoch_772_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00688254414126277
Checkpoint saved to /content/checkpoint/model_epoch_773_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006892540957778692
Checkpoint saved to /content/checkpoint/model_epoch_774_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006381948944181204
Checkpoint saved to /content/checkpoint/model_epoch_775_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067244418896734715
Checkpoint saved to /content/checkpoint/model_epoch_776_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006531162187457085
Checkpoint saved to /content/checkpoint/model_epoch_777_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006826044525951147
Checkpoint saved to /content/checkpoint/model_epoch_778_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006711797323077917
Checkpoint saved to /content/checkpoint/model_epoch_779_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006718799471855164
Checkpoint saved to /content/checkpoint/model_epoch_780_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006593038327991962
Checkpoint saved to /content/checkpoint/model_epoch_781_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006815225351601839
Checkpoint saved to /content/checkpoint/model_epoch_782_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006649625487625599
Checkpoint saved to /content/checkpoint/model_epoch_783_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006932128686457872
Checkpoint saved to /content/checkpoint/model_epoch_784_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0069266739301383495
Checkpoint saved to /content/checkpoint/model_epoch_785_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00674635311588645
Checkpoint saved to /content/checkpoint/model_epoch_786_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007221365347504616
Checkpoint saved to /content/checkpoint/model_epoch_787_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006856655236333609
Checkpoint saved to /content/checkpoint/model_epoch_788_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0069531286135315895
Checkpoint saved to /content/checkpoint/model_epoch_789_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00650915177538991
Checkpoint saved to /content/checkpoint/model_epoch_790_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00664856331422925
Checkpoint saved to /content/checkpoint/model_epoch_791_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006724754814058542
Checkpoint saved to /content/checkpoint/model_epoch_792_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00667242705821991
Checkpoint saved to /content/checkpoint/model_epoch_793_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006560597103089094
Checkpoint saved to /content/checkpoint/model_epoch_794_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006833352148532867
Checkpoint saved to /content/checkpoint/model_epoch_795_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006750509142875671
Checkpoint saved to /content/checkpoint/model_epoch_796_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006595691200345755
Checkpoint saved to /content/checkpoint/model_epoch_797_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006945295259356499
Checkpoint saved to /content/checkpoint/model_epoch_798_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006974311079829931
Checkpoint saved to /content/checkpoint/model_epoch_799_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067968824878335
Checkpoint saved to /content/checkpoint/model_epoch_800_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0065943459048867226
Checkpoint saved to /content/checkpoint/model_epoch_801_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006553881801664829
Checkpoint saved to /content/checkpoint/model_epoch_802_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006739153992384672
Checkpoint saved to /content/checkpoint/model_epoch_803_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007003439590334892
Checkpoint saved to /content/checkpoint/model_epoch_804_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007103412412106991
Checkpoint saved to /content/checkpoint/model_epoch_805_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006655183155089617
Checkpoint saved to /content/checkpoint/model_epoch_806_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006590808741748333
Checkpoint saved to /content/checkpoint/model_epoch_807_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006743184290826321
Checkpoint saved to /content/checkpoint/model_epoch_808_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006984792649745941
Checkpoint saved to /content/checkpoint/model_epoch_809_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006833663675934076
Checkpoint saved to /content/checkpoint/model_epoch_810_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007041651755571365
Checkpoint saved to /content/checkpoint/model_epoch_811_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006875952240079641
Checkpoint saved to /content/checkpoint/model_epoch_812_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006786395329982042
Checkpoint saved to /content/checkpoint/model_epoch_813_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006684118416160345
Checkpoint saved to /content/checkpoint/model_epoch_814_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006764323450624943
Checkpoint saved to /content/checkpoint/model_epoch_815_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006877973675727844
Checkpoint saved to /content/checkpoint/model_epoch_816_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006922882050275803
Checkpoint saved to /content/checkpoint/model_epoch_817_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006725771352648735
Checkpoint saved to /content/checkpoint/model_epoch_818_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006942263338714838
Checkpoint saved to /content/checkpoint/model_epoch_819_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006581514608114958
Checkpoint saved to /content/checkpoint/model_epoch_820_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006771516520529985
Checkpoint saved to /content/checkpoint/model_epoch_821_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007213970646262169
Checkpoint saved to /content/checkpoint/model_epoch_822_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067020803689956665
Checkpoint saved to /content/checkpoint/model_epoch_823_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006915762554854155
Checkpoint saved to /content/checkpoint/model_epoch_824_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0070911189541220665
Checkpoint saved to /content/checkpoint/model_epoch_825_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006794577930122614
Checkpoint saved to /content/checkpoint/model_epoch_826_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006860753521323204
Checkpoint saved to /content/checkpoint/model_epoch_827_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006776328198611736
Checkpoint saved to /content/checkpoint/model_epoch_828_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006804990582168102
Checkpoint saved to /content/checkpoint/model_epoch_829_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006888412404805422
Checkpoint saved to /content/checkpoint/model_epoch_830_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006572061683982611
Checkpoint saved to /content/checkpoint/model_epoch_831_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006796070374548435
Checkpoint saved to /content/checkpoint/model_epoch_832_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006713916547596455
Checkpoint saved to /content/checkpoint/model_epoch_833_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006613388657569885
Checkpoint saved to /content/checkpoint/model_epoch_834_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00646093999966979
Checkpoint saved to /content/checkpoint/model_epoch_835_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007043809629976749
Checkpoint saved to /content/checkpoint/model_epoch_836_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006686200387775898
Checkpoint saved to /content/checkpoint/model_epoch_837_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007068253587931395
Checkpoint saved to /content/checkpoint/model_epoch_838_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006487398408353329
Checkpoint saved to /content/checkpoint/model_epoch_839_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006443660240620375
Checkpoint saved to /content/checkpoint/model_epoch_840_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00680741760879755
Checkpoint saved to /content/checkpoint/model_epoch_841_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006782348267734051
Checkpoint saved to /content/checkpoint/model_epoch_842_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006562987808138132
Checkpoint saved to /content/checkpoint/model_epoch_843_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006704967003315687
Checkpoint saved to /content/checkpoint/model_epoch_844_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006524764932692051
Checkpoint saved to /content/checkpoint/model_epoch_845_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006712827831506729
Checkpoint saved to /content/checkpoint/model_epoch_846_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006371133029460907
Checkpoint saved to /content/checkpoint/model_epoch_847_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006649642717093229
Checkpoint saved to /content/checkpoint/model_epoch_848_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006841589696705341
Checkpoint saved to /content/checkpoint/model_epoch_849_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006842087022960186
Checkpoint saved to /content/checkpoint/model_epoch_850_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006559121888130903
Checkpoint saved to /content/checkpoint/model_epoch_851_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00660333689302206
Checkpoint saved to /content/checkpoint/model_epoch_852_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067625208757817745
Checkpoint saved to /content/checkpoint/model_epoch_853_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006854780949652195
Checkpoint saved to /content/checkpoint/model_epoch_854_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006538321264088154
Checkpoint saved to /content/checkpoint/model_epoch_855_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0065682074055075645
Checkpoint saved to /content/checkpoint/model_epoch_856_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006891417782753706
Checkpoint saved to /content/checkpoint/model_epoch_857_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006684223189949989
Checkpoint saved to /content/checkpoint/model_epoch_858_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067443205043673515
Checkpoint saved to /content/checkpoint/model_epoch_859_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006947016343474388
Checkpoint saved to /content/checkpoint/model_epoch_860_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006765610072761774
Checkpoint saved to /content/checkpoint/model_epoch_861_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006777054630219936
Checkpoint saved to /content/checkpoint/model_epoch_862_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006788724567741156
Checkpoint saved to /content/checkpoint/model_epoch_863_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00624810578301549
Checkpoint saved to /content/checkpoint/model_epoch_864_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006428946275264025
Checkpoint saved to /content/checkpoint/model_epoch_865_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006284665782004595
Checkpoint saved to /content/checkpoint/model_epoch_866_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006939119193702936
Checkpoint saved to /content/checkpoint/model_epoch_867_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007125953212380409
Checkpoint saved to /content/checkpoint/model_epoch_868_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006891201715916395
Checkpoint saved to /content/checkpoint/model_epoch_869_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006410390604287386
Checkpoint saved to /content/checkpoint/model_epoch_870_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006747124716639519
Checkpoint saved to /content/checkpoint/model_epoch_871_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006943248677998781
Checkpoint saved to /content/checkpoint/model_epoch_872_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0068453894928097725
Checkpoint saved to /content/checkpoint/model_epoch_873_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006780968047678471
Checkpoint saved to /content/checkpoint/model_epoch_874_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007097747642546892
Checkpoint saved to /content/checkpoint/model_epoch_875_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006717209238559008
Checkpoint saved to /content/checkpoint/model_epoch_876_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006752924062311649
Checkpoint saved to /content/checkpoint/model_epoch_877_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0070533850230276585
Checkpoint saved to /content/checkpoint/model_epoch_878_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006937201134860516
Checkpoint saved to /content/checkpoint/model_epoch_879_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006836982909590006
Checkpoint saved to /content/checkpoint/model_epoch_880_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006604248192161322
Checkpoint saved to /content/checkpoint/model_epoch_881_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006866122595965862
Checkpoint saved to /content/checkpoint/model_epoch_882_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006927751004695892
Checkpoint saved to /content/checkpoint/model_epoch_883_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006778917741030455
Checkpoint saved to /content/checkpoint/model_epoch_884_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006996175739914179
Checkpoint saved to /content/checkpoint/model_epoch_885_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0069787404499948025
Checkpoint saved to /content/checkpoint/model_epoch_886_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0072504556737840176
Checkpoint saved to /content/checkpoint/model_epoch_887_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006612785160541534
Checkpoint saved to /content/checkpoint/model_epoch_888_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006651769392192364
Checkpoint saved to /content/checkpoint/model_epoch_889_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006757343653589487
Checkpoint saved to /content/checkpoint/model_epoch_890_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006945288740098476
Checkpoint saved to /content/checkpoint/model_epoch_891_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006619151681661606
Checkpoint saved to /content/checkpoint/model_epoch_892_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006609760224819183
Checkpoint saved to /content/checkpoint/model_epoch_893_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.00673195393756032
Checkpoint saved to /content/checkpoint/model_epoch_894_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006763251032680273
Checkpoint saved to /content/checkpoint/model_epoch_895_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006700266618281603
Checkpoint saved to /content/checkpoint/model_epoch_896_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.007114702835679054
Checkpoint saved to /content/checkpoint/model_epoch_897_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006516875233501196
Checkpoint saved to /content/checkpoint/model_epoch_898_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.006845629774034023
Checkpoint saved to /content/checkpoint/model_epoch_899_.pth


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


0.0067397248931229115
Checkpoint saved to /content/checkpoint/model_epoch_900_.pth


In [10]:
!cp "/content/checkpoint/model_epoch_900_.pth" "/content/gdrive/MyDrive/Colab Notebooks/CSNetPlus_model_mask/0.5_0.0067.pth"