In [3]:
import os, sys
sys.path.append(os.pardir)
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import logomaker
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.datamodules.datamodule import BaseDataModule
from src.datamodules.components.dataset import BaseDataset
from src.models.model import BaseNet
from src.models.components.cnn import CNN
from src.models.components.rnn import RNN
from src.models.components.deepfam import DeepFam
from captum.attr import IntegratedGradients

In [3]:
df = pd.read_csv("../data/train.csv")
df

Unnamed: 0,id,seq,1h+,2h+,3h+,4h+,5h+,6h+,7h+,8h+,10h+,1h-,2h-,3h-,4h-,5h-,6h-,7h-,8h-,10h-
0,S0_M_T1,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...,1.15020,1.12560,1.400500,0.23320,0.73195,-0.47038,-0.57411,-0.259830,-0.76564,0.97764,0.37349,0.13216,-1.22420,-2.991800,-3.08940,-2.58650,-2.67220,-3.33630
1,S0_M_T1105,GCAGTGTATATAAACTTATAAATATTTCTCCAGCAAATGTGTAAAT...,3.16460,4.57390,4.277900,3.50270,2.85220,1.11460,0.42500,0.015806,-1.01360,2.82910,2.84920,2.36750,2.32410,-0.837060,-2.51390,0.16634,-0.53467,-2.16770
2,S0_M_T1114,ACCGGTGGATGAGGAAGGTAAATGTCTGCTCTAAGAAGTGCAGTGT...,1.20910,0.44768,1.885200,0.23320,0.45631,0.47715,0.47560,-0.420290,-0.62125,0.70201,0.83292,0.23155,0.61079,0.513920,0.13292,0.16634,-0.60179,-1.09030
3,S0_M_T1161,CGCTACAGACAATGTCTCTGTGAGACACGTATTCGCACATGGTATC...,-3.41850,-4.45510,-3.181800,-3.57420,-4.79160,-5.27770,-5.35860,-5.507800,-6.01360,-3.54590,-4.62650,-4.64700,-5.01930,-4.972600,-5.06120,-5.57940,-5.99410,-5.32870
4,S0_M_T117,ACAGCACAGACAGATTGACCTATTGGGGTGTTTCGCGAGTGTGAGA...,4.17190,3.92420,4.011800,3.58570,3.56590,3.01690,2.52390,2.362600,1.03080,4.35490,3.32770,3.00940,2.42590,1.381900,1.61100,1.02600,0.70633,-0.69453
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67275,S3_H_T9990,AATTAAAGAGAGAGAGAGACGGAGAACACGGTGGGTTTACTAGCGC...,1.32000,0.93303,1.614300,0.88527,0.45631,0.85155,-0.38169,-0.115440,-1.01360,1.09790,0.89705,1.02420,0.35936,-0.008894,0.34401,0.49884,0.17582,-1.43120
67276,S3_H_T9991,AAGGAATTTGTAGCGCCTGCTGACAAGTCTCTAGACTTTCTTGCCA...,2.54210,2.39860,2.562500,2.18070,2.27450,1.24580,1.47520,0.918510,-0.42861,2.54150,2.54340,2.58980,2.05400,1.034200,1.74090,1.31820,1.19570,-1.43120
67277,S3_H_T9994,TTTGGCTATAGAATCAGGCGGCCGTTTTATGTGGGATTTGACGACC...,-2.42700,-1.45840,-0.068781,-1.25220,-0.33218,-2.69280,-1.05880,-1.507800,-5.01360,-0.22399,-0.62651,-0.86565,-2.44390,-2.001600,-4.06120,-1.26600,-1.99410,-2.53000
67278,S3_H_T9997,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...,-0.26365,2.26850,2.310000,1.12630,1.29580,0.24583,0.61754,-0.922790,-3.01360,0.91351,0.83292,1.13110,1.35900,-2.413300,-1.78460,-0.50088,-1.82420,-3.33630


In [8]:
df.columns[12:]

Index(['2h-', '3h-', '4h-', '5h-', '6h-', '7h-', '8h-', '10h-'], dtype='object')

In [95]:
import json
a = "[0,1,2,3,4]"
json.loads(a)

[0, 1, 2, 3, 4]

In [77]:
net = BaseNet(CNN())
net.device

device(type='cpu')

In [51]:
nn.Transformer.generate_square_subsequent_mask(10).shape

torch.Size([10, 10])

In [6]:
df = pd.read_csv("../data/train.csv")
train_dataset = BaseDataset(df)
train_dataloader = DataLoader(train_dataset, batch_size=96)

In [7]:
X, init_level, y = next(iter(train_dataloader))

In [90]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        input_dim: int = 4,
        out_dim: int = 256,
        kernel_size: int = 9,
        pool_size: int = 3
    ):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv1d(in_channels=input_dim, out_channels=out_dim, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
            nn.MaxPool1d(pool_size)
        )
    
    def forward(self, x):
        # x: (N, C, L)
        
        return self.main(x)

