In [438]:
import torch
from torch import nn

In [439]:
# def pixel_norm(x):
#     eps= 10e-8
#     return x/torch.sqrt(torch.mean(x,dim=1)+eps)

class Pixel_norm(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps=10e-8

    def forward(self,x):
            return x/torch.sqrt(torch.mean(x,dim=1,keepdim=True)+self.eps)


class WConv(nn.Module):
    def __init__(self,in_chan,out_chan,kernel=3,stride=1,padding=1):
        super().__init__()
        self.conv=nn.Conv2d(in_chan,out_chan,kernel,stride,padding)
        self.equalized_weights =(2/(in_chan*kernel**2))**0.5

    def forward(self,x):
        return self.conv(x*self.equalized_weights)


class WConvTrans(nn.Module):
    def __init__(self,in_chan,out_chan,kernel=3,stride=1,padding=1):
        super().__init__()
        self.conv=nn.ConvTranspose2d(in_chan,out_chan,kernel,stride,padding)
        self.equalized_weights =(2/(out_chan*kernel**2))**0.5

    def forward(self,x):
        return self.conv(x*self.equalized_weights)


class ConvBlock(nn.Module):
    def __init__(self,in_chan,out_chan):
        super().__init__()
        self.conv = WConv(in_chan,out_chan)
        self.conv1 = WConv(out_chan,out_chan)
        self.pix = Pixel_norm()
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self,x):
        return self.pix(self.leaky(self.conv1(self.pix(self.leaky(self.conv(x))))))

class DisBlock(nn.Module):
    def __init__(self,in_chan,out_chan):
        super().__init__()
        self.conv = WConv(in_chan,in_chan)
        self.conv1 = WConv(in_chan,out_chan)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self,x):
        return self.leaky(self.conv1(self.leaky(self.conv(x))))

In [440]:
class Generator(nn.Module):
    def __init__(self,in_dim,img_channel):
        super(Generator,self).__init__()
        self.blocks =nn.ModuleList()
        self.rgb_layers = nn.ModuleList()
        self.rgb_layers.append(WConv(in_dim,img_channel,1,1,0))
        self.up = nn.Upsample(scale_factor=2,mode='nearest')
        self.blocks.append(nn.Sequential(
            WConvTrans(in_dim,in_dim,4,1,0),
            nn.LeakyReLU(0.2),
            Pixel_norm(),
            WConv(in_dim,in_dim),
            nn.LeakyReLU(0.2),
            Pixel_norm(),
        ))

        for i in range(3):
            self.blocks.append(
                ConvBlock(in_dim,in_dim)
            )
            self.rgb_layers.append(WConv(in_dim,img_channel,1,1,0))
        n=0
        while in_dim//2**n>16:

            self.blocks.append(
                ConvBlock(in_dim//2**n,in_dim//2**(n+1))
            )
            self.rgb_layers.append(WConv(in_dim//2**(n),img_channel,1,1,0))

            n+=1
        self.rgb_layers.append(WConv(in_dim//2**(n),img_channel,1,1,0))

    def forward(self,x, out_size,alpha):
        i=0
        while 2**i<out_size:
            i+=1
        x=self.blocks[0](x)
        x_up=x
        for j in range(1,i-1):
            print(j)
            x_up=self.up(x)
            x=self.blocks[j](x_up)
        x_rgb = self.rgb_layers[i-2](x_up)
        print(x_rgb.shape)
        x_rgb_up = self.rgb_layers[i-1](x)
        print(x_rgb_up.shape)
        return torch.tanh((1-alpha)*x_rgb+alpha*x_rgb_up)

In [441]:
class Discriminator(nn.Module):
    def __init__(self,in_chan,img_channel):
        super(Discriminator,self).__init__()
        self.pool = nn.AvgPool2d(2,2)
        self.blocks =nn.ModuleList()
        self.rgb_layers = nn.ModuleList()
        # self.rgb_layers.append(WConv(img_channel,img_channel,1,1,0))
        n=16
        while n<in_chan:
            self.blocks.append(DisBlock(n,n*2))
            self.rgb_layers.append(WConv(img_channel,n,1,1,0))
            n=n*2

        for i in range(3):
            self.blocks.append(DisBlock(in_chan,in_chan))
            self.rgb_layers.append(WConv(img_channel,in_chan,1,1,0))
        self.blocks.append(nn.Sequential(
            WConv(in_chan+1,in_chan),
            nn.LeakyReLU(0.2),
            WConv(in_chan,in_chan,4,1,0),
            nn.LeakyReLU(0.2),
            WConv(in_chan,1,1,1,0),
            nn.Sigmoid()
        ))
        self.rgb_layers.append(WConv(img_channel,in_chan,1,1,0))
    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self,x,size,alpha=0.5):
        i=0
        while 2**i<size:
            i+=1
        x_rgb=self.rgb_layers[-(i-1)](x)
        if i==2:
            x=self.minibatch_std(x_rgb)
            return (self.blocks[-1](x)).view(x.shape[0], -1)

        x_down= self.pool(self.blocks[-(i-1)](x_rgb))
        x=self.rgb_layers[-(i-2)](self.pool(x))
        x=(1-alpha)*x_down+alpha*x

        for j in range(i-2,1,-1):
            x=self.blocks[-j](x)
            x=self.pool(x)
        x=self.minibatch_std(x)
        return (self.blocks[-1](x)).view(x.shape[0], -1)




In [442]:
from torchsummary import summary
gen = Generator(512,3)
x= torch.randn((2,512,1,1))
gen.train()
dis = Discriminator(512,3)
dis.train()

for i in range(2,11):
    im=gen(x,2**i,0.5)
    print(im.shape)
    print(dis(im,2**i))





torch.Size([2, 3, 4, 4])
torch.Size([2, 3, 4, 4])
torch.Size([2, 3, 4, 4])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
torch.Size([2, 3, 8, 8])
torch.Size([2, 3, 8, 8])
torch.Size([2, 3, 8, 8])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
2
torch.Size([2, 3, 16, 16])
torch.Size([2, 3, 16, 16])
torch.Size([2, 3, 16, 16])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
2
3
torch.Size([2, 3, 32, 32])
torch.Size([2, 3, 32, 32])
torch.Size([2, 3, 32, 32])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
2
3
4
torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 64, 64])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
2
3
4
5
torch.Size([2, 3, 128, 128])
torch.Size([2, 3, 128, 128])
torch.Size([2, 3, 128, 128])
tensor([[0.4948],
        [0.4948]], grad_fn=<ViewBackward0>)
1
2
3
4
5
6
torch.Size([2, 3, 256, 256])
torch.Size([2, 3, 256, 256])
torch.Size([2, 3, 256, 256])
tensor([[0.4948],
 

In [443]:
for j in range(4-2,1,-1):
    print(j)

2
