In [6]:
'''Some useful utilities'''
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def identical(img):
    return img

def rot90(img):
    result=np.rot90(img, 1, axes=(0, 1))
    return result.copy()

def rot180(img):
    result=np.rot90(img, 2, axes=(0, 1))
    return result.copy()

def rot270(img):
    result=np.rot90(img, 3, axes=(0, 1))
    return result.copy()

def hormir(img):
    result=np.fliplr(img)
    return result.copy()

def vertmir(img):
    result=np.flipud(img)
    return result.copy()

def medianFilter(img, kernelshape=(3, 3), paddle='zero'):
    (M, N)=img.shape
    (m, n)=kernelshape
    result=np.zeros(img.shape, float)
    if paddle == 'zero':
        temp=np.zeros((M+2*int(m/2), N+2*int(n/2)))
    else:
        temp=np.zeros((M+2*int(m/2), N+2*int(n/2)))
    temp[int(m/2):int(m/2)+M, int(n/2):int(n/2)+N]=img.copy()
    for i in range(0, M):
        for j in range(0, N):
            result[i, j]=np.median(temp[i:i+m, j:j+n].copy())
    return result

def preprocessing(img):
    normed=((img-np.min(img))/(np.max(img)-np.min(img))*255).astype(np.uint8)
    gaussed=cv2.GaussianBlur(normed, (5, 5), 0.5)
    # plt.imshow(gaussed, 'gray')
    # plt.title('gauss')
    # plt.show()
    normed=((gaussed-np.min(gaussed))/(np.max(gaussed)-np.min(gaussed))*255).astype(np.uint8)
    meded=cv2.medianBlur(normed, 5)
    # plt.imshow(meded, 'gray')
    # plt.title('median')
    # plt.show()
    normed=((meded-np.min(meded))/(np.max(meded)-np.min(meded))*255).astype(np.uint8)
    clahe=cv2.createCLAHE(2., (8, 8))
    enhanced=clahe.apply(normed)
    # plt.imshow(enhanced, 'gray')
    # plt.title('contrast enhance')
    # plt.show()
    result=(enhanced-np.min(enhanced))/(np.max(enhanced)-np.min(enhanced))
    return result.copy()

def dataAug(img, mask):
    trans_func=np.random.choice([identical, rot90, rot180, rot270, hormir, vertmir])
    img_res, mask_res=trans_func(img), trans_func(mask)
    return img_res, mask_res

def myCrossEntropyLoss(output, target):
    output=output.detach().numpy()[0, 1].reshape((512, 512))
    target=target.detach().numpy().reshape((512, 512))
    loss=-np.sum(target*np.log(output+1e-12)+(1-target)*np.log(1-output+1e-12))/(512**2)
    return loss

def my_dice_score(set_A, set_B):
    inter=np.sum(set_A*set_B)
    union=np.sum(set_A+set_B)
    return 2*inter/(union+1e-12)

