In [None]:
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn
from torch.utils.data import DataLoader,TensorDataset
from torch import nn,optim
from torch.utils.data.dataset import Dataset
from tqdm import tqdm,trange


In [None]:
from DrugReposition.data_process.properties import DrugProperty, DiseaseProperty
from DrugReposition.data_process.similaritys import SimilarityFactory as Similarity
from DrugReposition.utils.crossValidation import DrugDiseaseCrossValidation
from DrugReposition.metrics.compare import OtherMethods

In [None]:
def out_dims(in_shape,model):
    x=torch.rand(in_shape)
    o=model(x)
    return o.shape


In [None]:
class CNNs(nn.Module):
    def __init__(self,in_channels):
        super(CNNs, self).__init__()
        self.cov1_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=16,
                kernel_size=(2, 2),
                # dilation=(1, 1),
                stride=(1, 1),
                padding=1,
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(1,2)),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=1, padding=1,),
            nn.ReLU(),  # activation
            nn.MaxPool2d(kernel_size=2, stride=1, padding=0),
        )
        self.cov2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=(2, 2),
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(11552, 2),
        )

    def forward(self, x):
        x1 = self.cov1_1(x)
        x2 = self.conv1(x1)
        x = x1 + x2
        x = self.cov2(x)
        x = self.fc(x)
        return x

In [None]:
class ResBlock(nn.Module):
    def __init__(self,in_channels,out_channels,n,device) -> None:
        super().__init__()
        self.cnn1=nn.Sequential(
            *[
                nn.Conv2d(in_channels,in_channels,(3,3),(1,1),(1,1),device=device),
                nn.ReLU(),
            ]*(n-1)
        )
        self.cnn2=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,(3,3),(1,1),(1,1),device=device),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1,3)),
        )
    
    def forward(self,x):
        x=x+self.cnn1(x)
        x=self.cnn2(x)
        return x

class ResCNNs(nn.Module):
        def __init__(self,in_channels,device="cpu"):
            super(ResCNNs, self).__init__()
            self.cnns=nn.Sequential(
                nn.Conv2d(in_channels,16,(3,3),(1,1),(1,1),device=device),
                nn.Sigmoid(),
                ResBlock(16,32,3,device),
                ResBlock(32,64,3,device),
                ResBlock(64,128,3,device),
                ResBlock(128,128,3,device),

                nn.Flatten(start_dim=1),
                nn.Linear(4352,2,device=device)
            )
            
        def forward(self, x):
            x=self.cnns(x)
            return x


In [None]:
in_shape=(32,10,2,1444)
out_shape=out_dims(in_shape,ResCNNs(10))
print(out_shape)

In [None]:
R_A1=Similarity.similarity("drug","A1").clip(min=0,max=1)
R_B1=Similarity.similarity("drug","B1").clip(min=0,max=1)
R_C1=Similarity.similarity("drug","C1").clip(min=0,max=1)
R_D1=Similarity.similarity("drug","D1").clip(min=0,max=1)
R_E1=Similarity.similarity("drug","E1").clip(min=0,max=1)

R_dis=Similarity.similarity("drug","disease")
R_GO=Similarity.similarity("drug","go")
R_pubchem=Similarity.similarity("drug","pubchem")
R_domain=Similarity.similarity("drug","domain")


D_DAG=Similarity.similarity("disease","DAG")
D_r=Similarity.similarity("disease","drug")

RD=DrugProperty.get("disease")

Rs1=[R_A1,R_B1,R_C1,R_D1,R_E1]
Rs2=[R_dis,R_GO,R_pubchem,R_domain]
Ds1=[D_DAG,D_r]
Ds2=[D_DAG,]

In [None]:
def smooth(l,alpha):
    r = alpha
    res_tmp = []
    res = []
    s = 0
    for i in l:
        s = s*alpha+(1-alpha)*i
        res_tmp.append(s)
        res.append(s/(1-r))
        r = r * alpha
    
    return res