class Encoder(nn.Module):
    def __init__(
        self,
        kernel_sizes: List = [6, 9, 12, 15],
        out_channels: int = 256,
        pool_size: int = 3
    ):
        super().__init__() 
        self.conv_blocks = nn.ModuleList([ConvBlock(4, out_channels, k, pool_size) for k in kernel_sizes])
        
    def forward(self, x):
        # x: (N, L, C)
        x = x.transpose(1, 2)  # (N, C, L)
        conv_outs = []
        for conv in self.conv_blocks:
            conv_outs.append(conv(x))
        x = torch.cat(conv_outs, dim=1)  # (N, C, L)
        
        return x

class TRFMDecode(nn.Module):
    def __init__(
        self,
        kernel_sizes: List = [6, 9, 12, 15],
        out_channels: int = 256,
        pool_size: int = 3,
        d_model: int = 256
    ):
        super().__init__()
        self.encoder = Encoder()
        self.fc1 = nn.Linear(len(kernel_sizes) * out_channels, d_model)
        self.decoder = nn.TransformerDecoderLayer(d_model=d_model, nhead=8, dim_feedforward=2048)
        self.embed_tgt = nn.Linear(1, d_model)
        self.out = nn.Linear(d_model, 1)
        
        
    def forward(self, x, init_level):
        # x: (N, L, C)
        x = self.encoder(x)  # (N, C, L)
        x = x.transpose(1, 2)  # (N, L, C)
        x = self.fc1(x)  # (N, L, d_model)
        x = x.transpose(0, 1)  # (L, N, d_model)
        tgt = self.embed_tgt(init_level.unsqueeze(-1))  # (N, d_model)
        tgt = tgt.unsqueeze(0)  # (1, N, d_model)
        
        outputs = []
        for _ in range(8):
            tgt = self.decoder(tgt, x)  # (L, N, d_model)
            out = self.out(tgt[-1]) # (N, 1)
            outputs.append(out)
            next_tgt = self.embed_tgt(out).unsqueeze(0)  # (1, N, d_model)
            tgt = torch.cat([tgt, next_tgt], axis=0)  # (L+1, N, d_model)
        
        outputs = torch.cat(outputs, axis=1)  # (N, 8)

        return outputs

In [91]:
net = TRFMDecode()
net(X, init_level).shape

torch.Size([96, 8])

In [54]:
init_level.shape
y.shape

torch.Size([96, 8])

In [80]:
embed = nn.Linear(1, 256)
tgt = init_level.unsqueeze(-1)
tgt = embed(tgt).unsqueeze(0)
tgt.shape

torch.Size([1, 96, 256])

In [83]:
net = Encoder()
decoder = nn.TransformerDecoderLayer(d_model=256, nhead=8, dim_feedforward=2048)
memory = torch.randn((30, 96, 256))
decoder(tgt, memory)[-1].shape

torch.Size([96, 256])

In [None]:
from typing import List
import torch
from torch import nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(
        self,
        kernel_size: int = 9,
        out_channels: int = 256
    ):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=4, out_channels=out_channels, kernel_size=kernel_size),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.MaxPool1d(110 - kernel_size + 1)
        )
    
    def forward(self, x):
        # x: (N, C, L)
        return self.conv(x).squeeze(-1)

    
class DeepDecode_v2(nn.Module):
    def __init__(
        self,
        kernel_sizes: List[int] = [6, 9, 12, 15],
        out_channels: int = 256,
        embed_dim: int = 256,
        fc_dim: List[int] = [1024, 64],
        dropout: float = 0.1
    ):
        super().__init__()
        self.conv_blocks = nn.ModuleList([ConvBlock(i, out_channels) for i in kernel_sizes])
        self.embed = nn.Linear(len(kernel_sizes) * out_channels, embed_dim)
        self.gru_cell = nn.GRUCell(1, embed_dim)
        self.out = nn.Sequential(
            nn.Linear(embed_dim, fc_dim[0]),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(fc_dim[0], fc_dim[1]),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(fc_dim[1], 1)
        )
        
    def forward(self, x, init_level):
        # x: (N, L, C)
        x = x.transpose(1, 2)
        temp = []
        for conv in self.conv_blocks:
            temp.append(conv(x))
        x = torch.cat(temp, axis=1)
        h = self.embed(x)
        
        outputs = []
        out = init_level.unsqueeze(1)
        for _ in range(8):
            h = self.gru_cell(out, h)
            out = self.out(h)
            outputs.append(out)
        
        return torch.cat(outputs, axis=1)

In [38]:
net = DeepDecode_v2()
emb = net(X, init_level)
emb.shape

torch.Size([96, 8])

In [18]:
gru_cell = nn.GRUCell(256, 256)
gru_cell(emb).shape

torch.Size([96, 256])