In [666]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%config Completer.use_jedi = False

In [904]:
class Small_Block_bn(nn.Module):
    def __init__(self,r,k,in_ch,drop_rate):
        super().__init__()
       
        self.conv1=nn.Conv1d(int(in_ch),int(in_ch/r/k),1,stride=1)
        self.bn1=nn.BatchNorm1d(int(in_ch/r/k))
        self.drop=nn.Dropout(drop_rate)
        
        self.conv2=nn.Conv1d(int(in_ch/r/k),int(in_ch/k),3,stride=1,padding=1)
        self.bn2=nn.BatchNorm1d(int(in_ch/k))
        self.relu=nn.ReLU()
        
        
    def forward(self, x):
        x=self.conv1(x)
        #print("The shape of data aftr 1st conv should be [1,c'/r/k,w] :{}".format(x.shape))
        x=self.bn1(x)
        x=self.relu(x)
        x=self.drop(x)
        
        x=self.conv2(x)
        x=self.bn2(x)
        x=self.relu(x)
        x=self.drop(x)
        #print("The shape of data aftr 2nd conv should be [1,c'/k,w] :{}".format(x.shape))
        #print("#############################################")
        return x
        
        

In [905]:
class Small_Block(nn.Module):
    def __init__(self,r,k,in_ch,drop_rate):
        super().__init__()
       
        self.conv1=nn.Conv1d(int(in_ch),int(in_ch/r/k),1,stride=1)
        self.drop=nn.Dropout(drop_rate)
        
        self.conv2=nn.Conv1d(int(in_ch/r/k),int(in_ch/k),3,stride=1,padding=1)
        self.relu=nn.ReLU()
        
        
    def forward(self, x):
        x=self.conv1(x)
        x=self.relu(x)
        x=self.drop(x)
        
        x=self.conv2(x)
        x=self.relu(x)
        x=self.drop(x)
        return x
        

In [906]:
class Cardinal(nn.Module):
    def __init__(self, r, k, in_ch, drop, up, radix=Small_Block_bn, radixUp=Small_Block):
        super().__init__()
        if up==False:
            self.__radix_block=radixUp #small block
        else:
            self.__radix_block=radix #small block
        self.k=k #number of cardinals  
        self.r=r #number of radixes 
        self.in_ch=in_ch
        self.drop_rate=drop
        self.radix=self.__build_radix() #sends to build layers
       
        
    def forward(self,x):
        #concat tensors for every k radixs'
        radix_current=[]
        for i, l in enumerate(self.radix):
            radix_current.append(l(x))
        radix_concat = torch.cat(radix_current, dim=1)
        
        return radix_concat
            
    def __build_radix(self):
        rdx=[]
      
        for i in range(self.k):
            rdx.append(
            self.__radix_block(self.r,self.k,self.in_ch,self.drop_rate)
            )
        out = nn.ModuleList(rdx)
            
        return nn.Sequential(*out)            

In [907]:
class ResNeSt(nn.Module):
    def __init__(self, r, k, stride,drop_rate, input_channels, ch_out, up, cardinal_block=Cardinal):
        super().__init__()
        self.cardinal_block=cardinal_block
        #self.split=__build_split(input_channels)
        self.in_ch=input_channels
        self.drop_rate=drop_rate
        
        self.r=r
        self.k=k
        self.up=up
        self.pool=nn.MaxPool1d(stride=2,kernel_size=2)
        self.pool_non=nn.MaxPool1d(stride=1,kernel_size=1)
        self.upsample=nn.ConvTranspose1d(ch_out, ch_out, kernel_size=2, stride=2)
       
        self.conv1=nn.Conv1d(
        in_channels=input_channels,
        out_channels=input_channels,
        kernel_size=1,
        )  
        self.conv2=nn.Conv1d(
        in_channels=input_channels,
        out_channels=ch_out,
        kernel_size=1,
        )
        
        self.cardinal=self.build_cardinal_block(up)
        
        if up == True:
            self.resenst_block = nn.Sequential(*[self.cardinal,self.conv2,self.pool])
        elif up == False:
            self.resenst_block = nn.Sequential(*[self.cardinal, self.conv2, self.upsample])
        else: #None
            self.resenst_block = nn.Sequential(*[self.cardinal, self.conv2, self.pool_non])

            
        self.global_pool=nn.AdaptiveAvgPool1d(int(1))
        self.dence1=nn.Linear(int(input_channels),int(input_channels/2))
        self.dence2=nn.Linear(int(input_channels/2),int(input_channels))
        self.relu=nn.ReLU()
        self.drop=nn.Dropout(0.1)
        self.softmax=nn.Softmax(dim=2)
        
        
    def forward(self,x):
        transformant=0
        #add every r cardinal
        residual=[]
        for i in range(self.r):
            r_block=self.resenst_block[0][i](x)
            transformant=transformant+r_block
            #print(transformant.shape)
            residual.append(r_block)
        