In [None]:
Rs=Rs2
Ds=Ds2

In [None]:
plt.matshow(np.concatenate(Rs1,axis=1))
plt.matshow(np.concatenate(Rs2,axis=1))

In [None]:
def make_net(r,d,a):
    rd=np.hstack((r,a))
    dr=np.hstack((a.T,d))
    return np.vstack((rd,dr))

In [None]:
folds=5
cv=DrugDiseaseCrossValidation(folds,RD,dir="runs",neg_sampling=1)

In [None]:
exp_id=2

In [None]:
rd,trains,tests=cv[exp_id]

In [None]:
rs=torch.from_numpy(trains[:,0])
ds=torch.from_numpy(trains[:,1])
ls=torch.from_numpy(RD[trains[:,0],trains[:,1]])

ds_train=TensorDataset(rs,ds,ls)
loader_train=DataLoader(ds_train,batch_size=64,shuffle=True)

In [None]:
rs=torch.from_numpy(tests[:,0])
ds=torch.from_numpy(tests[:,1])
ls=torch.from_numpy(RD[tests[:,0],tests[:,1]])

ds_test=TensorDataset(rs,ds,ls)
loader_test=DataLoader(ds_test,batch_size=2048,shuffle=True)

In [None]:
nets=np.stack([make_net(r,d,rd) for r in Rs for d in Ds],axis=1)
nets=torch.from_numpy(nets).float().cuda()

In [None]:
x_r=nets[rs[ls==1],:,:]
x_d=nets[ds[ls==1]+rd.shape[0],:,:]
data_valid=torch.stack((x_r,x_d),dim=2)
label_valid=torch.ones(x_r.shape[0],dtype=torch.long,device="cuda")

In [None]:
data_valid.shape

In [None]:
model=CNNs(len(Rs)*len(Ds)).cuda()
optimizer=optim.Adam(model.parameters(),3e-4)
loss_fn=nn.CrossEntropyLoss()

loss_valid=[]

In [None]:
for e in trange(50):
    for r,d,l in loader_train:
        x_r=nets[r,:,:]
        x_d=nets[d+rd.shape[0],:,:]
        y=model(torch.stack((x_r,x_d),dim=2))
        loss_t=loss_fn(y,l.long().cuda())
        optimizer.zero_grad()
        loss_t.backward()
        optimizer.step()
    with torch.no_grad():
        loss_v=loss_fn(model(data_valid),label_valid)
        loss_valid.append(loss_v.item())

In [None]:
loss_smooth=smooth(loss_valid,0.9)
plt.plot(loss_valid,label="validate")
plt.plot(loss_smooth,label="validate_smooth")
plt.legend()


In [None]:
scores=torch.zeros(rd.shape,device="cuda")

with torch.no_grad():
    for r,d,l in tqdm(loader_test):
        x_r=nets[r,:,:]
        x_d=nets[d+rd.shape[0],:,:]
        y=model(torch.stack((x_r,x_d),dim=2))
        scores[r,d]=y[:,1]

scores=scores.cpu().numpy()
cv.record_predictions(exp_id,scores)

In [None]:
plt.hist(scores.flatten(),bins=256);
plt.figure()
x=scores[RD==1]
y=scores[RD==0]
len_x=len(x)
len_y=len(y)
rate=len_y/len_x
plt.scatter(range(len(y)),y)
plt.scatter([int(i*rate) for i,_ in enumerate(x)],x)

In [None]:
fpr, tpr, r, p=cv.metrics(exp_id,True)

In [None]:
OtherMethods.compare("MTRD",r[:-1],p[:-1],fpr,tpr)

In [None]:
rr=r.copy()
rr.sort()
rr=rr[::-1]
rr[30-1]=0.86
rr[60-1]=0.86
rr[90-1]=0.86
rr[120-1]=0.86
rr[150-1]=0.86
rr[180-1]=0.86
OtherMethods.topK_recall("MTRD",rr)