In [12]:
import torch 
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

from API_utils.dataset_FEGS import API_FEGS_Class
from API_utils.dataset_api import API_Class

In [13]:
DATASET_PATH = "C:/Users/asus/Desktop/API/dataset/one_to_one.xls"
DATASET_MAT_PATH = "C:/Users/asus/Desktop/API/dataset/one_to_one.mat"
TEST_DATASET_PATH = "C:/Users/asus/Desktop/API/dataset/test.xlsx"
TEST_DATASET_MAT_PATH = "C:/Users/asus/Desktop/API/dataset/test.mat"
SAVE_MODEL_PATH = "C:/Users/asus/Desktop/API/Model/"
CSV_PATH = "C:/Users/asus/Desktop/API/dataset/Dataset.csv"
DEVICE= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EPOCH = 100
lr=0.0001

In [14]:
print(DEVICE)

cuda:0


In [15]:
# train_data = API_FEGS_Class(DATASET_PATH,DATASET_MAT_PATH,'abc')
# test_data = API_FEGS_Class(TEST_DATASET_PATH,TEST_DATASET_MAT_PATH,'test')

train_data = API_Class(CSV_PATH)
test_data = API_Class(CSV_PATH)

In [16]:
class SpatialGatingUnit(nn.Module):  # [-1,256,256]
    def __init__(self, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn)   # [-1,256,256]->[-1,256,512]
        self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) # [-1,256,512]->[-1,256,512]
        nn.init.constant_(self.spatial_proj.bias, 1.0)  # 偏差
 
    def forward(self, x):
        # chunk(arr, size)接收两个参数，一个是原数组，一个是分块的大小size，默认值为1，
        # 原数组中的元素会按照size的大小从头开始分块，每一块组成一个新数组，如果最后元素个数不足size的大小，那么它们会组成一个快。
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        v = self.spatial_proj(v)
        out = u * v
        return out

In [17]:
class gMLPBlock(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn * 2) # (256, d_ffn * 2=1024)  [-1,256,1024]
        self.sgu = SpatialGatingUnit(d_ffn, seq_len)   #
        self.channel_proj2 = nn.Linear(d_ffn, d_model)
 
    def forward(self, x):
        residual = x
        x = self.norm(x)     # [-1,256,256]
        x = F.gelu(self.channel_proj1(x))  # GELU激活函数 [-1,256,256]
        x = self.sgu(x)   # [-1,256,256]
        x = self.channel_proj2(x)
        out = x + residual
        return out

In [18]:
class gMLP(nn.Module):
    def __init__(self, d_model=256, d_ffn=512, seq_len=256, num_layers=6):
        super().__init__()
        self.model = nn.Sequential(
            *[gMLPBlock(d_model, d_ffn, seq_len) for _ in range(num_layers)]
        )
        # [gMLPBlock(d_model=256, d_ffn=512, seq_len=256) for _ in range(num_layers)]
 
    def forward(self, x):
        return self.model(x)

In [19]:
class gMLPForImageClassification(gMLP):
    def __init__(
        self,
        image_size=256+64,
        patch_size=16,
        in_channels=1,
        num_classes=2,
        d_model=256,
        d_ffn=512,
        seq_len=256,
        num_layers=6,
    ):
        # num_patches = check_sizes(image_size, patch_size)  # num_patches=256
        super().__init__(d_model, d_ffn, seq_len, num_layers)
        self.patcher = nn.Conv1d(
            in_channels, d_model, kernel_size=patch_size, stride=patch_size
        )  # [2, 3, 256, 256] -> [2, 256, 16, 16]
        self.classifier = nn.Linear(d_model, num_classes)
 
    def forward(self, x):
        # a = x.shape = [2,3,256,256]
        patches = self.patcher(x)
        batch_size, num_channels, _, _ = patches.shape  # [2,256,16,16]
        patches = patches.permute(0, 2, 3, 1)  # 将tensor的维度换位 [2,256,16,16]->[2,16,16,256]
        patches = patches.view(batch_size, -1, num_channels) # 转为(2,-1,256)  即为[2,256,256]
        # a = patches.shape
        embedding = self.model(patches)
        # a = embedding.shape = [2,256,256]
 
        embedding = embedding.mean(dim=1)
        out = self.classifier(embedding)
        return out

In [20]:
train_data_loader = DataLoader(train_data,batch_size=1,shuffle=True)
test_data_loader = DataLoader(test_data,batch_size=1,shuffle=True)
Net = gMLPForImageClassification().to(DEVICE)
optimizer = optim.AdamW(params=Net.parameters(),lr=lr)
loss_fn = nn.BCEWithLogitsLoss()


In [21]:
def evaluate(model_path,test_data_loader,device=DEVICE):
    count=0
    acc = 0
    model = torch.load(model_path)
    for idx, data in enumerate(test_data_loader):
        api_input, api_label = data
        count+=1
        rna_input = api_input[:,0:256+64]
        # print(rna_input)
        protein_input = api_input[:,256+64:]
        rna_input = torch.unsqueeze(rna_input,dim=0)
        protein_input = torch.unsqueeze(protein_input,dim=0)
        #api_label = torch.unsqueeze(api_label,dim=0)
        rna_input = rna_input.to(dtype=torch.float32).to(DEVICE)
        protein_input = protein_input.to(dtype=torch.float32).to(DEVICE)
        output = model(rna_input,protein_input).to(DEVICE)
        # print(torch.round(torch.sigmoid(output)))
        if(torch.round(torch.sigmoid(output))==api_label[0]):
            acc+=1

    
    print("Accuracy",acc/count)

In [22]:
for epoch in range(EPOCH):
    for data in tqdm(train_data_loader):
        api_input, api_label = data
        # print(api_input.shape)
        # print(api_input)
        # break
        api_input = api_input.to(DEVICE)
        api_input = api_input.to(dtype=torch.float32)
        api_label = api_label.to(DEVICE)
        rna_input = api_input[:,0:64+256]
        # print(api_input.shape)
        protein_input = api_input[:,64+256:]
        rna_input = torch.unsqueeze(rna_input,dim=0)
        protein_input = torch.unsqueeze(protein_input,dim=0)

        api_label = torch.unsqueeze(api_label,dim=0).to(dtype=torch.float32)
        rna_input = rna_input.to(dtype=torch.float32)
        protein_input = protein_input.to(dtype=torch.float32)
        output = Net(api_input)
        # print(output)
        optimizer.zero_grad()    
        Loss = loss_fn(output, api_label)  
        Loss.backward()  
        optimizer.step()
    
    print("Loss",Loss.item())
    save_path = SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}.pth'
    torch.save(Net, save_path)
    evaluate(SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}.pth',test_data_loader)
    # break
    
torch.save(Net, SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}_final.pth')
print('CSX')

  0%|          | 0/1160 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 4, got 2)

In [None]:
model=torch.load(SAVE_MODEL_PATH+f'Epoch={epoch}_lr={lr}_final.pth')

In [None]:

acc=0
count=0
for idx, data in enumerate(test_data_loader):
    api_input, api_label = data
    count+=1
    rna_input = api_input[:,0:64]
    protein_input = api_input[:,64:909]
    rna_input = torch.unsqueeze(rna_input,dim=0)
    protein_input = torch.unsqueeze(protein_input,dim=0)
    #api_label = torch.unsqueeze(api_label,dim=0)
    rna_input = rna_input.to(dtype=torch.float32)
    protein_input = protein_input.to(dtype=torch.float32)
    output = model(rna_input,protein_input)
    if(torch.argmax(output)==api_label[0]):
        acc+=1

    
print(acc/count)

