<a href="https://colab.research.google.com/github/Vaibhavs10/scratchpad/blob/main/UNet_error.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

# Basic UNet works

In [None]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.up = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=2)),
        ])

        self.up_out = nn.utils.weight_norm(nn.Conv2d(256, 1, kernel_size=3, stride=1))

        self.down = nn.ModuleList([
            nn.utils.weight_norm(nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(256, 64, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(128, 32, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(64, 16, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(32, 8, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=2)),
        ])
        
        self.down_out = nn.utils.weight_norm(nn.Conv2d(1, 1, kernel_size=3, stride=1))

        self.fc = nn.Linear(8000, 1)  # this needs to be changed everytime the window length is changes. It would be nice if this could be done dynamically.

    def forward(self, y):
        upsample_outputs = list()
        feature_maps = list()
        for d_up in self.up:
            y = d_up(y)
            y = nn.functional.leaky_relu(y, 0.2)
            upsample_outputs.append(y) 

        up_f_map = self.up_out(y)
        feature_maps.append(up_f_map)

        counter = 0
        uo_len = len(upsample_outputs)
        for d_down in self.down:
            print(d_down)
            if counter == 0:
                y = d_down(y)
                y = nn.functional.leaky_relu(y, 0.2)
            if counter >= 1:
                print(y.shape)
                print(upsample_outputs[uo_len - counter - 1].shape)
                _ = torch.cat((y, upsample_outputs[uo_len - counter -1]), dim=1)
                print(_.shape)
                y = d_down(torch.cat((y, upsample_outputs[uo_len - counter - 1]), dim=1))
                y = nn.functional.leaky_relu(y, 0.2)
            counter+=1
        
        down_f_map = self.down_out(y)
        feature_maps.append(down_f_map)
        
        y = torch.flatten(y, 1, -1)
        y = self.fc(y)

        return y, feature_maps

In [None]:
fake = torch.randn([2, 100, 80])  # [Batch, Sequence Length, Spectrogram Buckets]
real = torch.randn([2, 100, 80])

d = DiscriminatorNet()
y, fmaps = d(fake.unsqueeze(1))

# Updating the stride + kernel size borks it

In [None]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.up = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=2)),
        ])

        self.up_out = nn.utils.weight_norm(nn.Conv2d(256, 1, kernel_size=3, stride=1))

        self.down = nn.ModuleList([
            nn.utils.weight_norm(nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(256, 64, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(64, 16, kernel_size=3, stride=1, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(32, 8, kernel_size=4, stride=2, padding=2)),
            nn.utils.weight_norm(nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=2)),
        ])
        
        self.down_out = nn.utils.weight_norm(nn.Conv2d(1, 1, kernel_size=3, stride=1))

        self.fc = nn.Linear(8000, 1)  # this needs to be changed everytime the window length is changes. It would be nice if this could be done dynamically.

    def forward(self, y):
        upsample_outputs = list()
        feature_maps = list()
        for d_up in self.up:
            y = d_up(y)
            y = nn.functional.leaky_relu(y, 0.2)
            upsample_outputs.append(y) 

        up_f_map = self.up_out(y)
        feature_maps.append(up_f_map)

        counter = 0
        uo_len = len(upsample_outputs)
        for d_down in self.down:
            print(d_down)
            if counter == 0:
                y = d_down(y)
                y = nn.functional.leaky_relu(y, 0.2)
            if counter >= 1:
                print(y.shape)
                print(upsample_outputs[uo_len - counter - 1].shape)
                _ = torch.cat((y, upsample_outputs[uo_len - counter -1]), dim=1)
                print(_.shape)
                y = d_down(torch.cat((y, upsample_outputs[uo_len - counter - 1]), dim=1))
                y = nn.functional.leaky_relu(y, 0.2)
            counter+=1
        
        down_f_map = self.down_out(y)
        feature_maps.append(down_f_map)
        
        y = torch.flatten(y, 1, -1)
        y = self.fc(y)

        return y, feature_maps

In [None]:
fake = torch.randn([2, 100, 80])  # [Batch, Sequence Length, Spectrogram Buckets]
real = torch.randn([2, 100, 80])

d = DiscriminatorNet()
y, fmaps = d(fake.unsqueeze(1))