In [1]:
import torch
import torch.nn as nn
from diff_aug import DiffAugment
import numpy as np


In [2]:
class MLP(nn.Module):
    def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.0):
        super().__init__()
        if not hid_feat:
            hid_feat = in_feat
        if not out_feat:
            out_feat = in_feat
        self.fc1 = nn.Linear(in_feat, hid_feat)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hid_feat, out_feat)
        self.droprateout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return self.droprateout(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, attention_dropout=0.0, proj_dropout=0.0):
        super().__init__()
        self.heads = heads
        self.dim = dim
        # print(dim)
        self.scale = 1.0 / dim**0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(proj_dropout))

    def forward(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.heads, c // self.heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        # print("qkv.shape:", qkv.shape, "Dim:", self.dim, x.shape) 
        dot = (q @ k.transpose(-2, -1)) * self.scale
        attn = dot.softmax(dim=-1)
        attn = self.attention_dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.out(x)
        return x


class ImgPatches(nn.Module):
    def __init__(self, input_channel=3, dim=768, patch_size=4):
        super().__init__()
        self.patch_embed = nn.Conv2d(
            input_channel, dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, img):
        patches = self.patch_embed(img).flatten(2).transpose(1, 2)
        return patches


def UpSampling(x, H, W, psfac=2):
    B, N, C = x.size()
    assert N == H * W
    x = x.permute(0, 2, 1)
    x = x.view(-1, C, H, W)
    x = nn.PixelShuffle(psfac)(x)
    B, C, H, W = x.size()
    x = x.view(-1, C, H * W)
    x = x.permute(0, 2, 1)
    return x, H, W


class Encoder_Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, drop_rate=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, drop_rate, drop_rate)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, dim * mlp_ratio, dropout=drop_rate)

    def forward(self, x):
        x1 = self.ln1(x)
        x = x + self.attn(x1)
        x2 = self.ln2(x)
        x = x + self.mlp(x2)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, depth, dim, heads, mlp_ratio=4, drop_rate=0.0):
        super().__init__()
        self.Encoder_Blocks = nn.ModuleList(
            [Encoder_Block(dim, heads, mlp_ratio, drop_rate) for i in range(depth)]
        )

    def forward(self, x):
        for Encoder_Block in self.Encoder_Blocks:
            x = Encoder_Block(x)
        return x


In [3]:
def to_square(x: torch.Tensor):
    B, N, C = x.size()
    x = x.permute(0, 2, 1)
    x = x.view(-1, C, int(N**.5), int(N**.5))
    return x
def to_flat(x: torch.Tensor):
    B, C, H, W = x.size()
    x = x.view(-1, C, H * W)
    x = x.permute(0, 2, 1)
    return x

