In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader , Dataset
import torchvision.datasets as datasets 
import torchvision.transforms as transforms

In [2]:
class basic_NN(nn.Module):
    def __init__(self , input_size , hidden_size ):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size , hidden_size) , 
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size) , 
            nn.Linear(hidden_size , 1)
        )
        
    def forward(self , x):
        return self.layers(x)
check_model = basic_NN(200 , 64).to('cuda')

check_model(torch.randn(512 , 200).to('cuda')).shape

torch.Size([512, 1])

In [3]:
class improved_NN(nn.Module):
    def __init__(self , input_size , hidden_size , hidden2 , hidden3):
        super().__init__()
        self.fc11 = nn.Linear(1 , hidden_size) 
        #self.fc12 = nn.Linear(hidden_size*input_size , 1)
        self.relu = nn.ReLU()
        self.bn_input = nn.BatchNorm1d(input_size)
        self.bn_hidden1 = nn.BatchNorm1d(input_size*hidden_size)
        
        self.fc21 = nn.Linear(1 , hidden2)
        #self.fc22 = nn.Linear(hidden2*input_size , 1)
        self.bn_hidden2 = nn.BatchNorm1d(input_size*hidden2)
        
        self.fc31 = nn.Linear(1 , hidden3)
        #self.fc32 = nn.Linear(hidden3*input_size , 1)
        self.bn_hidden3 = nn.BatchNorm1d(input_size*hidden3)
        
        self.all = nn.Linear(input_size*(hidden_size+hidden2+hidden3) , 1)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self , x):  
        #i can do all process in one step,but i will leave it like that to help you figure out 
        # what i did here and why that is right computationally and conceptually
        # remove (#) and look to shape in each step to help you
        
        hold = x.shape[0]
        x = self.bn_input(x)
        x2 = x3 = x
        x = x.view(-1,1)
        #print(x.shape)
        x = self.relu( self.fc11(x) )
        #print(x.shape)
        x = x.reshape(hold , -1)
        x = self.bn_hidden1(x)
        #print(x.shape)
        
        x2 = x2.view(-1,1)
        x2 = self.relu( self.fc21(x2) )
        x2 = x2.reshape(hold , -1)
        x2 = self.bn_hidden2(x2)
        
        
        x3 = x3.view(-1,1)
        x3 = self.relu( self.fc31(x3) )
        x3 = x3.reshape(hold , -1)
        x3 = self.bn_hidden3(x3)
        
        x = torch.cat((x,x2,x3) , 1)
        x = self.dropout(x)
        #print(x.shape)
        #print(x.shape , x2.shape , x3.shape)
        return self.all(x)

check_model = improved_NN(200 , 8 , 16 , 32).to('cuda')
check_model( torch.randn(512 , 200).to('cuda') ).shape

torch.Size([512, 1])

In [14]:
class modify_Improved_NN(nn.Module):
    def __init__(self , input_size , hidden_size , hidden2 , hidden3):
        super().__init__()
        
        self.fc11 = nn.Linear(2 , hidden_size) 
        self.bn_hidden1 = nn.BatchNorm1d(input_size//2*hidden_size)
        
        self.fc21 = nn.Linear(2 , hidden2)
        self.bn_hidden2 = nn.BatchNorm1d(input_size//2*hidden2)
        
        self.fc31 = nn.Linear(2 , hidden3)
        self.bn_hidden3 = nn.BatchNorm1d(input_size//2*hidden3)
        
        self.all = nn.Linear(input_size//2*(hidden_size) , 1)
        self.relu = nn.ReLU()
        self.bn_input = nn.BatchNorm1d(input_size)
        self.dropout = nn.Dropout(0.22)
        
    def forward(self , x):  

        n_batch = x.shape[0]
        x = self.bn_input(x)
        orig_features = x[: , :200].unsqueeze(2) #(N , 200 , 1)
        generated_features = x[: , 200:].unsqueeze(2) #(N , 200 , 1)
        x = torch.concat([orig_features , generated_features] , dim =2 ) #(N , 200 ,2)
        #x2 = x3 = x
        
        x = self.relu( self.fc11(x) ) #(N , 200 ,hidden)
        x = x.reshape(n_batch , -1)#(N , 200*hidden)
        x = self.bn_hidden1(x)
        
        #x2 = self.relu( self.fc21(x2) )
        #x2 = x2.reshape(n_batch , -1)
        #x2 = self.bn_hidden2(x2)
        
        
        #x3 = self.relu( self.fc31(x3) )
        #x3 = x3.reshape(n_batch , -1)
        #x3 = self.bn_hidden3(x3)
        
        #x = torch.cat((x,x2,x3) , 1)
        x = self.dropout(x)

        return self.all(x)

check_model = modify_Improved_NN(400 , 2 , 4 , 6).to('cuda')
check_model( torch.randn(512 , 400).to('cuda') ).shape

torch.Size([512, 1])