<a href="https://colab.research.google.com/github/2019mohamed/DNA-and-NLP/blob/main/Transformer_based.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler , BatchSampler
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim


class SeqData (Dataset):
    def __init__(self):
        data = pd.read_csv('promoters.csv')
        data['Sequence'] = data['Sequence'].str.replace('\t\t' , '')
        data['Sequence'] = data['Sequence'].str.replace('\t' , '')
        self.seqs = list(data['Sequence'])
        #print(self.seqs[0])
        self.maxlen = len(max(self.seqs , key = lambda k :len(k)))
        self.labels = list(data['Class'])
        self.map = {'a':0 , 'c':1 , 'g':2 , 't':3 }
        #print(self.seqs[34],' ',self.labels[34])
    
    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, index):
        x = np.zeros((self.maxlen,len(self.map)))
        seq = self.seqs[index].lower()
        for i, alpa in enumerate(seq):
            x[i,self.map[alpa]] = 1
        #print(x)
        l = 0 if self.labels[index] == '+' else 1
        return x , l



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler , BatchSampler
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim


class MLP(nn.Module):
    """MLP with linear output"""
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        """MLP layers construction
        Paramters
        ---------
        num_layers: int
            The number of linear layers
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        """
        super(MLP, self).__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = F.relu(self.linears[i](h))
            return self.linears[-1](h)


class Transformer(nn.Module):

    def __init__(self, len_sequence, segment_size, embedding_size, hidden_size, trans_layers, readout_layers, alphabet_size=4,
                 dropout=0.0, heads=1, layer_norm=False, mask='empty'):
        super(Transformer, self).__init__()

        self.segment_size = segment_size

        if mask == "empty":
            self.mask_sequence = generate_empty_mask(len_sequence//segment_size)
        elif mask == "no_prev":
            self.mask_sequence = generate_square_previous_mask(len_sequence//segment_size)
        elif mask[:5] == "local":
            self.mask_sequence = generate_local_mask(len_sequence//segment_size, k=int(mask[5:]))

        self.sequence_trans = TransformerEncoderModel(ntoken=alphabet_size*segment_size, nout=hidden_size, ninp=hidden_size,
                                                  nhead=heads, nhid=hidden_size, nlayers=trans_layers, dropout=dropout,
                                                  layer_norm=layer_norm, max_len=len_sequence//segment_size)

        self.readout = MLP(input_dim=len_sequence // segment_size * hidden_size, hidden_dim=embedding_size, output_dim=embedding_size, num_layers=readout_layers)
                           


    def forward(self, sequence):
        # sequence (B, N, 4)
        (B, N, _) = sequence.shape

        # apply attention layers
        sequence = sequence.reshape((B, N//self.segment_size, -1)).transpose(0, 1)
        enc_sequence = self.sequence_trans(sequence, self.mask_sequence)

        # apply readout
        enc_sequence = enc_sequence.transpose(0, 1).reshape(B, -1)
        embedding = self.readout(enc_sequence)
        return embedding


class TransformerEncoderModel(nn.Module):
    """ Part of this code was adapted from the examples of the PyTorch library """

    def __init__(self, ntoken, nout, ninp, nhead, nhid, nlayers, max_len, dropout=0.0, layer_norm=False):
        super(TransformerEncoderModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.pos_encoder = PositionalEncoding(ninp, dropout, max_len=max_len)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers, norm= \
            nn.LayerNorm(normalized_shape=ninp, eps=1e-6) if layer_norm else None)
        self.encoder = nn.Linear(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, nout)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def generate_square_previous_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def generate_local_mask(sz, k=3):
    mask = torch.eye(sz)
    for i in range(1, k + 1):
        mask += torch.cat((torch.zeros(i, sz), torch.eye(sz)[:-i]), dim=0)
        mask += torch.cat((torch.zeros(sz, i), torch.eye(sz)[:, :-i]), dim=1)

    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def generate_empty_mask(sz):
    mask = torch.zeros(sz, sz)
    return mask

def split_rand(dataset,batch_size, split_ratio=0.7, seed=42, shuffle=True):
    import math
    num_entries = len(dataset)
    indices = list(range(num_entries))
    np.random.seed(seed)
    np.random.shuffle(indices)
    split = int(math.floor(split_ratio * num_entries))
    train_idx, valid_idx = indices[:split], indices[split:]
    
    
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = DataLoader(
            dataset, sampler=train_sampler,
            batch_size=batch_size)
    
    valid_loader = DataLoader(
            dataset, sampler=valid_sampler,
            batch_size=batch_size)

    return train_loader, valid_loader

def train(net, trainloader, optimizer, criterion):
    net.train()

    running_loss = 0
    total_iters = len(trainloader)

    for idx , data  in enumerate(trainloader):

        x, labels = data
        outputs = net(x.float())

        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    running_loss = running_loss / total_iters

    return running_loss

def eval_net(net, dataloader, criterion):
    net.eval()

    total = 0
    total_loss = 0
    total_correct = 0

    for idx ,  data in enumerate(dataloader):
        x, labels = data

        total += len(labels)
        outputs = net(x.float())
        _, predicted = torch.max(outputs.data, 1)

        total_correct += (predicted == labels.data).sum().item()
        loss = criterion(outputs, labels)
        total_loss += loss.item() * len(labels)

    loss, acc = 1.0*total_loss / total, 1.0*total_correct / total


    return loss, acc



dataset = SeqData()

model = Transformer (readout_layers = 2, hidden_size = 64 , trans_layers = 2 , segment_size = 57,  heads = 8 ,  layer_norm=False , embedding_size=128,
                    len_sequence = 57 )

train_loader , test_loader = split_rand(dataset, batch_size = 16)

criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001)

for _ in range(500):
    print(train(model, train_loader, optimizer, criterion))
    
    print(eval_net(model, test_loader, criterion))




4.621615314483643
(4.074456691741943, 0.65625)
3.585037851333618
(2.8197051286697388, 0.46875)
2.124591088294983
(1.2110081315040588, 0.46875)
0.8958014845848083
(0.7132685482501984, 0.40625)
0.6626940488815307
(0.6979579627513885, 0.40625)
0.6314382553100586
(0.6404063701629639, 0.75)
0.5582497477531433
(0.5475876033306122, 0.9375)
0.3720456659793854
(0.345610573887825, 0.875)
0.13121481984853745
(0.21237029135227203, 0.9375)
0.04211123213171959
(0.24333789199590683, 0.90625)
0.01967153958976269
(0.25595518201589584, 0.9375)
0.00611236309632659
(0.26373300701379776, 0.9375)
0.0023978045675903557
(0.27077573991846293, 0.9375)
0.001294133672490716
(0.2759672701358795, 0.9375)
0.0008651306503452361
(0.2787700966000557, 0.9375)
0.0006644129520282149
(0.2804830027744174, 0.9375)
0.0005550084519200027
(0.2815995393320918, 0.9375)
0.00048698969767428933
(0.2825254537165165, 0.9375)
0.0004422612255439162
(0.2830926850438118, 0.9375)
0.0004125510982703418
(0.28345635777805, 0.9375)
0.000385966