# Create a model architecture based on the Ross et al., 2018

<div>
<img src="model_architecture.png" width="500"/>
</div>


In [67]:
import torch
import torch.nn as nn
from torch.nn import Module

In [68]:
class PolarityPicker(Module):
    def __init__(self):
        super().__init__()
        self.cv1 = nn.Conv1d(in_channels = 1, out_channels=32, kernel_size=21)
        self.cv2 = nn.Conv1d(in_channels = 32, out_channels=64, kernel_size=15)
        self.cv3 = nn.Conv1d(in_channels = 64, out_channels=128, kernel_size=11)
        self.bn1 = nn.BatchNorm1d(32)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(128)
        self.pooling = nn.MaxPool1d(kernel_size = 2)
        self.fc1 = nn.Linear(in_features = 114, out_features = 512)
        self.fc2 = nn.Linear(in_features = 512, out_features = 512)
        self.fc3 = nn.Linear(in_features = 65536, out_features = 3)
        self.ft = nn.Flatten()
    
    def forward(self, data):
        x = self.cv1(data)
        x = self.bn1(x)
        x = self.pooling(x)
        x = self.cv2(x)
        x = self.bn2(x)
        x = self.pooling(x)
        x = self.cv3(x)
        x = self.bn3(x)
        x = self.pooling(x)
        x = self.fc1(x)
        x = self.bn4(x)
        x = self.fc2(x)
        x = self.bn5(x)
        x = self.ft(x)
        x = self.fc3(x)
        return x  

In [69]:
model = PolarityPicker()

In [70]:
polaritypicker(torch.rand([1,1,1000]))

tensor([[ 0.4108, -0.1912, -0.2480]], grad_fn=<AddmmBackward0>)

# Import the data

In [10]:
import h5py    
import numpy as np    
#f1 = h5py.File('https://drive.google.com/file/d/1cuB4YukRvZBM3Ehv18WajsM3Xd_jppSs/view?usp=share_link','r+')  