In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
torch.set_default_dtype(torch.float32)

import sys
from jupyterthemes import jtplot

jtplot.style()
print(torch.__version__)
print('Python ',sys.version_info)

import pdb

1.9.0
Python  sys.version_info(major=3, minor=8, micro=11, releaselevel='final', serial=0)


In [2]:
T = 1250
Dlstms = 2
testSpec = torch.randn((1,1,T,161),dtype=torch.complex64)
testReIm = torch.zeros(1,2,T,161)
testReIm[0,0,:,:] = testSpec.real
testReIm[0,1,:,:] = testSpec.imag

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(16,track_running_stats=True),
            nn.ELU()
        )

        self.enc2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(32,track_running_stats=True),
            nn.ELU()
        )

        self.enc3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(64,track_running_stats=True),
            nn.ELU()
        )

        self.enc4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(128,track_running_stats=True),
            nn.ELU()
        )

        self.enc5 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(256,track_running_stats=True),
            nn.ELU()
        )
    def forward(self, input):
        e1 = self.enc1(input)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        return e5, [e1,e2,e3,e4,e5]
        
encoder = Encoder()

In [7]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
# (1,256,1250,4)
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(128,track_running_stats=True),
            nn.ELU()
        )

        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(64,track_running_stats=True),
            nn.ELU()
        )

        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(1,3), stride=(1,2)),
            nn.BatchNorm2d(32,track_running_stats=True),
            nn.ELU()
        )

        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=16,
                               kernel_size=(1,3), stride=(1,2), output_padding=(0,1)),
            nn.BatchNorm2d(16,track_running_stats=True),
            nn.ELU()
        )

        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(1,3),
                               stride=(1,2))    
        )
    def forward(self, input, enc_list):
        #pdb.set_trace()
        d5 = self.dec5(torch.cat([input,enc_list[-1]], dim=1))
        d4 = self.dec4(torch.cat([d5,enc_list[-2]], dim=1))
        d3 = self.dec3(torch.cat([d4,enc_list[-3]], dim=1))
        d2 = self.dec2(torch.cat([d3,enc_list[-4]], dim=1))
        d1 = self.dec1(torch.cat([d2,enc_list[-5]], dim=1))
        return d1
    
dec_Re = Decoder()
dec_Im = Decoder()      

In [8]:
class Recurrent(nn.Module):
    def __init__(self):
        super(Recurrent,self).__init__()
        self.lstms = nn.LSTM(input_size=1024,hidden_size=1024,
                             num_layers=Dlstms,batch_first=True)
    def forward(self,x,state_in):
        output, state_out = self.lstms(x,state_in)
        return output, state_out   
recurrent = Recurrent()

def init_state():
    hn = torch.randn(Dlstms,1,1024)
    cn = torch.randn(Dlstms,1,1024)
    return (hn,cn)

In [15]:
state0 = init_state()
e_vec, e_list = encoder(testReIm)
# Resize for RNN
center  = e_vec.view(1,T,1024)
rnn_out, state = recurrent(center, state0)
# Resize for Conv
out_rnn = rnn_out.view(1,256,T,4)
yhat = torch.cat([dec_Re(out_rnn, e_list), dec_Im(out_rnn, e_list)], dim=1)

In [16]:
yhat.size()





torch.Size([1, 2, 1250, 161])