In [None]:
import torch 
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import math
from API_utils.dataset_FEGS import API_FEGS_Class
from API_utils.dataset_api import API_Class

In [None]:
DATASET_PATH = "C:/Users/asus/Desktop/API/dataset/one_to_one.xls"
DATASET_MAT_PATH = "C:/Users/asus/Desktop/API/dataset/one_to_one.mat"
TEST_DATASET_PATH = "C:/Users/asus/Desktop/API/dataset/test.xlsx"
TEST_DATASET_MAT_PATH = "C:/Users/asus/Desktop/API/dataset/test.mat"
SAVE_MODEL_PATH = "C:/Users/asus/Desktop/API/Model/"
CSV_PATH = "C:/Users/asus/Desktop/API/dataset/Dataset0.csv"
DEVICE= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EPOCH = 100
lr=0.001

In [None]:
print(DEVICE)

In [None]:
# train_data = API_FEGS_Class(DATASET_PATH,DATASET_MAT_PATH,'abc')
# test_data = API_FEGS_Class(TEST_DATASET_PATH,TEST_DATASET_MAT_PATH,'test')


train_data = API_Class(CSV_PATH)
test_data = API_Class(CSV_PATH)

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm1d(in_planes, momentum=0.001)
        self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv1 = nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_planes, momentum=0.001)
        self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv2 = nn.Conv1d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
        self.activate_before_residual = activate_before_residual
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
        super(WideResNet, self).__init__()
        # nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        nChannels = [640, 640, 320,160]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv1d(1, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm1d(nChannels[3], momentum=0.001)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x):

        out = self.conv1(x)

        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)

        out = self.relu(self.bn1(out))

        out = F.avg_pool1d(out, 128)
        # print(out.shape)
        out = out.view(-1, self.nChannels)
        # print(out.shape)
        out = self.fc(out)
        # print(out.shape)
        return out

In [None]:
train_data_loader = DataLoader(train_data,batch_size=1,shuffle=True)
test_data_loader = DataLoader(test_data,batch_size=1,shuffle=True)
Net = WideResNet(num_classes=1).to(DEVICE)
optimizer = optim.AdamW(params=Net.parameters(),lr=lr)
loss_fn = nn.BCEWithLogitsLoss()


In [None]:
def evaluate(model_path,test_data_loader,device=DEVICE):
    count=0
    acc = 0
    model = torch.load(model_path)
    for idx, data in enumerate(test_data_loader):
        api_input, api_label = data
        count+=1
        
        api_input = api_input.to(DEVICE)
        api_input = api_input.to(dtype=torch.float32)
        # print(rna_input)
        rna_input = api_input[:,0:256+64]
        protein_input = api_input[:,256+64:]
        rna_input = torch.unsqueeze(rna_input,dim=0)
        protein_input = torch.unsqueeze(protein_input,dim=0)
        #api_label = torch.unsqueeze(api_label,dim=0)
        rna_input = rna_input.to(dtype=torch.float32).to(DEVICE)
        protein_input = protein_input.to(dtype=torch.float32).to(DEVICE)
        output = model(api_input).to(DEVICE)
        # print(torch.round(torch.sigmoid(output)))
        if(torch.round(torch.sigmoid(output))==api_label[0]):
            acc+=1

    
    print("Accuracy",acc/count)

In [None]:
for epoch in range(EPOCH):
    for data in tqdm(train_data_loader):
        api_input, api_label = data
        # print(api_input.shape)
        # print(api_input)
        # break
        api_input = api_input.to(DEVICE)
        api_input = api_input.to(dtype=torch.float32)
        api_label = api_label.to(DEVICE)
        rna_input = api_input[:,0:64+256]
        # print(api_input.shape)
        protein_input = api_input[:,64+256:]
        rna_input = torch.unsqueeze(rna_input,dim=0)
        protein_input = torch.unsqueeze(protein_input,dim=0)

        api_label = torch.unsqueeze(api_label,dim=0).to(dtype=torch.float32)
        rna_input = rna_input.to(dtype=torch.float32)
        protein_input = protein_input.to(dtype=torch.float32)
        output = Net(api_input)
        # print(output)
        optimizer.zero_grad()    
        Loss = loss_fn(output, api_label)  
        Loss.backward()  
        optimizer.step()
    
    print("Loss",Loss.item())
    save_path = SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}.pth'
    torch.save(Net, save_path)
    evaluate(SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}.pth',test_data_loader)
    # break
    
torch.save(Net, SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}_final.pth')
print('CSX')

In [None]:
model=torch.load(SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}_final.pth')

In [None]:

acc=0
count=0
for idx, data in enumerate(test_data_loader):
    api_input, api_label = data
    count+=1
    rna_input = api_input[:,0:64]
    protein_input = api_input[:,64:909]
    rna_input = torch.unsqueeze(rna_input,dim=0)
    protein_input = torch.unsqueeze(protein_input,dim=0)
    #api_label = torch.unsqueeze(api_label,dim=0)
    rna_input = rna_input.to(dtype=torch.float32)
    protein_input = protein_input.to(dtype=torch.float32)
    output = model(rna_input,protein_input)
    if(torch.argmax(output)==api_label[0]):
        acc+=1

    
print(acc/count)