def dilation(src, kernel, pad_value=0, mode='b'):
    (M, N)=src.shape
    (m, n)=kernel.shape
    kernel=np.rot90(kernel, k=2, axes=(0, 1))
    dst=src.copy()
    if ~pad_value:
        padded=np.zeros((M+2*m//2, N+2*n//2))
    elif pad_value:
        padded=np.ones((M+2*m//2, N+2*n//2))
    padded[m//2:m//2+M, n//2:n//2+N]=src.copy()
    for i in range(M):
        for j in range(N):
            window=padded[i:i+m, j:j+n].copy()
            # 若与A的交不为空
            if (kernel.astype(np.bool8) & window.astype(np.bool8)).any():
                dst[i, j]=1
            else:
                dst[i, j]=0
    return dst

def getEdge(src):
    kernel=np.ones((3, 3))
    return (dilation(src, kernel)-src).copy()

In [7]:
# Dataset
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import h5py
from utils import *

class TumorDataset(Dataset):
    def __init__(self, dataset_dir: str, train: bool = True, transform: transforms = None):
        self.dataset_dir=dataset_dir
        self.transform=transform
        self.train=train
        

    def __len__(self):
        '''90% training set and 10% testing set'''
        if self.train:
            return 2742
        else:
            return 307
        

    def __getitem__(self, idx):
        '''Get data from dataset and return its image, mask and label fields'''
        for i in os.walk(self.dataset_dir):
            name_list=i[2]
            try:
                name_list.remove('.DS_Store')
            except:
                pass
        path=self.dataset_dir+name_list[idx]
        image=self.load(path, 'image')
        mask=self.load(path, 'tumorMask')
        edge=getEdge(mask)
        label=int(self.load(path, 'label'))-1
        image=preprocessing(image) # 高斯模糊，中值滤波， 对比度增强， 归一化
        image, mask=dataAug(image, mask) # 数据增强，包括旋转，镜像，对图像与mask施加同样的操作
        if self.transform:
            image=self.transform(image)
        
        return (
            torch.as_tensor(image).float(), 
            torch.as_tensor(mask).long(), 
            torch.as_tensor(label).long(),
            torch.as_tensor(edge).float()
        )
        
    @staticmethod
    def load(path, field):
        '''Load and preprocess a single .mat data file'''
        
        assert field in ['image', 'label', 'tumorMask', 'tumorBorder'], 'Incorrect data field'
        
        with h5py.File(path, 'r') as f:
            result=np.array(f['cjdata'][field])
        
        return result
    
    


In [41]:
# Edge Unet Network definition
import torch
import torch.nn.functional as F

class ConvBNReLU(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.layers=torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), 
            torch.nn.BatchNorm2d(out_channels), 
            torch.nn.ReLU(True)
        )
    
    def forward(self, x):
        return self.layers(x)

class ImageEncodeBlock1(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers=torch.nn.Sequential(
            ConvBNReLU(in_channels, out_channels),
            ConvBNReLU(out_channels, out_channels),
            ConvBNReLU(out_channels, out_channels),
        )
        
    def forward(self, x):
        return self.layers(x)
    
class ImageEncodeBlock2(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers=torch.nn.Sequential(
            ConvBNReLU(in_channels, out_channels),
            ConvBNReLU(out_channels, out_channels),
            ConvBNReLU(out_channels, out_channels),
            ConvBNReLU(out_channels, out_channels),
        )
    
    def forward(self, x):
        return self.layers(x)
    
class ImageEncodeBlock3(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers=torch.nn.Sequential(
            ConvBNReLU(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            ConvBNReLU(out_channels, out_channels, kernel_size=1, stride=1, padding=0),
            torch.nn.Dropout2d(0.5),
        )
    
    def forward(self, x):
        return self.layers(x)

class EdgeEncodeBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers=torch.nn.Sequential(
            ConvBNReLU(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            ConvBNReLU(out_channels, out_channels, kernel_size=1, stride=1, padding=0),
            ConvBNReLU(out_channels, out_channels, kernel_size=1, stride=1, padding=0),
        )
        
    
    def forward(self, x):
        return self.layers(x)

class EdgeGuidanceBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv1=torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.conv2=torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        
        
    def forward(self, X, E):
        gamma=torch.sigmoid(self.conv1(E))
        beta=gamma*X
        output=F.relu(self.conv2(gamma+beta))
        return output

class DeconvBNReLU(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super().__init__()
        self.layers=torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.layers(x)

class DecodeBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, interpolate=True):
        super().__init__()
        if interpolate:
            self.upsample=torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 
                torch.nn.Conv2d(in_channels, out_channels, 1, 1, 0)
            )
        else:
            self.upsample=DeconvBNReLU(in_channels, out_channels)
        self.conv=torch.nn.Sequential(
            ConvBNReLU(out_channels*2, out_channels), 
            ConvBNReLU(out_channels, out_channels)
            )

    def forward(self, X, E):
        print('concatenate size x e', X.size(), E.size())
        concatenated=torch.cat([E, self.upsample(X)], dim=1)
        output=self.conv(concatenated)
        return output

class EdgeUNet(torch.nn.Module):
    def __init__(self, in_channels=1, out_channels=2, interpolate=True):
        super().__init__()
        self.image_encode_block1=ImageEncodeBlock1(in_channels, 64)
        self.image_encode_block2=ImageEncodeBlock1(64, 128)
        self.image_encode_block3=ImageEncodeBlock2(128, 256)
        self.image_encode_block4=ImageEncodeBlock2(256, 512)
        self.image_encode_block5=ImageEncodeBlock3(512, 1024)
        self.maxpool=torch.nn.MaxPool2d(2, 2)

        self.edge_encode_block1=EdgeEncodeBlock(in_channels, 64)
        self.edge_encode_block2=EdgeEncodeBlock(64, 128)
        self.edge_encode_block3=EdgeEncodeBlock(128, 256)
        self.edge_encode_block4=EdgeEncodeBlock(256, 512)
        self.avgpool=torch.nn.AvgPool2d(2, 2)

        self.edge_guidance_block1=EdgeGuidanceBlock(64, 64)
        self.edge_guidance_block2=EdgeGuidanceBlock(128, 128)
        self.edge_guidance_block3=EdgeGuidanceBlock(256, 256)
        self.edge_guidance_block4=EdgeGuidanceBlock(512, 512)

        self.decode_block1=DecodeBlock(1024, 512, interpolate=interpolate)
        self.decode_block2=DecodeBlock(512, 256, interpolate=interpolate)
        self.decode_block3=DecodeBlock(256, 128, interpolate=interpolate)
        self.decode_block4=DecodeBlock(128, 64, interpolate=interpolate)

        self.out_conv=torch.nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, image, edge):
        image1=self.image_encode_block1(image)
        downimg1=self.maxpool(image1)
        image2=self.image_encode_block2(downimg1)
        downimg2=self.maxpool(image2)
        image3=self.image_encode_block3(downimg2)
        downimg3=self.maxpool(image3)
        image4=self.image_encode_block4(downimg3)
        downimg4=self.maxpool(image4)
        image5=self.image_encode_block5(downimg4)

        edge1=self.edge_encode_block1(edge)
        downedg1=self.avgpool(edge1)
        edge2=self.edge_encode_block2(downedg1)
        downedg2=self.avgpool(edge2)
        edge3=self.edge_encode_block3(downedg2)
        downedg3=self.avgpool(edge3)
        edge4=self.edge_encode_block4(downedg3)

        EGB1=self.edge_guidance_block1(image1, edge1)
        EGB2=self.edge_guidance_block2(image2, edge2)
        EGB3=self.edge_guidance_block3(image3, edge3)
        EGB4=self.edge_guidance_block4(image4, edge4)

        decode1=self.decode_block1(image5, EGB4)
        decode2=self.decode_block2(decode1, EGB3)
        decode3=self.decode_block3(decode2, EGB2)
        decode4=self.decode_block4(decode3, EGB1)
        
        output=self.out_conv(decode4)
        return torch.sigmoid(output)

In [9]:
# Unet network definition
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return torch.sigmoid(logits)


In [10]:
# training process definition of unet
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

def train(model: UNet, 
          device: torch.device, 
          batch_size: int, 
          train_loader: DataLoader, 
          optimizer: torch.optim.Optimizer, 
          criterion: torch.nn.Module, 
          epoch: int):
    '''
    one epoch training process, containing: forwarding, calculating loss value, back propagation, printing some of the training progress
    '''
    running_loss=0.
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, targets, labels, edge=data
        inputs, targets, labels=inputs.to(device), targets.to(device), labels.to(device)
        optimizer.zero_grad() # set gradient of last epoch to zero
        outputs=model(inputs)
        loss=criterion(outputs, targets)
        loss.backward() # backward the gradient
        optimizer.step() # update parameters
        running_loss+=loss.item() # sum of total loss
        if batch_idx % 300 == 299:
            print('第{}轮次已训练{}批样本, 本批次平均loss值: {}'.format(epoch+1, batch_idx+1, running_loss/300))
            running_loss=0.
    return

def eval(model: UNet, 
         device: torch.device, 
         test_loader: DataLoader,
         criterion: torch.nn.Module):
    '''
    testing the accuracy of current partly-trained model and print
    '''
    correct=0
    total=0
    total_loss=0
    with torch.no_grad():
        for data in test_loader:
            images, targets, labels, edge=data
            images, targets, labels=images.to(device), targets.to(device), labels.to(device)
            # outputs1, outputs2=model(images)
            outputs2=model(images)
            # _, predicted=torch.max(outputs1.data, dim=1)
            total+=labels.size(0)
            # correct+=(predicted-labels<0.5).sum().item()
            total_loss+=criterion(outputs2, targets)
    # print('accuracy on test set: {}%\naverage dice score: {}'.format(100*correct/total, total_dice/total))
    print(f'current epoch average cross entropy loss: {total_loss/total}')
    image=np.array(images.cpu())
    target=np.array(targets.cpu())
    output2=np.array(outputs2.cpu())
    output2[output2<=0.5]=0
    output2[output2>0.5]=1
    brain_MRI=image[0, 0].copy()
    groun_truth=target[0].copy()
    predict=output2[0, 1].copy()
    print('Dice score', my_dice_score(predict, target))
    plt.figure(figsize=(40, 40))
    plt.subplot(131)
    plt.imshow(brain_MRI, 'gray')
    plt.title('original')
    plt.subplot(132)
    plt.imshow(groun_truth, 'gray')
    plt.title('target')
    plt.subplot(133)
    plt.imshow(predict, 'gray')
    plt.title('predict')
    plt.show()
    return total_loss/total

In [None]:
# Unet training process
batch_size=1
epochs=60
model=UNet(1, 2, False)
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
model.to(device)
transform=transforms.Compose([transforms.ToTensor()])
train_dataset=TumorDataset(dataset_dir='./dataset/training/', train=True, transform=transform)
train_loader=DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset=TumorDataset(dataset_dir='./dataset/testing/', train=False, transform=transform)
test_loader=DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
criterion=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9], device=device))
optimizer=torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-8)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

for epoch in range(epochs):
    train(model, device, batch_size, train_loader, optimizer, criterion, epoch)
    eval_loss=eval(model, device, test_loader, criterion)
    scheduler.step(eval_loss)

In [11]:
# training process definition of edge unet
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

def train(model: EdgeUNet, 
          device: torch.device, 
          batch_size: int, 
          train_loader: DataLoader, 
          optimizer: torch.optim.Optimizer, 
          criterion: torch.nn.Module, 
          epoch: int):
    '''
    one epoch training process, containing: forwarding, calculating loss value, back propagation, printing some of the training progress
    '''
    running_loss=0.
    for batch_idx, data in enumerate(train_loader, 0):
        images, targets, labels, edges=data
        edges=edges.reshape((-1, 1, 512, 512))
        images, targets, labels, edges=images.to(device), targets.to(device), labels.to(device), edges.to(device)
        optimizer.zero_grad() # set gradient of last epoch to zero
        outputs=model(images, edges)
        loss=criterion(outputs, targets)
        loss.backward() # backward the gradient
        optimizer.step() # update parameters
        running_loss+=loss.item() # sum of total loss
        if batch_idx % 300 == 299:
            print('第{}轮次已训练{}批样本, 本批次平均loss值: {}'.format(epoch+1, batch_idx+1, running_loss/300))
            running_loss=0.
    return

def eval(model: EdgeUNet, 
         device: torch.device, 
         test_loader: DataLoader,
         criterion: torch.nn.Module):
    '''
    testing the accuracy of current partly-trained model and print
    '''
    correct=0
    total=0
    total_loss=0
    with torch.no_grad():
        for data in test_loader:
            images, targets, labels, edges=data
            edges=edges.reshape((-1, 1, 512, 512))
            images, targets, labels, edges=images.to(device), targets.to(device), labels.to(device), edges.to(device)
            # outputs1, outputs2=model(images)
            outputs2=model(images, edges)
            # _, predicted=torch.max(outputs1.data, dim=1)
            total+=labels.size(0)
            # correct+=(predicted-labels<0.5).sum().item()
            total_loss+=criterion(outputs2, targets)
    # print('accuracy on test set: {}%\naverage dice score: {}'.format(100*correct/total, total_dice/total))
    print(f'current epoch average cross entropy loss: {total_loss/total}')
    image=np.array(images.cpu())
    target=np.array(targets.cpu())
    output2=np.array(outputs2.cpu())
    output2[output2<=0.5]=0
    output2[output2>0.5]=1
    brain_MRI=image[0, 0].copy()
    groun_truth=target[0].copy()
    predict=output2[0, 1].copy()
    print('Dice score', my_dice_score(predict, target))
    plt.figure(figsize=(40, 40))
    plt.subplot(131)
    plt.imshow(brain_MRI, 'gray')
    plt.title('original')
    plt.subplot(132)
    plt.imshow(groun_truth, 'gray')
    plt.title('target')
    plt.subplot(133)
    plt.imshow(predict, 'gray')
    plt.title('predict')
    plt.show()
    return total_loss/total

In [42]:
# Edge unet training process
batch_size=1
epochs=3
model=EdgeUNet(1, 2, False)
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
model.to(device)
transform=transforms.Compose([transforms.ToTensor()])
train_dataset=TumorDataset(dataset_dir='./dataset/training/', train=True, transform=transform)
train_loader=DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset=TumorDataset(dataset_dir='./dataset/testing/', train=False, transform=transform)
test_loader=DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
criterion=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9], device=device))
optimizer=torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-8)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

