In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torchinfo import summary

import easydict

from network_module import *


In [221]:
opt = easydict.EasyDict({
    "data_dir": '../dataset',
    "input_length": 220500,
    "image_height": 1025,
    "image_width": 431,
    "bbox_shape": 120,
    "mask_type": 'time_masking',
    "in_channels" : 2,
    "out_channels" : 1,
    "latent_channels" : 32,
    "pad_type": 'zero',
    "activation": 'lrelu',
    "norm":'in',
    "init_type":'xavier',
    "init_gain":0.02,
    "stage_num": 1,
    "batch_size": 4,
    "msd_latent" : 32,
    })

In [13]:
class PatchDiscriminator(nn.Module):
    def __init__(self, opt):
        super(PatchDiscriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(3, opt.latent_channels, 7, 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.latent_channels, opt.latent_channels * 2, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.latent_channels * 2, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.latent_channels * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)
        
    def forward(self, img, mask):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = torch.cat((img, mask), 1)
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x

In [14]:
discriminator = PatchDiscriminator(opt)
summary(discriminator, [(4, 2, 1024, 428), (4, 1, 1024, 428)], device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
PatchDiscriminator                       --                        --
├─Conv2dLayer: 1-1                       [4, 32, 1024, 428]        --
│    └─ZeroPad2d: 2-1                    [4, 3, 1030, 434]         --
│    └─SpectralNorm: 2-2                 [4, 32, 1024, 428]        --
│    └─LeakyReLU: 2-3                    [4, 32, 1024, 428]        --
├─Conv2dLayer: 1-2                       [4, 64, 512, 214]         --
│    └─ZeroPad2d: 2-4                    [4, 32, 1026, 430]        --
│    └─SpectralNorm: 2-5                 [4, 64, 512, 214]         --
│    └─InstanceNorm2d: 2-6               [4, 64, 512, 214]         --
│    └─LeakyReLU: 2-7                    [4, 64, 512, 214]         --
├─Conv2dLayer: 1-3                       [4, 128, 256, 107]        --
│    └─ZeroPad2d: 2-8                    [4, 64, 514, 216]         --
│    └─SpectralNorm: 2-9                 [4, 128, 256, 107]        --
│    └─Instance

In [47]:
[4, 1, 32, 13]

In [48]:
discriminator = PatchDiscriminator(opt)
summary(discriminator, [(4, 2, 1024, 428), (4, 1, 1024, 428)], device='cpu')discriminator = jj_Discriminator()
summary(discriminator, [(4, 1, 1024, 428), (4, 1, 1024, 428)],  device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
jj_Discriminator2                        --                        --
├─Conv2d: 1-1                            [4, 64, 1024, 214]        512
├─BatchNorm2d: 1-2                       [4, 64, 1024, 214]        128
├─LeakyReLU: 1-3                         [4, 64, 1024, 214]        --
├─Conv2d: 1-4                            [4, 128, 512, 107]        73,728
├─BatchNorm2d: 1-5                       [4, 128, 512, 107]        256
├─LeakyReLU: 1-6                         [4, 128, 512, 107]        --
├─Conv2d: 1-7                            [4, 256, 256, 54]         294,912
├─BatchNorm2d: 1-8                       [4, 256, 256, 54]         512
├─LeakyReLU: 1-9                         [4, 256, 256, 54]         --
├─Conv2d: 1-10                           [4, 512, 256, 54]         1,179,648
├─BatchNorm2d: 1-11                      [4, 512, 256, 54]         1,024
├─LeakyReLU: 1-12                        [4, 512, 256, 54]    

In [135]:
inp_spec = torch.zeros([4, 1, 1024, 428])

In [None]:
class Multi_Scale_Discriminator(nn.Module):
    def __init__(self, opt):
        super(multi_scale_discriminator, self).__init__()
        self.frame_lengths = [16, 32, 64, 128]
        
        self.scale0_discriminator = Scale0_Discriminator()
        self.scale1_discriminator = Scale1_Discriminator()
        self.scale2_discriminator = Scale2_Discriminator()
        self.scale3_discriminator = Scale3_Discriminator()
        
    def make_scale_input(self, frame_length, mask_start, mask_end, max_time_index):
        spec_end_start = mask_start + 0.5 * frame_length
        spec_end_end = mask_end + 0.5 * frame_length
        spec_end = (torch.rand(self.batch_size) * (spec_end_end - spec_end_start + 1) + spec_end_start).int()        
        spec_end = torch.min(spec_end, max_time_index)
        spec_start = spec_end - frame_length
        spec_start = torch.max(spec_start, 0)
        return spec_start
    
    def forward(self, img, mask, mask_start, mask_end):
        self.batch_size = img.shape[0]
        scale0_spec_start = self.make_scale_input(self.frame_lengths[0], mask_start, mask_end, img.shape[-1]-1)
        scale1_spec_start = self.make_scale_input(self.frame_lengths[1], mask_start, mask_end, img.shape[-1]-1)
        scale2_spec_start = self.make_scale_input(self.frame_lengths[2], mask_start, mask_end, img.shape[-1]-1)
        scale3_spec_start = self.make_scale_input(self.frame_lengths[3], mask_start, mask_end, img.shape[-1]-1)
                
        scale0_output = self.scale0_discriminator(img[...,scale0_spec_start:scale0_spec_start + frame_lengths[0] + 1])
        scale1_output = self.scale1_discriminator(img[...,scale1_spec_start:scale1_spec_start + frame_lengths[1] + 1])
        scale2_output = self.scale2_discriminator(img[...,scale2_spec_start:scale2_spec_start + frame_lengths[2] + 1])
        scale3_output = self.scale3_discriminator(img[...,scale3_spec_start:scale3_spec_start + frame_lengths[3] + 1])
        
        return scale1_output, scale2_output, scale3_output, scale4_output

In [238]:
class Scale0_Discriminator(nn.Module):
    def __init__(self, opt):
        super(Scale0_Discriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(1, opt.msd_latent, (7,4), 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.msd_latent, opt.latent_channels * 2, 4, (2,1), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.msd_latent * 2, opt.latent_channels * 4, 4, (2,2), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, (2,2), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, (2,2), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.msd_latent * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)
        
    def forward(self, img):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = img
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x

class Scale1_Discriminator(nn.Module):
    def __init__(self, opt):
        super(Scale1_Discriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(1, opt.msd_latent, 7, 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.msd_latent, opt.latent_channels * 2, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.msd_latent * 2, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.msd_latent * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)
        
    def forward(self, img):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = img
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x

class Scale2_Discriminator(nn.Module):
    def __init__(self, opt):
        super(Scale2_Discriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(1, opt.msd_latent, 7, 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.msd_latent, opt.latent_channels * 2, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.msd_latent * 2, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.msd_latent * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)
        
    def forward(self, img):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = img
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x
    
class Scale3_Discriminator(nn.Module):
    def __init__(self, opt):
        super(Scale3_Discriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(1, opt.msd_latent, 7, 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.msd_latent, opt.latent_channels * 2, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.msd_latent * 2, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.msd_latent * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.msd_latent * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)
        
    def forward(self, img):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = img
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x    

In [239]:
# discriminator = Scale0_Discriminator(opt)
# summary(discriminator, (4, 1, 1024, 16), device='cpu')

# discriminator = Scale1_Discriminator(opt)
# summary(discriminator, (4, 1, 1024, 32), device='cpu')

# discriminator = Scale2_Discriminator(opt)
# summary(discriminator, (4, 1, 1024, 64), device='cpu')

discriminator = Scale3_Discriminator(opt)
summary(discriminator, (4, 1, 1024, 128), device='cpu')

torch.Size([4, 1, 32])


Layer (type:depth-idx)                   Output Shape              Param #
Scale3_Discriminator                     --                        --
├─Conv2dLayer: 1-1                       [4, 32, 1024, 128]        --
│    └─ZeroPad2d: 2-1                    [4, 1, 1030, 134]         --
│    └─SpectralNorm: 2-2                 [4, 32, 1024, 128]        --
│    └─LeakyReLU: 2-3                    [4, 32, 1024, 128]        --
├─Conv2dLayer: 1-2                       [4, 64, 512, 64]          --
│    └─ZeroPad2d: 2-4                    [4, 32, 1026, 130]        --
│    └─SpectralNorm: 2-5                 [4, 64, 512, 64]          --
│    └─InstanceNorm2d: 2-6               [4, 64, 512, 64]          --
│    └─LeakyReLU: 2-7                    [4, 64, 512, 64]          --
├─Conv2dLayer: 1-3                       [4, 128, 256, 32]         --
│    └─ZeroPad2d: 2-8                    [4, 64, 514, 66]          --
│    └─SpectralNorm: 2-9                 [4, 128, 256, 32]         --
│    └─Instance

In [235]:
a = torch.ones([4, 1, 1025, 4])
b = torch.mean(a, -1)
b.shape

torch.Size([4, 1, 1025])

In [167]:
mask_start = torch.tensor(list([10, 50, 100, 200]))
mask_end = torch.tensor(list([20, 80, 150, 400]))
mask_start, mask_end

(tensor([ 10,  50, 100, 200]), tensor([ 20,  80, 150, 400]))

In [246]:
a = torch.ones([4, 1, 1025, 4])
print(a.shape)
torch.mean(a.float())

torch.Size([4, 1, 1025, 4])


tensor(1.)

In [None]:
class Scale1_Discriminator(nn.Module):
    def __init__(self, opt):
        super(scale_Discriminator, self).__init__()
        # Down sampling
        self.block1 = Conv2dLayer(1, opt.latent_channels, 4, 1, 3, pad_type = opt.pad_type, activation = opt.activation, norm = 'none', sn = True)
        self.block2 = Conv2dLayer(opt.latent_channels, opt.latent_channels * 2, 4, (2, 1), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block3 = Conv2dLayer(opt.latent_channels * 2, opt.latent_channels * 4, 4, (2, 1), 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block4 = Conv2dLayer(opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block5 = Conv2dLayer(opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type = opt.pad_type, activation = opt.activation, norm = opt.norm, sn = True)
        self.block6 = Conv2dLayer(opt.latent_channels * 4, 1, 4, 2, 1, pad_type = opt.pad_type, activation = 'none', norm = 'none', sn = True)

    
    def forward(self, img):
        # the input x should contain 4 channels because it is a combination of recon image and mask
        x = img
        x = self.block1(x)                                      # out: [B, 64, 256, 256]
        x = self.block2(x)                                      # out: [B, 128, 128, 128]
        x = self.block3(x)                                      # out: [B, 256, 64, 64]
        x = self.block4(x)                                      # out: [B, 256, 32, 32]
        x = self.block5(x)                                      # out: [B, 256, 16, 16]
        x = self.block6(x)                                      # out: [B, 256, 8, 8]
        return x

In [None]:
discriminator = scale_Discriminator(opt)
summary(discriminator, (4, 1, 1024, 428), device='cpu')

In [168]:
torch.randint(mask_start, mask_end, 1)

TypeError: randint() received an invalid combination of arguments - got (Tensor, Tensor, int), but expected one of:
 * (int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (int low, int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)


In [260]:
frame_length = 16
spec_start = mask_start
spec_start

tensor([ 10,  50, 100, 200])

In [258]:
index_matrix = torch.zeros([4, frame_length], requires_grad=False)
index_matrix += torch.arange(0, frame_length)
index_matrix += spec_start.unsqueeze(-1)
index_matrix = index_matrix.unsqueeze(1).unsqueeze(1)
index_matrix = index_matrix.expand(fake_data.shape[0], fake_data.shape[1], 128, index_matrix.shape[-1])
index_matrix = index_matrix.type(torch.int64)

fake_data = db_to_linear(fake_data, opt)
fake_data_pad = torch.nn.functional.pad(fake_data, (0, 0, 0, 1), mode='constant', value=0)
mel_fake_data = audio_utils.convert_mel_scale(fake_data_pad)
log_mel_fake_data = torch.log10(mel_fake_data+1e-7)
mel_spec_22050 = torch.gather(log_mel_fake_data, -1, index_matrix)


RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 1

In [265]:
index_matrix = torch.zeros([4, frame_length], requires_grad=False)
index_matrix += torch.arange(0, frame_length)
index_matrix += spec_start.unsqueeze(-1)
index_matrix.shape

torch.Size([4, 16])

In [264]:
spec_start

tensor([ 10,  50, 100, 200])

In [145]:
spec_end_start = mask_start + 0.5 * frame_length
spec_end_end = mask_end + 0.5 * frame_length
spec_end = torch.randint(spec_end_start, spec_end_end+1, 1)
spec_end = torch.min(spec_end, img.shape[-1])
spec_start = spec_end - frame_length
spec_start = torch.max(spec_start, 0)

120.67999999999999

In [266]:
5 / 471 * 16

0.16985138004246284

In [193]:
170ms, 340ms, 680ms, 1320ms

tensor([ 11,  67, 110, 281], dtype=torch.int32)

In [172]:
mask_start

tensor([ 10,  50, 100, 200])

In [173]:
mask_end

tensor([ 20,  80, 150, 400])

In [175]:
mask_region_length = mask_end - mask_start
mask_region_length

tensor([ 10,  30,  50, 200])

In [271]:
mask = torch.zeros([1, 1, 1025, 431])
mask[...,300:] = 1

In [281]:
randn = torch.randn(mask.shape) + 1
randn

tensor([[[[ 0.2558,  0.3840,  0.5906,  ...,  1.4871,  0.9725,  0.7602],
          [ 1.2828,  3.1549,  1.8883,  ...,  1.5925,  1.7560, -0.5502],
          [ 1.7750, -1.3085,  2.3714,  ...,  0.3994,  0.6844, -1.2932],
          ...,
          [ 0.9250,  0.8512, -0.6145,  ..., -0.1641,  1.2646,  0.6053],
          [ 1.1421,  0.0322,  2.3493,  ...,  0.3243,  0.6510,  0.7888],
          [ 1.0212,  2.9284,  1.3780,  ...,  0.1066,  1.8088,  2.3443]]]])

In [282]:
mask2 = mask*randn

In [283]:
mask2

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  1.4871,  0.9725,  0.7602],
          [ 0.0000,  0.0000,  0.0000,  ...,  1.5925,  1.7560, -0.5502],
          [ 0.0000, -0.0000,  0.0000,  ...,  0.3994,  0.6844, -1.2932],
          ...,
          [ 0.0000,  0.0000, -0.0000,  ..., -0.1641,  1.2646,  0.6053],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.3243,  0.6510,  0.7888],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.1066,  1.8088,  2.3443]]]])

In [280]:
mask

tensor([[[[0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 1.],
          [0., 0., 0.,  ..., 1., 1., 1.]]]])

In [284]:
-0.0000 == 0

True