In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
ngpu= 1
print(torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu"))
print(torch.cuda.get_device_name(0))

In [None]:
import torch.nn as nn
from complexPyTorch.complexLayers import ComplexConv2d, ComplexMaxPool2d, ComplexReLU
from complexPyTorch.complexFunctions import complex_relu, complex_max_pool2d, complex_upsample

# class ComplexUpsample(Module):

#     def __init__(self,kernel_size, stride= None, padding = 0,
#                  dilation = 1, return_indices = False, ceil_mode = False):
#         super(ComplexMaxPool2d,self).__init__()
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding
#         self.dilation = dilation
#         self.ceil_mode = ceil_mode
#         self.return_indices = return_indices

#     def forward(self,input):
#         return complex_max_pool2d(input,kernel_size = self.kernel_size,
#                                 stride = self.stride, padding = self.padding,
#                                 dilation = self.dilation, ceil_mode = self.ceil_mode,
#                                 return_indices = self.return_indices)


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        ComplexConv2d(in_channels, out_channels, 3, padding=1),
        ComplexReLU(),
        ComplexConv2d(out_channels, out_channels, 3, padding=1),
        ComplexReLU()
    )

def coe_to_spatial(img):
    iimage = torch.fft.ifft2(img)
    iimage = torch.abs(iimage)
    
    iimage -= torch.min(iimage)
    iimage /= torch.max(iimage)
    return iimage

## main model

In [None]:
class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)

        self.dconv_down4 = double_conv(128, 256)
   
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(64 + 128, 64)
        self.dconv_up1 = double_conv(32 + 64, 32)

        self.conv_last = ComplexConv2d(32, n_class, 1)
        self.conv_llast = ComplexConv2d(3, 1, 1)
        self.maxpool = ComplexMaxPool2d(2)
        # self.upsample = complex_upsample(scale_factor=2, align_corners=True)
        '''
        maxpool用於下採樣
        upsample用於上採樣
        '''
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = complex_upsample(x, scale_factor=2, mode='bilinear', align_corners=True)   
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = complex_upsample(x, scale_factor=2, mode='bilinear', align_corners=True)     
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = complex_upsample(x, scale_factor=2, mode='bilinear', align_corners=True)       
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        return out

def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            # apply a uniform distribution to the weights and a bias=0
            m.weight.data.uniform_(0.0, 1.0)
            m.bias.data.fill_(0)

<p>繪製model:<p>

<pre><code>from torchviz import make_dot
x = torch.randn(1, 3, 256, 256).requires_grad_(True).cuda()
y = model(x)
vis_graph = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
vis_graph.view()
</code></pre>

<p>輸出 Model parameters:
<pre><code> 
model = UNet(n_class=3).cuda()
for name, param in model.named_parameters():
     print(param)
</code></pre>

## training loop

In [None]:
from collections import defaultdict
import torch.nn.functional as F
from tqdm.notebook import tqdm
from tensorboardX import SummaryWriter

def mse_loss(output, target):
    return torch.mean((output - target)**2)

def train_model(model, optimizer, scheduler, num_epochs=25):
    
    log_path = os.path.join(os.getcwd(), "log/log-complex")
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        
    net_path = os.path.join(os.getcwd(), "net/net-complex")
    if not os.path.exists(net_path):
        os.makedirs(net_path)
    
    writer = SummaryWriter(log_path)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for param_group in optimizer.param_groups:
            print('LR', param_group['lr'])
            
        since = time.time()
        ###################
        # train the model #
        ###################
        model.train()
        
        train_loss = []
        epoch_samples = 0
        with torch.set_grad_enabled(True):
            for (image, label) in tqdm(dataloaders['train']):
                # print('images, labels', images, labels)
                image, label = image.to(device), label.to(device)
                images, labels = image.to(torch.cfloat), label.to(torch.cfloat)
                
                optimizer.zero_grad()
                
                outputs = model(images)
                
                loss = mse_loss(outputs, labels)

                train_loss.append(loss.data.cpu().numpy() * images.size(0))
            
                loss.backward()
                optimizer.step()
                # scheduler.step()

                epoch_samples += images.size(0)
            
            print('training: total loss: {:.6f}'.format(sum(train_loss) / epoch_samples))
            # writer.add_scalar('Training/Loss', sum(train_loss) / epoch_samples, epoch)

        model.eval()
        val_loss = []
        with torch.no_grad():
            epoch_samples = 0
            for (image, label) in tqdm(dataloaders['val']):
                image, label = image.to(device), label.to(device)
                images, labels = image.to(torch.cfloat), label.to(torch.cfloat)

                outputs = model(images)
                metrics = defaultdict(float)
                
                loss = mse_loss(outputs, labels)

                val_loss.append(loss.data.cpu().numpy() * images.size(0))
                
                epoch_samples += images.size(0)
                
            print('validation: total loss: {:.6f}'.format(sum(val_loss) / epoch_samples))
            # writer.add_scalar('Validation/Loss', sum(val_loss) / epoch_samples, epoch)
            epoch_loss = sum(val_loss) / epoch_samples

                # deep copy the model
            
            
            torch.save(model, net_path + '/net-' + str(epoch+1) + '.pkl')
            if  epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