for epoch in range(epochs):
    train(model, device, batch_size, train_loader, optimizer, criterion, epoch)
    eval_loss=eval(model, device, test_loader, criterion)
    scheduler.step(eval_loss)

False
input size torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 512, 512])
image1 size torch.Size([1, 64, 512, 512])
image2 size torch.Size([1, 128, 256, 256])
image3 size torch.Size([1, 256, 128, 128])
image4 size torch.Size([1, 512, 64, 64])
image5 size torch.Size([1, 1024, 32, 32])
edge1 size torch.Size([1, 64, 512, 512])
edge2 size torch.Size([1, 128, 256, 256])
edge3 size torch.Size([1, 256, 128, 128])
edge4 size torch.Size([1, 512, 64, 64])
egb1 size torch.Size([1, 64, 512, 512])
egb2 size torch.Size([1, 128, 256, 256])
egb3 size torch.Size([1, 256, 128, 128])
egb4 size torch.Size([1, 512, 64, 64])
concatenate size x e torch.Size([1, 1024, 32, 32]) torch.Size([1, 512, 64, 64])
decode1 size torch.Size([1, 512, 64, 64])
concatenate size x e torch.Size([1, 512, 64, 64]) torch.Size([1, 256, 128, 128])
decode2 size torch.Size([1, 256, 128, 128])
concatenate size x e torch.Size([1, 256, 128, 128]) torch.Size([1, 128, 256, 256])
decode3 size torch.Size([1, 128, 256, 256])
concatenate si

KeyboardInterrupt: 