In [235]:
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

!mkdir -p drive
!google-drive-ocamlfuse drive  -o nonempty

E: Package 'python-software-properties' has no installation candidate
Selecting previously unselected package google-drive-ocamlfuse.
(Reading database ... 131304 files and directories currently installed.)
Preparing to unpack .../google-drive-ocamlfuse_0.7.3-0ubuntu3~ubuntu18.04.1_amd64.deb ...
Unpacking google-drive-ocamlfuse (0.7.3-0ubuntu3~ubuntu18.04.1) ...
Setting up google-drive-ocamlfuse (0.7.3-0ubuntu3~ubuntu18.04.1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
··········
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=ht

In [236]:
import os 
os.chdir("content/drive/")
!ls

FileNotFoundError: ignored

In [0]:
from __future__ import print_function, division
%matplotlib inline

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable,Function
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from scipy.optimize import linear_sum_assignment
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import os
plt.ion()   # interactive mode
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")


# Any results you write to the current directory are saved as output.

In [0]:
TRAIN_PATH = './drive/mini_data/DATA/stage1_train/'
TEST_PATH = '../drive/mini_data/DATA/stage1_test/'
UNET_PATH = './drive/mini_program/'

# Loss Function

In [0]:
class loss_f(nn.Module):
    def __init__(self,lamda=1):
        # super parameter
        super(loss_f,self).__init__()
        self.lamda = lamda
    def forward(self,output_list,scores,gt_masks):
        """
        gt_masks: sample['masks'],shape=[1,masks_number,H,W]
        output_list: tensor ,size = [premasks_number,H,W]
        scores: tensor size = [premasks_numbers,1]
        batchsize = 1
        """
        
        gt_masks = gt_masks.squeeze(0) # max_element = 1
        num_gt = gt_masks.size()[0]
        num_pre = output_list.size()[0]
        num_min = min(num_gt,num_pre)
        loss = 0
        Matrix = torch.zeros(num_min,num_gt,device=output_list.device)
        for i in range(num_min): # when pre>gt, drop the later prediction
            for j in range(num_gt): 
                #generate a num_min*num_gt matrix 
                Matrix[i][j] = self._iou(output_list[i],gt_masks[j]) ## IoU is a tensor scalar
        
        #Matrix = Matrix.cuda()
        numpy_m = Matrix.cpu().detach().numpy()
        
        row_ind,col_ind = linear_sum_assignment(numpy_m)
        self.para = (output_list,scores,gt_masks,num_pre,num_gt,num_min,numpy_m,row_ind,col_ind)
        for i in range(num_min):
            loss = loss + (-Matrix[row_ind[i]][col_ind[i]]+self.lamda*F.binary_cross_entropy(scores[i],torch.tensor([1.0],device=output_list.device)))
        for i in range(num_gt,num_pre):#pre> gt时进入 ,pre<gt时不进入
            loss = loss + self.lamda*F.binary_cross_entropy(scores[i],torch.tensor([0.0],device=output_list.device)) #F.binary_cross_entropy(b,a)
        return loss
    
    def backward(self,grad_output):
        (output_list,scores,gt_masks,num_pre,num_gt,num_min,row_ind,col_ind) = self.para
        #grad_hung = torch.zeros(num_pre,device=output_list.device)
        grad_s = torch.zeros(num_pre,device=output_list.device)
        grad_mask = torch.zeros_like(output_list) #size = [premasks_number,H,W],device=cuda
        for t in range(num_pre):
            if t<=num_min:
                ######hungarian##################
                #grad_hung[t] = -numpy_m[row_ind[t]][col_ind[t]] #value of iou
                ######hungarian##################
                
                ######iou########################
                X = output_list[row_ind[t]] # pre mask [h,w],device = cuda
                Y = gt_masks[col_ind[t]] # gt mask [h,w]
                I = torch.sum(torch.mul(X,Y)) # scalar
                U = torch.sum(X)+torch.sum(Y)-I # scalar
                grad_mask[t] = -1*(U*Y-I*(1-Y))/(U**2) # [h,w]
                ######iou########################
                grad_s[t] = -self.lamda*(1/(scores[t]))
            else:
                #grad_mask[t]
                grad_s[t] = self.lamda*(1/(1-scores[t]))
        # grad_input = [derivate forward(input) wrt parameters] * grad_output
        grad_mask = grad_mask*grad_output
        grad_s = grad_s*grad_output
        print(grad_mask,grad_s)
        return grad_mask,grad_s
    
    def _iou(self,x,y):
        iou_inter = torch.sum(torch.mul(x,y))
        iou = iou_inter/(torch.sum(x)+torch.sum(y)-iou_inter)
        return iou

# Model Class

In [0]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)
        # ConvLSTM part
        #self.conv = nn.Conv2d(64,16,3,stride=1,padding=1)
        self.convlstm = ConvLSTM(input_size=(176,176),
                                 input_dim=1,
                                 hidden_dim=[1,1],
                                 kernel_size=(3, 3),
                                 num_layers=2,
                                 batch_first=False,
                                 bias=True,
                                 return_all_layers=True)
        # SI part
        
        self.conv8 = nn.Conv2d(1,1,1,stride=1,padding=0)
        self.bias = nn.Parameter(torch.ones([1]),requires_grad=True)
    
        self.fc = nn.Linear(1*176*176, 1)
        #self.up = nn.ConvTranspose2d(1, 1, 2, stride=2,padding=0)
        torch.nn.init.uniform_(self.conv8.weight, a=-0.08, b=0.08)
        torch.nn.init.uniform_(self.conv8.bias, a=-0.08, b=0.08)

    def forward(self, x,seq_len,mode='train'):
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x64 = self.up4(x, x1)
        x = F.sigmoid(self.outc(x64))
        if mode=='pre':
            return (x,x64)
        # x.shape() = [1,1,176,176]
        #x = self.conv(x)#[1,1,176,176]
        size = x.size()[-1] # W
        # convlstm and si
        convlstm_input = x.unsqueeze(0) #[1,batch,d,h,w],(t, b, c, h, w) -> (b, t, c, h, w) (i.e.[1,1,1,176,176])
        output_list = []
        scores = []
        hidden_state=None
        for i in range(seq_len):
            [layer_output_list, last_state_list] = self.convlstm(convlstm_input,hidden_state) # layer_output_list=[2,batch,t,d,h,w],t=1,(i.e.[2,1,1,1,176,176])
                                                                                 # last_state_list(i.e. [[h,c]])=[1,2,batch,t,d,h,w] (i.e.[1,2,1,1,1,176,176])
            hidden_state = last_state_list

            layer_output_list = layer_output_list[-1].squeeze(1)#[1,1,176,176]
            
            # produce score
            score_input = layer_output_list
            #score_input = F.max_pool2d(score_input,2) #[88,88]
            [b,c,h,w] = score_input.size()

            score = F.sigmoid(self.fc(score_input.view(b*c*h*w)))
            scores.append(score)
            # produce masks
            SI = F.log_softmax(self.conv8(layer_output_list)) #the input is [1,d,h,w],SI is [1,1,h,w] (i.e.[1,1,176,176])
            SI = F.sigmoid(SI +self.bias)
            #SI = self.up(SI)
            mask = SI.squeeze() #[batch,h,w] ,since batchsize=1,it should be [h,w] (i.e.[176,176])
            output_list.append(mask)
        return (output_list,scores)


