In [2]:
%load_ext lab_black

In [12]:
import torch
from torch import nn
from torch.nn.utils import spectral_norm
import torch.nn.functional as F

In [4]:
from functools import partial


def add_func(a, b, c):
    print(f"{a}+{b}+{c}")
    return a + b + c


add_list = list(map(partial(add_func, b=5, c=3), [0, 1, 2, 3, 4, 5]))

print(add_list)

0+5+3
1+5+3
2+5+3
3+5+3
4+5+3
5+5+3
[8, 9, 10, 11, 12, 13]


In [5]:
x = torch.rand((1, 3, 64, 64))
upsample = lambda x: torch.nn.functional.interpolate(x, scale_factor=2)
up_x = upsample(x)

In [13]:
class SelfAttention(nn.Module):
    """特徴量マップのための自己注意機構"""

    def __init__(self, ch):
        super(SelfAttention, self).__init__()
        self.ch = ch
        self.theta = spectral_norm(
            nn.Conv2d(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
        )
        self.phi = spectral_norm(
            nn.Conv2d(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
        )
        self.g = spectral_norm(
            nn.Conv2d(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
        )
        self.o = spectral_norm(
            nn.Conv2d(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
        )
        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, x, y=None):
        # apply convs
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), [2, 2])
        g = F.max_pool2d(self.g(x), [2, 2])
        # perform reshapes
        theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
        phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
        g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
        # matmul and softmax to get attention maps
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
        # Attention map times g path
        o = self.o(
            torch.bmm(g, beta.transpose(1, 2)).view(
                -1, self.ch // 2, x.shape[2], x.shape[3]
            )
        )
        return self.gamma * o + x

In [11]:
SelfAttention(64)

SelfAttention(
  (query_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
  (key_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
  (value_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (softmax): Softmax(dim=-1)
)

In [18]:
class ResBlock_UP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ResBlock_UP, self).__init__()

        self.sn_conv0 = spectral_norm(nn.Conv2d(in_ch, out_ch, 3, 1, 1, bias=True))
        self.sn_conv1 = spectral_norm(nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=True))
        self.sn_conv_sc = spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=True))

        self.bn0 = nn.BatchNorm2d(in_ch)
        self.bn1 = nn.BatchNorm2d(out_ch)

        self.activation = nn.ReLU(inplace=False)

        self.upsample = lambda x: torch.nn.functional.interpolate(x, scale_factor=2)
        self.learnable_sc = in_ch != out_ch

    def residual(self, z, y):

        h = self.bn0(z, y)
        h = self.activation(h)
        h = self.upsample(h)
        h = self.sn_conv0(h)
        h = self.bn1(h)
        h = self.activation(h)
        h = self.sn_conv1(h)
        return h

    def shortcut(self, x):
        if self.learnable_sc:
            x = self.upsample(x)
            x = self.sn_conv_sc(x)
            return x
        else:
            return x

    def foward(self, z, y):
        return self.residual(z, y) + self.shortcut(z)

In [19]:
attention = "128"
ch = 64
{2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) for i in range(3, 9)}
[ch * item for item in [16, 8, 8, 4, 2, 1]]

[1024, 512, 512, 256, 128, 64]

In [22]:
class Generator(nn.Module):
    def __init__(self, ngpu, ch=64, dim_z=128):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.sn_linear = spectral_norm(nn.Linear(dim_z, 4 * 4 * 16 * ch, bias=True))
        self.res0 = ResBlock_UP(16 * ch, 16 * ch)
        self.res1 = ResBlock_UP(16 * ch, 8 * ch)
        self.res2 = ResBlock_UP(8 * ch, 8 * ch)
        self.res3 = ResBlock_UP(8 * ch, 4 * ch)
        self.res4 = ResBlock_UP(4 * ch, 2 * ch)
        self.attn = SelfAttention(2 * ch)
        self.res5 = ResBlock_UP(2 * ch, ch)

        self.output = nn.Sequential(
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=False),
            spectral_norm(nn.Conv2d(ch, 3, bias=True)),
        )

    def forward(self, input):
        return

In [None]:
class Discriminator(nn.Module):
    def __init__(self,ngpu):
        super(Discriminator,self).__init__()
        self.ngpu = ngpu
        
    def forward(self,input):
        

In [3]:
def weight_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [8]:
m = torch.nn.utils.spectral_norm(nn.Linear(20, 40))

In [9]:
m

Linear(in_features=20, out_features=40, bias=True)