from dataloader import get_train_augmentation, get_test_augmentation, get_loader
import torchvision
from torchvision import transforms, datasets, models
import os

tr_img_folder = os.path.join(r'C:\Users\user\pythonProject\mission87\data\DUTS\DUTS-TR\DUTS-TR-Image')
tr_gt_folder = os.path.join(r'C:\Users\user\pythonProject\mission87\data\DUTS\DUTS-TR\DUTS-TR-Mask')

ver = 2
train_transform = get_train_augmentation(img_size=224, ver=ver)
test_transform = get_test_augmentation(img_size=224)

train_loader = get_loader(tr_img_folder, tr_gt_folder, phase='train',
                          batch_size=2, shuffle=True, num_workers=0,
                          transform=train_transform)
val_loader = get_loader(tr_img_folder, tr_gt_folder, phase='val',
                        batch_size=2, shuffle=False, num_workers=0,
                        transform=train_transform)

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

model = UNet(n_class=3).cuda()
model.apply(weights_init)
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)

## model test complex

In [None]:
test = torch.randn([1, 3, 224, 224])
test = test.to(torch.cfloat) #dtype to complex64
test = test.cuda()

pred = model(test)
loss = mse_loss(test, pred)
print(loss)
loss.backward()

## check training/valid data

In [None]:
import matplotlib.pyplot as plt
image, mask = next(iter(train_loader))

s = {1: 'Nothing change', 2: 'Fourier coefficients', 3: 'Fourier amplitude', 4: 'Fourier phase'}
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
fig.suptitle(s[ver], verticalalignment='bottom')
ax1.imshow(coe_to_spatial(image[0].to(torch.cfloat)).permute(1, 2, 0))
ax1.set_title('image')
ax2.imshow(coe_to_spatial(mask[0].to(torch.cfloat)).permute(1, 2, 0), 'gray')
ax2.set_title('mask')
fig.tight_layout()
fig.subplots_adjust(top=0.99)
plt.show()

In [None]:
def mse_loss(output, target):
    loss = torch.mean((output - target)**2)
    return loss

image = image.to(torch.cfloat).cuda()
mask = mask.cuda()
pred = model(image)
mse_loss(pred, mask)

# Start Training

In [None]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=50)

# Eval

In [None]:
model.eval()
image, mask = image.cuda(), mask.cuda()
pred = model(image)
pred, pred.shape

In [None]:
fshift = torch.fft.fftshift(image[0])

ifshift = torch.fft.ifftshift(fshift)
iimage = torch.fft.ifft2(ifshift)
iimage = torch.abs(iimage.permute(1, 2, 0))

iimage -= torch.min(iimage)
iimage /= torch.max(iimage)

plt.imshow(iimage.cpu().detach().numpy())

In [None]:
fshift = torch.fft.fftshift(mask[0])

ifshift = torch.fft.ifftshift(fshift)
iimage = torch.fft.ifft2(ifshift)
iimage = torch.abs(iimage.permute(1, 2, 0))

iimage -= torch.min(iimage)
iimage /= torch.max(iimage)

plt.imshow(iimage.cpu().detach().numpy())

In [None]:
fshift = torch.fft.fftshift(pred[0])

ifshift = torch.fft.ifftshift(fshift)
iimage = torch.fft.ifft2(ifshift)
iimage = torch.abs(iimage.permute(1, 2, 0))

iimage -= torch.min(iimage)
iimage /= torch.max(iimage)

plt.imshow(iimage.cpu().detach().numpy())