In [41]:
# Import libraries

import torch
from torch import nn
from torch.functional import F
import numpy as np
import swyft.lightning as sl
from toolz.dicttoolz import valmap
from sklearn.metrics import roc_curve, auc


In [42]:
# 1D Unet implementation below

class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        mid_channels=None,
        padding=1,
        bias=False,
    ):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv1d(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                mid_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, down_sampling=2):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(down_sampling), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)       

class Up(nn.Module):
    
    def __init__(self, in_channels, out_channels, scale_factor, kernel_size=3, stride=2):
        super().__init__()
        
        # self.up = nn.ConvTranspose1d(
        #     in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride
        # )
        
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor),
            nn.ConvTranspose1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.att = AttentionGate(out_channels, out_channels // 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, g, x):
        
        # x1 = gate
        # x2 = skip-connection
        
        #print ('Up x1 (g)', g.shape)
        #print ('Up x2 (x)', x.shape)
        
        # diff_signal_length = x2.size()[2] - x1.size()[2]
        # x1 = F.pad(
        #     x1, [diff_signal_length // 2, diff_signal_length - diff_signal_length // 2]
        # )
        
        x1 = self.up(g); #print ('Up x1', x1.shape)
        s = self.att(x1, x); #print ('Up s', s.shape)
        
        x = torch.cat([s, x1], dim=1)
        
        return self.conv(x)


class AttentionGate(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        bias=False):
        super().__init__()
        
        self.Wg = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(out_channels)
        )

        self.Wx = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(out_channels)
        )

        self.psi = nn.Sequential(
            nn.Conv1d(out_channels, 1, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        
        
    def forward(self, g, x):
        
        # g = gate
        # x = skip-connection
        
        #print ('Att (g)', g.shape)
        #print ('Att (x)', x.shape)
        
        Wg = self.Wg(g); #print ('Att (Wg)', Wg.shape)
        
        Wx = self.Wx(x); #print ('Att (Wx)', Wx.shape)
        out = self.relu(Wg + Wx); #print ('Att (out)', out.shape)
        out = self.psi(out); #print ('Att (out)', out.shape)
        
        # print (out)
        
        out = out * x; #print ('Att (out)', out.shape)
        return out
    

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size)

    def forward(self, x):
        return self.conv(x)


class Unet(nn.Module):
    def __init__(
        self,
        n_in_channels,
        n_out_channels,
        sizes=(16, 32, 64, 128, 256),
        down_sampling=(2, 2, 2, 2),
    ):
        super(Unet, self).__init__()
        self.inc = DoubleConv(n_in_channels, sizes[0])
        self.down1 = Down(sizes[0], sizes[1], down_sampling[0])
        self.down2 = Down(sizes[1], sizes[2], down_sampling[1])
        self.down3 = Down(sizes[2], sizes[3], down_sampling[2])
        self.down4 = Down(sizes[3], sizes[4], down_sampling[3])
        self.up1 = Up(sizes[4], sizes[3], down_sampling[3])
        self.up2 = Up(sizes[3], sizes[2], down_sampling[2])
        self.up3 = Up(sizes[2], sizes[1], down_sampling[1])
        self.up4 = Up(sizes[1], sizes[0], down_sampling[0])
        self.outc = OutConv(sizes[0], n_out_channels)

    def forward(self, x):
        #print (x.shape)
        x1 = self.inc(x); #print ('x1', x1.shape)
        x2 = self.down1(x1); #print ('x2', x2.shape)
        x3 = self.down2(x2); #print ('x3', x3.shape)
        x4 = self.down3(x3); #print ('x4', x4.shape)
        x5 = self.down4(x4); #print ('x5', x5.shape)
        x = self.up1(x5, x4); #print (x.shape)
        x = self.up2(x, x3); #print (x.shape)
        x = self.up3(x, x2); #print (x.shape)
        x = self.up4(x, x1); #print (x.shape)
        f = self.outc(x); #print (f.shape)
        return f


class LinearCompression(nn.Module):
    def __init__(self):
        super(LinearCompression, self).__init__()
        self.sequential = nn.Sequential(
            nn.LazyLinear(1024),
            nn.ReLU(),
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.LazyLinear(64),
            nn.ReLU(),
            nn.LazyLinear(16),
        )

    def forward(self, x):
        return self.sequential(x)


In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_t = torch.rand(5,3,8192).to(device)
data_f = torch.rand(5,6,4097).to(device)

unet_t = Unet(
    n_in_channels=3, 
    n_out_channels=1,
    sizes=(16, 32, 64, 128, 256),
    down_sampling=(8, 8, 8, 8),
).to(device)

unet_f = Unet(
    n_in_channels=6, 
    n_out_channels=1,
    sizes=(16, 32, 64, 128, 256),
    down_sampling=(2, 2, 2, 2),
).to(device)

In [45]:
unet_f(data_f[:,:,:-1])

tensor([[[ 0.0206, -0.3780, -0.0874,  ...,  0.0109,  0.2588,  0.3941]],

        [[-0.0722,  0.0090,  0.0559,  ..., -0.0300, -0.0280,  0.1172]],

        [[ 0.0403,  0.6867,  0.4020,  ...,  0.1316,  0.3188,  0.1926]],

        [[-0.0856, -0.8712, -0.8845,  ..., -0.0046,  0.0463,  0.4451]],

        [[ 0.2795,  0.1790,  0.3238,  ..., -0.0909,  0.1677,  0.3383]]],
       grad_fn=<ConvolutionBackward0>)