In [55]:
import torch
import numpy as np
import pandas as pd
import os
import h5py
from exabiome.nn.loader import read_dataset, LazySeqDataset
import argparse
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [3]:
path = '/global/homes/a/azaidi/ar122_r202.toy.input.h5'

In [4]:
hparams = argparse.Namespace(**{'load': False,
                            'window': 4096,
                            'step': 4096,
                             'classify': True,
                               'tgt_tax_lvl': "phylum",
                               'fwd_only': True})

In [7]:
chunks = LazySeqDataset(hparams, path=path, keep_open=True)
len(chunks)

19010

Let's use a function to use a transform for the x value (for padding) instead of having that logic in the dataset class 

In [21]:
def pad_seq(seq):
    if(len(seq) < 4096):
        padded = torch.zeros(4096)
        padded[:len(seq)] = seq
        return padded
    else:
        return seq

That's not a very clean transform fxn above, but w/e -- Pytorch uses lambda functions in their docs anyways ;)

We also don't want to do the unsqueezing at the batch level everytime it's called -- let's do it here :)

In [34]:
class taxon_ds(Dataset):
    def __init__(self, chunks, transform):
        self.chunks = chunks
        self.transform = transform
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        x = chunks[idx][1]
        x = self.transform(x)
        y = chunks[idx][2]
        return (x.unsqueeze(0), y)

In [35]:
ds = taxon_ds(chunks, pad_seq)

In [36]:
ds[0][0].shape, ds[2][0].shape

(torch.Size([1, 4096]), torch.Size([1, 4096]))

In [37]:
dl = DataLoader(ds, batch_size=16, shuffle=True)
len(dl)

1189

In [38]:
batch = next(iter(dl))
batch[0].shape, batch[1].shape

(torch.Size([16, 1, 4096]), torch.Size([16]))

We'll use timm to take a look at what the efficientnet architecture looks like

In [41]:
import timm

In [73]:
model = timm.create_model('efficientnet_b0', pretrained=False, in_chans=1)

In [74]:
#uncomment below to see full arch
#model

In [71]:
model.conv_stem

Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [72]:
model.blocks[0]

Sequential(
  (0): DepthwiseSeparableConv(
    (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (se): SqueezeExcite(
      (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
      (act1): SiLU(inplace=True)
      (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
    )
    (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): Identity()
  )
)

looks like it should be a bunch of conv1d's -> batchnorm1d -> SiLU

We'll make an arbitrary one below, just to test out that we can push the data through these types of layers -- the values inside of the layers are arbitrarily chosen (aside from first conv layer which needs to know that the input will have one channel)

In [76]:
model = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=3),
            nn.BatchNorm1d(32),
            nn.SiLU(inplace=True),
            nn.Conv1d(32, 32, kernel_size=3),
             nn.BatchNorm1d(32),
             nn.SiLU(inplace=True))

In [77]:
model(batch[0]).shape

torch.Size([16, 32, 4092])

Looks like it should be straightforward enough, as expected :)