# sub-parts of the U-Net model

class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

# Helper Classes

In [0]:
class NeuralDataset(Dataset):
    """Neural dataset."""

    def __init__(self,root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the examples.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        #self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.example_list = os.listdir(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.example_list)

    def __getitem__(self, idx):
        example_dir = os.path.join(self.root_dir,
                                self.example_list[idx])
        img_dir = example_dir+'/images'
        masks_dir = example_dir+'/masks'
        
        img_name = img_dir+'/'+os.listdir(img_dir)[0]
        image = io.imread(img_name)[:,:,0:3]
        
        maskwalk = os.walk(masks_dir).__next__()
        masks = []
        for item in maskwalk[2]:
            masks_name = os.path.join(masks_dir,item)
            mask = io.imread(masks_name)
            masks.append(mask)
        masks = np.stack(masks)
        
        sample = {'image': image, 'masks': masks,'id_':img_dir}# masks is [masknumber,H,W],image [H,W,C]

        if self.transform:
            sample = self.transform(sample)

        return sample
    
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, masks,id_ = sample['image'], sample['masks'],sample['id_']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # resize the masks
        new_msk = []
        for mask in masks:
            new_msk.append(transform.resize(mask, (new_h, new_w)))
        new_mask = np.stack(new_msk)
        
        return {'image': img, 'masks': new_mask,'id_':id_}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, masks,id_ = sample['image'], sample['masks'],sample['id_']

        h, w = image.shape[:2]
        
        new_h, new_w = self.output_size
        if h-new_h==0:
            top = 0
        else:
            top = np.random.randint(0, h - new_h)
        if w-new_w==0:
            left = 0
        else:
            left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        # crop the masks
        new_msk = []
        for mask in masks:
            newm = mask[top: top + new_h,
                        left: left + new_w]
            new_msk.append(newm)
        new_mask = np.stack(new_msk)

        return {'image': image, 'masks': new_mask,'id_':id_}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, masks,id_ = sample['image'], sample['masks'],sample['id_']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        # numpy masks: N x H x W
        # torch masks: N X H X W
        return {'image': torch.from_numpy(image),
                'masks': torch.from_numpy(masks),
                'id_':id_}
class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        
        Parameters
        ----------
        input_size: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.height, self.width = input_size
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        self.batch = nn.BatchNorm2d(4 * self.hidden_dim)
        torch.nn.init.uniform_(self.conv.weight, a=-0.08, b=0.08)
        torch.nn.init.uniform_(self.conv.bias, a=-0.08, b=0.08)

    def forward(self, input_tensor, cur_state):
        # input_tensor is [batch,channel,h,w]
        h_cur, c_cur = cur_state     
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        
        combined = self.conv(combined)
        conbined = self.batch(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next) # size of h and c is [batch,channel,h,w]
        
        return h_next, c_next

    def init_hidden(self, batch_size):
        return [Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
                Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda()]


class ConvLSTM(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim  = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.height, self.width = input_size

        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]

            cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)
        

    def forward(self, input_tensor, hidden_state=None):
        """
        
        Parameters
        ----------
        input_tensor: todo 
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful
            
        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # Implement stateful ConvLSTM
        if hidden_state is None:
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0))

        layer_output_list = []
        last_state_list   = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):  #num_layers should be 2

            h, c = hidden_state[layer_idx]
            output_inner = []
            
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])# size of h and c is [batch,c,h,w]
                output_inner.append(h)#[t,batch,c,h,w]

            layer_output = torch.stack(output_inner, dim=1)# layer_output is [batch,t,c,h,w]
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)# [1,batch,t,c,h,w]
            last_state_list.append([h, c])#[1,2,batch,t,c,h,w]
        torch.cuda.empty_cache()
        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list   = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states # [[h,c]]

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                    (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [0]:
# reading data
transformed_dataset = NeuralDataset(root_dir=TRAIN_PATH,
                                    transform=transforms.Compose([
                                        Rescale(176),
                                        RandomCrop(176),
                                        ToTensor()]))

dataloader = DataLoader(transformed_dataset, batch_size=1,
                        shuffle=True, num_workers=0)


In [261]:
#set up model
unet = UNet(3,1)
if os.path.exists('./drive/mini_program/unet_params.pkl') == False:
    pretrained_dict = torch.load('./drive/mini_program/unet.pkl').state_dict()
    model_dict = unet.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict) 
    unet.load_state_dict(model_dict)
    print(unet.state_dict)
else:
    unet.load_state_dict(torch.load('./drive/mini_program/unet_params.pkl'))
    print('load the parameters')
for i,p in enumerate(unet.parameters()):
    if i <72:
        p.requires_grad = False
unet.bias.requires_grad = True
# training
unet = unet.to(device)
torch.cuda.empty_cache()
# create your optimizer
criterion = loss_f()
optimizer = optim.Adam(unet.parameters(), lr=0.00001)#,momentum=0.9,weight_decay=0.0005)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

<bound method Module.state_dict of UNet(
  (inc): inconv(
    (conv): double_conv(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
  )
  (down1): down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     

In [0]:
file_path = './'
for epoch in range(840):  # loop over the dataset multiple times
    els = []
    running_loss = 0.0
    epoch_loss =0.0
    k = 0
    #converge_loss = np.array([0,0,0,0,0])
    for i, data in enumerate(dataloader, 0):
        # get the inputs
        inputs = (data['image'].type(torch.float32)).to(device)
        labels = data['masks'].type(torch.float32).to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        seq_len = epoch%419+1
        (output_list,scores) = unet(inputs,seq_len)
        output_list = torch.stack(output_list) # shape = [premasks,H,W]
        scores = torch.stack(scores)
        
        loss = criterion(output_list,scores,labels) #  the minimum of loss is -1
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(filter(lambda p:p.requires_grad,unet.parameters()), max_norm=5, norm_type=1)
        optimizer.step()
        #scheduler.step(loss)

        #judge whether converge
#         for j in range(1,5):
#             converge_loss[j]=converge_loss[j-1]
#         converge_loss[0] = loss.item()
#         if converge_loss.max()<-0.65:
#             seq_len = seq_len+1
#             torch.save(unet.state_dict(), '../mini_program/unet_params.pkl')
#             converge_loss = np.array([0,0,0,0,0])
        
        #print('epoch=',epoch,'num_data=',i,'seq_len=',seq_len,'loss',loss.item(),'lr',optimizer.param_groups[0]['lr'])
        # print statistics
        running_loss += loss.item()
        epoch_loss+= loss.item()
        k = k+1
        if k % 67 == 66:    # print every 134 mini-batches
            print('[%d, %5d] loss: %.8f,seq_len:%d lr:%.10f' %
                  (epoch + 1, k + 1, running_loss /67,seq_len,optimizer.param_groups[0]['lr']))
            running_loss = 0.0
            
#     if seq_len>=420:
#         break
    print('[%d] loss: %.8f,seq_len:%d lr:%.10f' %
      (epoch + 1, epoch_loss /670,seq_len,optimizer.param_groups[0]['lr']))
    torch.save(unet.state_dict(), './drive/mini_program/unet_params.pkl')
#     with open(os.path.join(file_path, "rls.csv"),'a') as f:
#         f.write(rls)
    with open(os.path.join(file_path, "els.csv"),'a') as f:
        f.write('epoch='+str(epoch+1)+'loss='+str(epoch_loss /670)+'\n')
print('Finished Training')

[1,    67] loss: 0.65305278,seq_len:1 lr:0.0000100000
[1,   134] loss: 0.59979281,seq_len:1 lr:0.0000100000
[1,   201] loss: 0.53420841,seq_len:1 lr:0.0000100000
[1,   268] loss: 0.46902846,seq_len:1 lr:0.0000100000
[1,   335] loss: 0.40554993,seq_len:1 lr:0.0000100000
[1,   402] loss: 0.34459421,seq_len:1 lr:0.0000100000
[1,   469] loss: 0.28764544,seq_len:1 lr:0.0000100000
[1,   536] loss: 0.23617993,seq_len:1 lr:0.0000100000
[1,   603] loss: 0.19077480,seq_len:1 lr:0.0000100000
[1,   670] loss: 0.15115173,seq_len:1 lr:0.0000100000
[1] loss: 0.38739688,seq_len:1 lr:0.0000100000
[2,    67] loss: 0.20076842,seq_len:2 lr:0.0000100000
[2,   134] loss: 0.11199258,seq_len:2 lr:0.0000100000
[2,   201] loss: 0.08167406,seq_len:2 lr:0.0000100000
[2,   268] loss: 0.05702280,seq_len:2 lr:0.0000100000
[2,   335] loss: 0.11805749,seq_len:2 lr:0.0000100000
[2,   402] loss: 0.02827004,seq_len:2 lr:0.0000100000
[2,   469] loss: 0.01945247,seq_len:2 lr:0.0000100000
[2,   536] loss: 0.11446592,seq_len