In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#####UNET

In [16]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

In [17]:
def crop(image, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels (assumes that the input's size and the new size are
    even numbers).
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [18]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [19]:
def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

In [20]:
class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

In [21]:
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

In [22]:
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        #self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
        #Dilated Convolution
        self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) , dilation=2)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

In [23]:
class CBAM(nn.Module):
    def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max'], no_spatial=False):
        super(CBAM,self).__init__()
        self.ChannelGate = ChannelGate(gate_channels,reduction_ratio,pool_types)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self,x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

In [24]:
class FeatureMapBlock(nn.Module):


    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):

        x = self.conv(x)
        return x

In [25]:
class my_contracting_block(nn.Module):
     def __init__(self, input_channels, use_dropout=False, use_bn=True,use_CBAM=False):
        super(my_contracting_block, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.CBAM = CBAM(input_channels*2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
        self.use_CBAM = use_CBAM

     def forward(self, x):

        x = self.conv1(x)
        if self.use_CBAM:
            x = self.CBAM(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x

In [26]:
class std_contracting_block(nn.Module):
    def __init__(self, input_channels, use_bn=True, kernel_size=3, activation='relu',use_CBAM = False):
        super(std_contracting_block, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
        self.CBAM = CBAM(input_channels*2)
        if use_bn:
            self.instancenorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn
        self.use_CBAM = use_CBAM

    def forward(self, x):

        x = self.conv1(x)
        if self.use_CBAM:
            x = self.CBAM(x)
        if self.use_bn:
            x = self.instancenorm(x)

        x = self.activation(x)
        return x

In [27]:
class std_expanding_block(nn.Module):
     def __init__(self, input_channels, use_bn=True, use_CBAM=False):
        super(std_expanding_block, self).__init__()
        self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        if use_bn:
            self.instancenorm = nn.BatchNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.CBAM = CBAM(input_channels//2)
        self.activation = nn.ReLU()
        self.use_CBAM = use_CBAM

     def forward(self, x):

        x = self.conv1(x)
        if self.use_CBAM:
            x = self.CBAM(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

In [28]:
class my_expanding_block(nn.Module):
    def __init__(self, input_channels, use_dropout=False, use_bn=True,use_CBAM=False):
        super(my_expanding_block, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
        self.use_CBAM = use_CBAM
        self.CBAM = CBAM(input_channels//2)

    def forward(self, x, skip_con_x):

        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        if self.use_CBAM :
            x = self.CBAM(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

In [29]:
class ResidualBlock(nn.Module):

    def __init__(self, input_channels,use_CBAM=False):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.BatchNorm2d(input_channels)
        self.CBAM =CBAM(input_channels)
        self.use_CBAM = use_CBAM
        self.activation = nn.ReLU()

    def forward(self, x):

        original_x = x.clone()
        x = self.conv1(x)
        if self.use_CBAM:
            x= self.CBAM(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.instancenorm(x)
        return original_x + x

In [30]:
class my_gen(nn.Module):
     def __init__(self, input_channels, output_channels, hidden_channels=32, use_CBAM=False):
        super(my_gen, self).__init__()

        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = my_contracting_block(hidden_channels, use_dropout=True,use_CBAM=use_CBAM)
        self.contract2 = my_contracting_block(hidden_channels*2, use_dropout=True,use_CBAM=use_CBAM)
        self.contract3 =  my_contracting_block(hidden_channels*4, use_dropout=True,use_CBAM=use_CBAM)
        self.contract4 =  my_contracting_block(hidden_channels*8,use_CBAM=use_CBAM)
        self.contract5 = my_contracting_block(hidden_channels*16,use_CBAM=use_CBAM)
        self.contract6 = my_contracting_block(hidden_channels*32,use_CBAM=use_CBAM)
        self.expand0 = my_expanding_block(hidden_channels * 64,use_CBAM=use_CBAM)
        self.expand1 = my_expanding_block(hidden_channels * 32,use_CBAM=use_CBAM)
        self.expand2 =my_expanding_block(hidden_channels * 16,use_CBAM=use_CBAM)
        self.expand3 =my_expanding_block(hidden_channels * 8,use_CBAM=use_CBAM)
        self.expand4 = my_expanding_block(hidden_channels * 4,use_CBAM=use_CBAM)
        self.expand5 = my_expanding_block(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.sigmoid = torch.nn.Sigmoid()

     def forward(self, x):

        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        x5 = self.contract5(x4)
        x6 = self.contract6(x5)
        x7 = self.expand0(x6, x5)
        x8 = self.expand1(x7, x4)
        x9 = self.expand2(x8, x3)
        x10 = self.expand3(x9, x2)
        x11 = self.expand4(x10, x1)
        x12 = self.expand5(x11, x0)
        xn = self.downfeature(x12)
        return self.sigmoid(xn)

In [31]:
class res_gen(nn.Module):

    def __init__(self, input_channels, output_channels, hidden_channels=64,use_CBAM=False,res_CBAM=False):
        super(res_gen, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = std_contracting_block(hidden_channels,use_CBAM=use_CBAM)
        self.contract2 = std_contracting_block(hidden_channels * 2,use_CBAM=use_CBAM)
        res_mult = 4
        self.res0 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res1 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res2 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res3 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res4 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res5 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res6 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res7 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res8 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.expand2 = std_expanding_block(hidden_channels * 4,use_CBAM=use_CBAM)
        self.expand3 = std_expanding_block(hidden_channels * 2,use_CBAM=use_CBAM)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.tanh = torch.nn.Tanh()

    def forward(self, x):

        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.res0(x2)
        x4 = self.res1(x3)
        x5 = self.res2(x4)
        x6 = self.res3(x5)
        x7 = self.res4(x6)
        x8 = self.res5(x7)
        x9 = self.res6(x8)
        x10 = self.res7(x9)
        x11 = self.res8(x10)
        x12 = self.expand2(x11)
        x13 = self.expand3(x12)
        xn = self.downfeature(x13)
        return self.tanh(xn)

In [32]:
class my_res_gen(nn.Module):

    def __init__(self, input_channels, output_channels, hidden_channels=64,use_CBAM=False,res_CBAM=False):
        super(my_res_gen, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = my_contracting_block(hidden_channels,use_CBAM=use_CBAM)
        self.contract2 = my_contracting_block(hidden_channels * 2,use_CBAM=use_CBAM)
        res_mult = 4
        self.res0 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res1 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res2 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res3 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res4 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res5 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res6 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res7 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.res8 = ResidualBlock(hidden_channels * res_mult,use_CBAM=res_CBAM)
        self.expand2 = my_expanding_block(hidden_channels * 4,use_CBAM=use_CBAM)
        self.expand3 = my_expanding_block(hidden_channels * 2,use_CBAM=use_CBAM)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.tanh = torch.nn.Tanh()

    def forward(self, x):

        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.res0(x2)
        x4 = self.res1(x3)
        x5 = self.res2(x4)
        x6 = self.res3(x5)
        x7 = self.res4(x6)
        x8 = self.res5(x7)
        x9 = self.res6(x8)
        x10 = self.res7(x9)
        x11 = self.res8(x10)
        x12 = self.expand2(x11)
        x13 = self.expand3(x12)
        xn = self.downfeature(x13)
        return self.tanh(xn)

In [None]:
#####UNET

In [33]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
#from unet1 import unet_archs
##import unet_archs

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):


    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [34]:
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image



# Inspired by https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/datasets.py
class ImageDataset(Dataset):
    def __init__(self, root, transform=None, mode='train'):
        self.transform = transform
        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
        if len(self.files_A) > len(self.files_B):
            self.files_A, self.files_B = self.files_B, self.files_A
        self.new_perm()
        assert len(self.files_A) > 0

    def new_perm(self):
        self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
        item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
        if item_A.shape[0] != 3:
            item_A = item_A.repeat(3, 1, 1)
        if item_B.shape[0] != 3:
            item_B = item_B.repeat(3, 1, 1)
        if index == len(self) - 1:
            self.new_perm()
        # Old versions of PyTorch didn't support normalization for different-channeled images
        return (item_A - 0.5) * 2, (item_B - 0.5) * 2

    def __len__(self):
        return min(len(self.files_A), len(self.files_B))

In [35]:
class ContractingBlock(nn.Module):

    def __init__(self, input_channels, use_bn=True, kernel_size=3, activation='relu'):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
        if use_bn:
            self.instancenorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn

    def forward(self, x):

        x = self.conv1(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x



class FeatureMapBlock(nn.Module):

    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=7, padding=3, padding_mode='reflect')

    def forward(self, x):

        x = self.conv(x)
        return x

In [36]:
class Discriminator(nn.Module):

    def __init__(self, input_channels, hidden_channels=64):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False, kernel_size=4, activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels * 2, kernel_size=4, activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels * 4, kernel_size=4, activation='lrelu')
        self.final = nn.Conv2d(hidden_channels * 8, 1, kernel_size=1)

    def forward(self, x):
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        xn = self.final(x3)
        return xn

In [37]:
import torchvision.models as models

class VGG19(nn.Module):


    def __init__(self):
        super().__init__()
        vgg_features = models.vgg19(pretrained=True).features

        self.f1 = nn.Sequential(*[vgg_features[x] for x in range(2)])
        self.f2 = nn.Sequential(*[vgg_features[x] for x in range(2, 7)])
        self.f3 = nn.Sequential(*[vgg_features[x] for x in range(7, 12)])
        self.f4 = nn.Sequential(*[vgg_features[x] for x in range(12, 21)])
        self.f5 = nn.Sequential(*[vgg_features[x] for x in range(21, 30)])

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h1 = self.f1(x)
        h2 = self.f2(h1)
        h3 = self.f3(h2)
        h4 = self.f4(h3)
        h5 = self.f5(h4)
        return [h1, h2, h3, h4, h5]


In [38]:
import torch.nn.functional as F

adv_criterion = nn.MSELoss()
recon_criterion = nn.SmoothL1Loss()

n_epochs = 100
dim_A = 3
dim_B = 3
display_step = 200
batch_size = 4
lr =0.00005
load_shape = 286
target_shape = 256
device = 'cuda'
lambda1= lambda epoch: 1-epoch/n_epochs

In [None]:
transform = transforms.Compose([
    transforms.Resize(load_shape),
    transforms.RandomCrop(target_shape),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

import torchvision
dataset = ImageDataset("dataset", transform=transform)
print(dataset[2][0].shape)