class convUp(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        print(in_channels, out_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bnorm = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
    def forward(self, x):
        x = to_square(x)
        x = self.conv(x)
        x = self.bnorm(x)
        x = self.act(x)
        x = to_flat(x)
        return x

In [4]:
class Generator(nn.Module):
    """docstring for Generator"""

    # ,device=device):
    def __init__(
        self,
        depth1=5,
        depth2=4,
        depth3=2,
        initial_size=8,
        dim=384,
        heads=4,
        mlp_ratio=4,
        drop_rate=0.0,
        latent_dim=1024,
        output_channels=1,
        psfac=2,
        cvupmult = 1
    ):
        super(Generator, self).__init__()

        # self.device = device
        self.initial_size = initial_size
        self.dim = dim
        self.depth1 = depth1
        self.depth2 = depth2
        self.depth3 = depth3
        self.heads = heads
        self.mlp_ratio = mlp_ratio
        self.droprate_rate = drop_rate
        
        self.mlp = nn.Linear(latent_dim, (self.initial_size**2) * self.dim)

        self.positional_embedding_1 = nn.Parameter(
            torch.zeros(1, (initial_size**2), self.dim)
        )
        self.positional_embedding_2 = nn.Parameter(
            torch.zeros(1, (initial_size * (psfac**1)) ** 2, (self.dim // (psfac**2)))
        )
        self.positional_embedding_3 = nn.Parameter(
            torch.zeros(1, (initial_size * (psfac**2)) ** 2, heads)
        )

        self.TransformerEncoder_encoder1 = TransformerEncoder(
            depth=self.depth1,
            dim=self.dim,
            heads=self.heads,
            mlp_ratio=self.mlp_ratio,
            drop_rate=self.droprate_rate,
        )
        self.TransformerEncoder_encoder2 = TransformerEncoder(
            depth=self.depth2,
            dim=(self.dim // (psfac**2)),
            heads=self.heads,
            mlp_ratio=self.mlp_ratio,
            drop_rate=self.droprate_rate,
        )
        self.TransformerEncoder_encoder3 = TransformerEncoder(
            depth=self.depth3,
            dim=heads,
            heads=self.heads,
            mlp_ratio=self.mlp_ratio,
            drop_rate=self.droprate_rate,
        )

        self.linear = nn.Sequential(
            nn.Conv2d(self.dim // psfac**4, output_channels, 1, 1, 0)
        )
        self.psfac = psfac

        self.cvup1 = convUp(in_channels = int(dim/(psfac**2)) ,out_channels= int(dim/(psfac**2))*cvupmult, kernel_size= 3,stride = 1, padding = 1)
        self.cvdown = convUp(in_channels = int(dim/(psfac**4))*(cvupmult) ,out_channels= heads, kernel_size= 3,stride = 1, padding = 1)
    def forward(self, noise):
        H, W = self.initial_size, self.initial_size
        x = self.mlp(noise).view(-1, self.initial_size**2, self.dim)
        print(x.shape, "initial MLP")
        x = x + self.positional_embedding_1
        print(x.shape, "pos embedding one")
        x = self.TransformerEncoder_encoder1(x)
        print(x.shape, "transformer encoder 1")
        x, H, W = UpSampling(x, H, W, psfac=self.psfac)
        print(x.shape, "upsampling 1")
        x = x + self.positional_embedding_2
        print(x.shape, "pos embedding 2")
        x = self.TransformerEncoder_encoder2(x)
        print(x.shape, "transformer encoder 2")
        x = self.cvup1(x)
        print(x.shape)
        x, H, W = UpSampling(x, H, W, psfac=self.psfac)
        print(x.shape, "upsampling 2")
        x = x + self.positional_embedding_3
        print(x.shape, "pos embedding 3")
        x = self.TransformerEncoder_encoder3(x)
        print(x.shape, "transformer encoder 3")
        x = self.linear(x.permute(0, 2, 1).view(-1, self.dim // (self.psfac**4), H, W))
        print(x.shape, "reshaping for output")
        return x
    
    def generate(self, num_images):
        return self.forward(torch.cuda.FloatTensor(np.random.normal(0, 1, (num_images, 256))))
generator = Generator(
        depth1=1,
        depth2=1,
        depth3=1,
        initial_size=8,
        dim=128*2,
        heads=2,
        mlp_ratio=1,
        drop_rate=0.5,
        latent_dim=256,
        psfac=4,
        cvupmult = 2
    ).to("cuda:0")

16 32
2 2


In [5]:
generator.generate(2).shape


torch.Size([2, 64, 256]) initial MLP
torch.Size([2, 64, 256]) pos embedding one
torch.Size([2, 64, 256]) transformer encoder 1
torch.Size([2, 1024, 16]) upsampling 1
torch.Size([2, 1024, 16]) pos embedding 2
torch.Size([2, 1024, 16]) transformer encoder 2
torch.Size([2, 1024, 32])
torch.Size([2, 16384, 2]) upsampling 2
torch.Size([2, 16384, 2]) pos embedding 3


OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB (GPU 0; 8.00 GiB total capacity; 4.06 GiB already allocated; 2.86 GiB free; 4.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF