In [None]:
from helper_fuctions import *

# **LegNet**: solving the sequence-to-expression problem with SOTA convolutional networks

This is the tutorial demonstrating how LegNet can be practically used with the data from yeast gigantic parallel reporter assays.

Please don't hesitate to ask questions or share any feedback: dmitrypenzar1996@gmail.com

The code below allows us to train the LegNet model to predict gene expression from promoter sequences using the data from gigantic yeast parallel reporter assays ( [Sort-Seq](https://www.cell.com/cell-systems/fulltext/S2405-4712(16)30292-7?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS2405471216302927%3Fshowall%3Dtrue) ).

In our tutorial, we follow the setup of the [DREAM 2022 promoter expression challenge](https://www.synapse.org\#!Synapse:syn28469146/wiki/617075).

To perform the model training, you need a table with sequences and corresponding expression values. A toy example is [expression measured in yeast grown in complex medium](https://zenodo.org/record/4436477/files/complex_media_training_data_Glu.txt?download=1) or [expression measure in yeast grown in defined medium](https://zenodo.org/record/4436477/files/defined_media_training_data_SC_Ura.txt?download=1) in the [Zenodo record](https://zenodo.org/record/4436477#.Y5QgZOxBy3J), see [Vaishnav et al](https://doi.org/10.1038/s41586-022-04506-6) for details.

## Specify parameters

In [None]:
model_dir = 'model' # folder that will contatin all outputs {type:"string"}

### 1. Data

The tab-separated training data must use the following format:

* First column: sequence. Second column: expression value. Third column: fold (optional).
* The data should be provided without any extra header line.


In [None]:
path_to_training_data = 'path' # file location or url {type:"string"}

delimiter = "tab" # ["tab", ",", ";"] {type:"string"}

If the sequences have unequal lengths it's **necessary** to use padding. The padding is possible only if the surrounding plasmid sequence is provided as a string where the target promoter sequence is denoted with Ns. (Example: atgcNNNatcg.)

Specify the sequence length without adapters.

Seqsize is the resulting sequence size after padding.


In [None]:
with open("plasmid.json") as json_file:
    plasmid = json.load(json_file)

left_adapter = "TGCATTTTTTTCACATC" # {type:"string"}
right_adapter = "GGTTACGGCTGTT" # {type:"string"}

sequence_len_no_adapters = 80 # {type:"integer"}

seqsize = 150 # {type:"integer"}

if len(plasmid) == 0:
    print('Padding will not be used because the plasmid sequence is unavailable.')
else:
    print('Padding will be used because the plasmid sequence has len > 0.')

assert sequence_len_no_adapters <= seqsize, "Seqsize cannot be less than the sequence length without adapters."


### 2.1. Architecture

In [None]:
'''
final_ch: number of channels of the final convolutional layer. 

For challenge task it corresponds to the number of expression bins.
'''
final_ch = 18 # {type:"integer"}

'''
Number of channels for EffNet-like blocks
'''
blocks = [256, 128, 128, 64, 64, 64, 64] # {type:"list"}

'''
Kernel size of convolutional layers
'''
ks = 7 # {type:"integer"}

'''
Number of channels in a middle/high-dimensional convolutional layer of an EffNet-like block
'''
resize_factor = 4 # {type:"integer"}

'''
Reduction number used in SELayer
'''
se_reduction = 4 # {type:"integer"}

'''
BatchNorm momentum
'''
bn_momentum = .1 # {type:"float"} 

### 2.2. Training

In [None]:
loss = 'kl' # ["mse", "kl"] {type:"string"}

epoch_num = 450 # {type:"integer"}

batch_per_epoch = 1000 # {type:"integer"}

optimizer_name = "adamw" # ["adam", "adamw", "rmsprop"] {type:"string"}

weight_decay = 0.01 # {type:"float"}

'''
Note: lr will be set manually via range test run for few epochs to find out good learning rate.
Warning: running range test will require you to manually set lr based on a training plot.
'''

train_batch_size = 1024 # {type:"integer"}
valid_batch_size = 1024 # {type:"integer"}

train_workers = 8 # {type:"integer"}
valid_workers = 8 # {type:"integer"}

'''
Additional binary channel with singleton information.
Integer expression values are considered as singletons measured in the GPRA experiment only once. 
'''
use_single_channel = True # ["False", "True"] {type:"raw"}

'''
Dataset augmentation by reversing input sequences and adding binary channel.
'''
use_reverse_channel = True # ["False", "True"] {type:"raw"}

'''
Additional substrate channel. Can be used to mix the data from yeast grown in different media.
'''
use_multisubstate_channel = True # ["False", "True"] {type:"raw"}

'''
Whether to split the training data into train and validation.
'''
foldify = True # ["False", "True"] {type:"raw"}

'''
Whether to turn on validation.
'''
use_validation = False # ["False", "True"] {type:"raw"}

seed = 42 # {type:"integer"}

gpu = 1 # {type:"integer"}

## The main code of the model training

Below we present approach step-by-step, briefly describing data preprocessing and the model architecture.

Note: additional code is located in the `helper_fuctions` module. 

### Data preprocessing

#### Input vector structure

The training sequences were padded on the 5’ end with nucleotides from the corresponding plasmid to the uniform total length (`preprocess_data`) and encoded into 4-dimensional vectors using **one-hot encoding** (`Seq2Tensor`). 

We have considered the sequences with integer scores as **singletons**, i.e. they likely have been observed only once, while non-integer scores were obtained by averaging two or more observations. To supply this information to the model, we introduced a binary `is_singleton` channel (1 for singletons, 0 for other training sequences). The final predictions for evaluation were made by specifying `is_singleton=0`.

Since the regulatory elements are often asymmetric relative to the transcription start sites, different scores were expected for direct and reverse complementary strands of a particular sequence. Thus, the data was augmented by providing each sequence twice in native and **reverse complementary** form, specifying 0 and 1, respectively, in an additional `is_reverse` channel. The test-time augmentation was to average the predictions made for direct (`is_reverse=0`) and reverse complementary (`is_reverse=1`) input.


<img src="img/Input.jpg" width="800" />

In [None]:
def preprocess_data(data, seqsize):
    '''
    Training sequences are padded on the 5’ end with nucleotides 
    from the corresponding plasmid to the uniform total length.
    '''
    
    data = data.copy()
    INSERT_START = plasmid.find('N' * sequence_len_no_adapters)
    
    #take the left part of the plasmid
    add_part = plasmid[INSERT_START-seqsize:INSERT_START]
    
    # cut left adapter and append the plasmid part
    data.seq = data.seq.apply(lambda x:  add_part + x[len(left_adapter):])
    
    # reduce sequence size to seqsize
    data.seq = data.seq.str.slice(-seqsize, None)
    return data

In [None]:
class Seq2Tensor(nn.Module):
    '''
    Encode sequences using one-hot encoding after preprocessing.
    '''
    def __init__(self):
        super().__init__()
    def forward(self, seq):
        if isinstance(seq, torch.FloatTensor):
            return seq
        seq = [n2id(x) for x in seq]
        code = torch.from_numpy(np.array(seq))
        code = F.one_hot(code, num_classes=5) # 5th class is N
        
        code[code[:, 4] == 1] = 0.25 # encode Ns with .25
        code = code[:, :4].float() 
        return code.transpose(0, 1)

#### The formal description of the challenge problem

We reformulated the sequence-to-expression problem arising in the GPRA ([Sort-Seq](https://www.cell.com/cell-systems/fulltext/S2405-4712(16)30292-7?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS2405471216302927%3Fshowall%3Dtrue)) analysis as a soft classification task by transforming expression estimates to class probabilities.

Given a measured expression $e$, assume that the real expression is normally distributed: $\rho \sim N(\mu = e + 0.5, sd=0.5)$. 

For each class i from 1 to 16 defined by an original measurement bin, a probability of the class is the probability of $[i, i+1)$, where 0 and 17 bins are special cases with  $(-\inf,0]$ and $[17,+\inf)$, respectively.


In [None]:
# points to get cdf from normal distribution, see below
POINTS = np.array([-np.inf, *range(1, final_ch, 1), np.inf])

class SeqDatasetProb(Dataset):
    
    """ Sequence dataset. """
    
    def __init__(self, ds, seqsize, use_single_channel, use_reverse_channel, use_multisubstate_channel, shift=0.5, scale=0.5):
        """
        Parameters
        ----------
        ds : pd.DataFrame
            Training dataset.
        seqsize : int
            Constant sequence length.
        use_single_channel : bool
            If True, additional binary channel with singleton information is used.
        use_reverse_channel : bool
            If True, additional reverse augmentation is used.
        use_multisubstate_channel : bool
            If True, additional substrate channel is used.
        shift : float, optional
            Assumed sd of real expression normal distribution.
        scale : float, optional
            Assumed scale of real expression normal distribution.
        """
        self.ds = ds
        self.seqsize = seqsize
        self.totensor = Seq2Tensor() 
        self.shift = shift 
        self.scale = scale
        self.use_single_channel = use_single_channel
        self.use_reverse_channel = use_reverse_channel
        self.use_multisubstate_channel = use_multisubstate_channel
        
    def transform(self, x):
        assert isinstance(x, str)
        assert len(x) == self.seqsize
        return self.totensor(x)
    
    def __getitem__(self, i):
        """
        Output
        ----------
        X: torch.Tensor    
            Create one-hot encoding tensor with reverse and singleton channels if required.
        probs: np.ndarray
            Given a measured expression, we assume that the real expression is normally distributed
            with mean=`bin` and sd=`shift`. 
            Resulting `probs` vector contains probabilities that correspond to each class (bin).     
        bin: float 
            Training expression value
        """
        seq = self.transform(self.ds.seq.values[i])
        to_concat = [seq]
        
        # add reverse augmentation channel
        if self.use_reverse_channel:
            rev = torch.full( (1, self.seqsize), self.ds.rev.values[i], dtype=torch.float32)
            to_concat.append(rev)
            
        # add singleton channel
        if self.use_single_channel:
            single = torch.full( (1, self.seqsize) , self.ds.is_singleton.values[i], dtype=torch.float32)
            to_concat.append(single)
            
        # add multiclass channel
        if self.use_multisubstate_channel:
            substrate = torch.full( (1, self.seqsize) , self.ds.substrate.values[i], dtype=torch.float32)
            to_concat.append(substrate)
        
        # create final tensor
        if len(to_concat) > 1:
            X = torch.concat(to_concat, dim=0)
        else:
            X = seq
            
        bin = self.ds.bin.values[i]
        
        # generate probabilities corresponding to each class
        norm = scipy.stats.norm(loc=bin + self.shift, scale=self.scale)
        
        cumprobs = norm.cdf(POINTS)
        probs = cumprobs[1:] - cumprobs[:-1]
        return X, probs, bin
    
    def __len__(self):
        return len(self.ds.seq)

### Trainer code

`train_step` function of `create_trainer` illustrates the principle of the training procedure.

The output of the model is a vector containing `final_ch` **probabilities** that correspond to each class (bin) and **expression value** where

$expression = \sum_{i=0}^{final\_ch} i * p_i$ (soft-argmax operation). 

Only probabilities vector is used for calculating loss during training. 


In [None]:
def create_trainer(model, 
                   optimizer,
                   scheduler,  
                   criterion, 
                   device, 
                   model_dir,
                   use_validation,
                   valid_dl=None
                  ):
    model_dir = Path(model_dir)
    model_dir.mkdir(exist_ok=True, parents=True)
    
    train_mse =  MeanSquaredError()
    train_pearson = PearsonMetric()
    train_spearman = SpearmanMetric()
    
    def train_step(trainer, batch):
        nonlocal model
        if not model.training:
            model = model.train()
            
        # unpack one-hot encoding tensor with additional channels, probabilities and expression
        X, y_probs, y = batch 
        X = X.to(device)
        y_probs = y_probs.float().to(device)
        
        # the output of the model consists of probabilities vector and expression from these probabilities
        # only probabilities vector is used for training
        logprobs, y_pred = model(X) 
        loss = criterion(logprobs, y_probs)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()                                                                                         
        out = (y_pred.detach().cpu(), y)
        
        # calculate training metrics based on calculated expression
        train_mse.update(out)
        train_pearson.update(out)
        train_spearman.update(out)
               
        return loss.item()
    
    trainer = Engine(train_step)
    
    @trainer.on(Events.STARTED)
    def prepare_epoch(engine): 
        engine.state.metrics['train_pearson'] = -np.inf
        engine.state.metrics['train_mse'] = -np.inf
        engine.state.metrics['train_spearman'] = -np.inf
        if use_validation:
            engine.state.metrics['val_pearson'] = -np.inf
            engine.state.metrics['val_mse'] = -np.inf
            engine.state.metrics['val_spearman'] = -np.inf

    def evaluate(engine, batch):
        nonlocal model
        if model.training:
            model = model.eval()
        with torch.no_grad():
            X, y_probs, y = batch
            X = X.to(device)
            y = y.float().to(device)
            _, y_pred = model(X)
        return y_pred.cpu(), y.cpu()

    evaluator = Engine(evaluate)

    MeanSquaredError().attach(evaluator, 'mse')
    p = PearsonMetric()
    p.attach(evaluator, 'pearson')
    s = SpearmanMetric()
    s.attach(evaluator, 'spearman')
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        p.reset()
            
        engine.state.metrics['train_mse'] = train_mse.compute()
        engine.state.metrics['train_pearson'] = train_pearson.compute()
        engine.state.metrics['train_spearman'] = train_spearman.compute()
        train_mse.reset()
        train_pearson.reset()
        train_spearman.reset()
        
        if use_validation and valid_dl is not None:
            evaluator.run(valid_dl, max_epochs=1)
            for name, value in evaluator.state.metrics.items():
                engine.state.metrics[f"val_{name}"] = value
        
        score_path = model_dir / f"scores_{engine.state.epoch}.json"
        with open(score_path, "w") as outp:
            json.dump(engine.state.metrics, outp)

    
    @trainer.on(Events.EPOCH_COMPLETED)
    def dump_model(engine):
        model_path = model_dir / f"model_{engine.state.epoch}.pth"
        torch.save(model.state_dict(), model_path)
        
        optimizer_path = model_dir / f"optimizer_{engine.state.epoch}.pth"
        torch.save(optimizer.state_dict(), optimizer_path)
        
        if scheduler is not None:
            scheduler_path = model_dir / f"scheduler_{engine.state.epoch}.pth"
            torch.save(scheduler.state_dict(), scheduler_path)
           
            
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, ["train_mse", "train_pearson", "train_spearman", ], 
                output_transform=lambda x: {'batch_loss': x}, 
                )
    return trainer, p

### Model architecture

Our model is based upon a fully-convolutional neural network architecture inspired by EfficientNetV2 with selected features  from DenseNet and additional custom blocks.

![OveralArchitecture](img/A.jpg)

Efficient-net like convblock comprises features of EfficientNetV2. The **Squeeze and Excitation (SE) block** is a modification of that of the original EfficientNetV2.

<img src="img/SEblock.jpg" width="200" />

In [None]:
class SELayer(nn.Module):
    """
    Squeeze-and-Excite layer.

    Parameters
    ----------
    inp : int
        Middle layer size.
    oup : int
        Input and ouput size.
    reduction : int, optional
        Reduction parameter. The default is 4.
    """
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, int(inp // reduction)),
                nn.SiLU(),
                nn.Linear(int(inp // reduction), int(inp // reduction)),
                Concater(Bilinear(int(inp // reduction), int(inp // reduction // 2), rank=0.5, bias=True)),
                nn.SiLU(),
                nn.Linear(int(inp // reduction) +  int(inp // reduction // 2), oup),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = self.fc(y).view(b, c, 1)
        return x * y

The **bilinear block** is re-implemented using tensorly-pytorch instead of the default PyTorch version because of the tensor regularization functionality of the latter. 

In [None]:
class Bilinear(nn.Module):
    """
    Bilinear layer introduces pairwise product to a NN to model possible combinatorial effects.
    This particular implementation attempts to leverage the number of parameters via low-rank tensor decompositions.

    Parameters
    ----------
    n : int
        Number of input features.
    out : int, optional
        Number of output features. If None, assumed to be equal to the number of input features. The default is None.
    rank : float, optional
        Fraction of maximal to rank to be used in tensor decomposition. The default is 0.05.
    bias : bool, optional
        If True, bias is used. The default is False.

    """
    def __init__(self, n: int, out=None, rank=0.05, bias=False):        
        super().__init__()
        if out is None:
            out = (n, )
        self.trl = TRL((n, n), out, bias=bias, rank=rank)
        self.trl.weight = self.trl.weight.normal_(std=0.00075)
    
    def forward(self, x):
        x = x.unsqueeze(dim=-1)
        return self.trl(x @ x.transpose(-1, -2))

Here is presented the structure of full **EfficientNet-like convblock**.

<img src="img/C.jpg" width="200" />

In [None]:
class SeqNN(nn.Module):
    """
    LegNet neural network.

    Parameters
    ----------
    seqsize : int
        Sequence length.
    use_single_channel : bool
        If True, singleton channel is used.
    block_sizes : list, optional
        List containing block sizes. The default is [256, 256, 128, 128, 64, 64, 32, 32].
    ks : int, optional
        Kernel size of convolutional layers. The default is 5.
    resize_factor : int, optional
        Resize factor used in a high-dimensional middle layer of an EffNet-like block. The default is 4.
    activation : nn.Module, optional
        Activation function. The default is nn.SiLU.
    filter_per_group : int, optional
        Number of filters per group in a middle convolutiona layer of an EffNet-like block. The default is 2.
    se_reduction : int, optional
        Reduction number used in SELayer. The default is 4.
    final_ch : int, optional
        Number of channels in the final output convolutional channel. The default is 18.
    bn_momentum : float, optional
        BatchNorm momentum. The default is 0.1.
    """
    __constants__ = ('resize_factor')
    
    def __init__(self, 
                seqsize, 
                use_single_channel, 
                use_reverse_channel,
                use_multisubstate_channel,
                block_sizes=[256, 256, 128, 128, 64, 64, 32, 32], 
                ks=5, 
                resize_factor=4, 
                activation=nn.SiLU,
                filter_per_group=2,
                se_reduction=4,
                final_ch=18,
                bn_momentum=0.1):        
        super().__init__()
        self.block_sizes = block_sizes
        self.resize_factor = resize_factor
        self.se_reduction = se_reduction
        self.seqsize = seqsize
        self.use_single_channel = use_single_channel
        self.use_reverse_channel = use_reverse_channel
        self.use_multisubstate_channel = use_multisubstate_channel
        self.final_ch = final_ch
        self.bn_momentum = bn_momentum
        seqextblocks = OrderedDict()

        in_channels_first_block = 4
        if self.use_single_channel:
            in_channels_first_block += 1
        if self.use_reverse_channel:
            in_channels_first_block += 1
        if self.use_multisubstate_channel:
            in_channels_first_block += 1
        
        block = nn.Sequential(
                       nn.Conv1d(
                            in_channels=in_channels_first_block,
                            out_channels=block_sizes[0],
                            kernel_size=ks,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(block_sizes[0], momentum=self.bn_momentum),
                       activation()
        )
        seqextblocks[f'blc0'] = block

        
        for ind, (prev_sz, sz) in enumerate(zip(block_sizes[:-1], block_sizes[1:])):
            block = nn.Sequential(
                        nn.Conv1d(
                            in_channels=prev_sz,
                            out_channels=sz * self.resize_factor,
                            kernel_size=1,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(sz * self.resize_factor, momentum=self.bn_momentum),
                       activation(),
                       
                       nn.Conv1d(
                            in_channels=sz * self.resize_factor,
                            out_channels=sz * self.resize_factor,
                            kernel_size=ks,
                            groups=sz * self.resize_factor // filter_per_group,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(sz * self.resize_factor, momentum=self.bn_momentum),
                       activation(),
                
                       SELayer(prev_sz, sz * self.resize_factor, reduction=self.se_reduction),
                
                       nn.Conv1d(
                            in_channels=sz * self.resize_factor,
                            out_channels=prev_sz,
                            kernel_size=1,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(prev_sz, momentum=self.bn_momentum),
                       activation(),
            
            )
            seqextblocks[f'inv_res_blc{ind}'] = block
            block = nn.Sequential(
                        nn.Conv1d(
                            in_channels=2 * prev_sz,
                            out_channels=sz,
                            kernel_size=ks,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(sz, momentum=self.bn_momentum),
                       activation(),
            )
            seqextblocks[f'resize_blc{ind}'] = block

        self.seqextractor = nn.ModuleDict(seqextblocks)

        self.mapper = block = nn.Sequential(
                        nn.Conv1d(
                            in_channels=block_sizes[-1],
                            out_channels=self.final_ch,
                            kernel_size=1,
                            padding='same',
                       ),
                       activation()
        )
        
        self.register_buffer('bins', torch.arange(start=0, end=self.final_ch, step=1, requires_grad=False))
        
    def feature_extractor(self, x):
        x = self.seqextractor['blc0'](x)
        
        for i in range(len(self.block_sizes) - 1):
            x = torch.cat([x, self.seqextractor[f'inv_res_blc{i}'](x)], dim=1)
            x = self.seqextractor[f'resize_blc{i}'](x)
        return x 

    def forward(self, x):    
        f = self.feature_extractor(x)
        x = self.mapper(f)
        x = F.adaptive_avg_pool1d(x, 1)
        x = x.squeeze(2)
        logprobs = F.log_softmax(x, dim=1) 
        
        # soft-argmax operation
        x = F.softmax(x, dim=1)
        score = (x * self.bins).sum(dim=1)
        
        return logprobs, score

### Model training

In [None]:
# Environment setup

set_global_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = False

model_dir = Path(model_dir)
model_dir.mkdir(exist_ok=False, parents=True)

run_backup_path = model_dir / "run.py"

shutil.copy(sys.argv[0], run_backup_path)

In [None]:
train, valid = cash_and_preprocess(seqsize=seqsize, 
                            path_to_training_data=path_to_training_data, 
                            delimiter=delimiter, 
                            foldify=foldify, 
                            use_single_channel=use_single_channel, 
                            use_reverse_channel=use_reverse_channel,
                            use_multisubstate_channel=use_multisubstate_channel,
                            preprocess_data=preprocess_data,
                            use_validation=use_validation,
                            seed=seed
                           )

In [None]:
# Train and validation (if needed) dataloader creation

train_dl, valid_dl = create_dl(train, valid, 
                               seqsize,
                               use_single_channel, use_reverse_channel, use_multisubstate_channel,
                               train_batch_size, train_workers,
                               valid_batch_size, valid_workers,
                               batch_per_epoch,
                               SeqDatasetProb,
                               shuffle_train=True, shuffle_val=False
                              )

In [None]:
# Model creation
device = torch.device(f"cuda:{gpu}")

model = get_model(
    SeqNN=SeqNN,
    seqsize=seqsize, 
    use_single_channel=use_single_channel,
    use_reverse_channel=use_reverse_channel,
    use_multisubstate_channel=use_multisubstate_channel,
    blocks= blocks, 
    ks=ks, 
    resize_factor=resize_factor, 
    se_reduction=se_reduction, 
    bn_momentum=bn_momentum,
    final_ch=final_ch,
    device=device,
)

In [None]:
# Loss declaration
if loss == "kl":
    criterion = nn.KLDivLoss(reduction= "batchmean").to(device)
elif loss == "mse":
    criterion = nn.MSELoss().to(device)
else:
    raise Exception("Wrong loss")

To select the max learning rate for the One Cycle Policy, we used the LR-range test suggested in [Smith, Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186).

In [None]:
# Base optimizer setup. Will be changed after lr range test.

optimizer = get_optimizer(optimizer_name, model.parameters(), .01, weight_decay)

In [None]:
run_lr_finder(model=model, optimizer=optimizer, criterion=criterion, train_dl=train_dl, device=device)

<img src="img/lrtest.png" width="500" />

Here is presented an example of the lr finder output. Lr should be chosen to be on the lowest plateau.

In [None]:
chosen_lr = 10e-2

In [None]:
max_lr, div_factor = chosen_lr, 25.0
min_lr = max_lr / div_factor

In [None]:
model = get_model(
    SeqNN=SeqNN,
    seqsize=seqsize, 
    use_single_channel=use_single_channel,
    use_reverse_channel=use_reverse_channel,
    use_multisubstate_channel=use_multisubstate_channel,
    blocks= blocks, 
    ks=ks, 
    resize_factor=resize_factor, 
    se_reduction=se_reduction, 
    bn_momentum=bn_momentum,
    final_ch=final_ch,
    device=device,
)
optimizer = get_optimizer(optimizer_name, model.parameters(), min_lr, weight_decay)

To train our neural network, we used **One Cycle Policy** with FastAI modifications: 

* two phases (instead of the original three), 
* the cosine annealing strategy instead of the linear one, 
* the AdamW optimizer (weight_decay=0.01) instead of the SGD with momentum.

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=max_lr,
                                                div_factor=div_factor,
                                                steps_per_epoch=batch_per_epoch, 
                                                epochs=epoch_num, 
                                                pct_start=0.3,
                                                three_phase="store_true"
                                               )

In [None]:
print('Model parameters:', int(parameter_count(model)))

In [None]:
trainer, p = create_trainer(model, optimizer, scheduler, criterion, device, model_dir, 
                            use_validation=use_validation, valid_dl=valid_dl)

log_dir = model_dir / "logs"

tb_logger = TensorboardLogger(log_dir=log_dir)
tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED,
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
)

state = trainer.run(train_dl, max_epochs=epoch_num) 

### Model testing

Here is an additional code that you can use to test your trained model on custom data `target`.

For the expression measured in yeast grown in complex and defined medium, native yeast promoter expression measurments can be downloaded for [complex](https://zenodo.org/record/4436477/files/Native_complex.csv?download=1) and [defined](https://zenodo.org/record/4436477/files/Native_defined.csv?download=1) medium from the [Zenodo record](https://zenodo.org/record/4436477#.Y5QgZOxBy3J). 

Results will be written to the `output`. If `target` file includes ground truth expression estimates, `test_results` function will output metrics.

In [None]:
target = 'path_to_target'
output = 'path_to_output'

In [None]:
device = torch.device(f"cuda:{gpu}")
model = get_model(
    SeqNN=SeqNN,
    seqsize=seqsize, 
    use_single_channel=use_single_channel,
    use_reverse_channel=use_reverse_channel,
    use_multisubstate_channel=use_multisubstate_channel,
    blocks= blocks, 
    ks=ks, 
    resize_factor=resize_factor, 
    se_reduction=se_reduction, 
    bn_momentum=bn_momentum,
    final_ch=final_ch,
    device=device,
)

In [None]:
test_results(target, output, 
             model, model_dir, epoch_num, seqsize, 
             use_single_channel, use_reverse_channel, use_multisubstate_channel, 
             preprocess_data, SeqDatasetProb, valid_batch_size, valid_workers, device)