###############stuff here 

        vector=self.global_pool(transformant).mean(2) #shapes [1,8,1] => [1,8]
        branch=self.dence1(vector) 
        branch=self.relu(branch)
        branch=self.drop(branch)
        ##need batch norm
        dout=[]
        branch.shape
        for t in range(self.r):
            variable=self.dence2(branch)
            variable=self.relu(variable)
            variable=self.drop(variable)
            variable=variable.unsqueeze(dim=2)
            dout.append(variable)
        variable = torch.cat(dout, dim=0)#<2  #torch.Size([2, 4])

        branch=self.softmax(variable)
        #branch=branch.unsqueeze(dim=2) #shapes [1,8] => [1,8,1]
        
        trans=0
        for t in range(self.r):
            trans=trans+residual[t]*branch[t].unsqueeze(dim=0)  #shapes [1,8,3000] * [1,8,1]
            
        #torch.Size([1, 4...16,32,64])
        trans=self.conv1(trans)
        #print("shape of transformant {}, shape of branch {}".format(transformant.shape,x.shape))
        trans=trans+x #adding from x
        trans=self.resenst_block[1](trans)    
        trans=self.resenst_block[2](trans)    
        print(trans.shape)
        return trans
        
    
    def build_cardinal_block(self,up):
        crb=[]
        
        for i in range (self.r):
            crb.append(
            self.cardinal_block(self.r,self.k,self.in_ch,self.drop_rate,up)
            )
        out = nn.ModuleList(crb)
        return nn.Sequential(*out)     

In [908]:
m = nn.Softmax(dim=1)
input = torch.randn(2, 3)
output = m(input)
print(input)
print(output)

tensor([[ 0.4432, -0.2075, -0.9288],
        [-0.7698,  0.6224, -0.5034]])
tensor([[0.5633, 0.2939, 0.1428],
        [0.1580, 0.6358, 0.2062]])


In [909]:
class AutoEncoder(nn.Module):
    def __init__(self,hparams, resnest=ResNeSt):
        super().__init__()
        self.resnest_block=resnest
        self.hparams=hparams
        self.r=hparams['r_k'][0]
        self.k=hparams['r_k'][1]
        self.stride=hparams["stride"]
        self.drop_rate=self.hparams['dropout_rate']
        self.encoder=self.build_encoder()
        self.decoder=self.build_decoder()
        shape_out = self.hparams['n_samples']
        
        
        for i in range(len(self.hparams['layer_feature_maps'])-1):
            shape_out //= self.hparams['pool_size']

        for i in range(len(self.hparams['layer_feature_maps'])-1):
            shape_out *= self.hparams['pool_size']
        
        self.__out_cnn = nn.Conv1d(
            in_channels=self.hparams["layer_feature_maps"][0],
            out_channels=self.hparams["n_channels"],
            kernel_size=self.hparams["kernel_size"],
            padding=(self.hparams['n_samples'] - shape_out) // 2,
            dilation=1,
            stride=1,
            bias=False,
        )
    
    def forward(self,x):
        x=self.encoder(x)
        x=self.decoder(x)
        x=self.__out_cnn(x)
        return x  
        
        
    def build_encoder(self):
        resnest=[]
        resnest.append(self.resnest_block(self.r,self.k, self.stride, self.drop_rate, self.hparams["n_channels"],self.hparams["layer_feature_maps"][0],None))
        for i in range(len(self.hparams["layer_feature_maps"])-1):
            resnest.append(self.resnest_block(self.r,self.k,self.stride,self.drop_rate, self.hparams["layer_feature_maps"][i],self.hparams["layer_feature_maps"][i+1],True))
        
        return nn.Sequential(*resnest)  
    
        
    def build_decoder(self):
        resnest=[]

        lfm=self.hparams['layer_feature_maps']
        lfm.reverse()

        resnest.append(self.resnest_block(self.r, self.k,self.stride,self.drop_rate, self.hparams["layer_feature_maps"][0],self.hparams["layer_feature_maps"][0],None))
        
        for i in range(len(self.hparams["layer_feature_maps"])-1):
            resnest.append(self.resnest_block(self.r, self.k,self.stride,self.drop_rate, self.hparams["layer_feature_maps"][i],self.hparams["layer_feature_maps"][i+1],False))
        
        lfm.reverse()
        return nn.Sequential(*resnest) 
        
        

In [910]:
hparams={"kernel_size":1,"pool_size":2,"r_k":[2,2],"n_channels":4,"layer_feature_maps":[16,32,64,128,256],"dropout_rate":0.1,'n_samples':3000,"stride":2}
#n_channels=c_prime                            #3,2               30
num_ch=hparams["n_channels"]; 
num_samp=hparams["n_samples"]
inp=10+np.random.randint(0,20,(num_ch,num_samp))/5
x=torch.tensor([inp]) 
x = x.type(torch.float)

#net=ResNeSt(hparams,30,20)
#out=net.forward(x) #torch.Size([1, 30, 300])

net=AutoEncoder(hparams)
out=net.forward(x) #torch.Size([1, 30, 300])

torch.Size([1, 16, 3000])
torch.Size([1, 32, 1500])
torch.Size([1, 64, 750])
torch.Size([1, 128, 375])
torch.Size([1, 256, 187])
torch.Size([1, 256, 187])
torch.Size([1, 128, 374])
torch.Size([1, 64, 748])
torch.Size([1, 32, 1496])
torch.Size([1, 16, 2992])


In [897]:
out.shape

torch.Size([1, 4, 3000])

In [696]:
m = nn.AdaptiveAvgPool1d(1)
input1 = torch.randn(1, 2, 3)
print(input1.shape)
print(input1)
branch=torch.randn(1, 2)
branch=branch.unsqueeze(dim=2)
print(branch.shape)
print(branch)
input1*branch
#input1=input1.squeeze(dim=[1,1])
#input1=input1.mean(2)
#print(input1.shape)
#print(input1[0])
#output = m(input1)

#print(output.shape)
#print(output)

torch.Size([1, 2, 3])
tensor([[[ 0.5327, -0.8996, -0.0028],
         [-1.1559, -0.0129, -1.7405]]])
torch.Size([1, 2, 1])
tensor([[[0.8170],
         [1.3763]]])


tensor([[[ 4.3518e-01, -7.3494e-01, -2.2482e-03],
         [-1.5908e+00, -1.7687e-02, -2.3955e+00]]])