In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from KAN import *
import numpy as np
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import matplotlib.pyplot as plt


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
train_positive_pt5 = pd.read_csv("train_positive_ProtT5-XL-UniRef50.csv", header = None).iloc[:,2:]
train_negative_pt5 = pd.read_csv("train_negative_ProtT5-XL-UniRef50.csv", header = None).iloc[:,2:]
test_positive_pt5 = pd.read_csv("test_positive_ProtT5-XL-UniRef50.csv", header = None).iloc[:,2:]
test_negative_pt5 = pd.read_csv("test_negative_ProtT5-XL-UniRef50.csv", header = None).iloc[:,2:]


# create labels
train_positive_labels = np.ones(train_positive_pt5.shape[0])
train_negative_labels = np.zeros(train_negative_pt5.shape[0])
test_positive_labels = np.ones(test_positive_pt5.shape[0])
test_negative_labels = np.zeros(test_negative_pt5.shape[0])

# stack positive and negative data together
X_train_pt5 = np.vstack((train_positive_pt5,train_negative_pt5))
X_test_pt5 = np.vstack((test_positive_pt5,test_negative_pt5))
y_train = np.concatenate((train_positive_labels, train_negative_labels), axis = 0)
y_test = np.concatenate((test_positive_labels, test_negative_labels), axis = 0)

# shuffle X and y together
X_train_pt5, y_train = shuffle(X_train_pt5, y_train)
X_test_pt5, y_test = shuffle(X_test_pt5, y_test)

In [4]:
x_train_tf = torch.from_numpy(X_train_pt5).to(torch.float32)
x_test_tf = torch.from_numpy(X_test_pt5).to(torch.float32)
y_train_tf = torch.from_numpy(y_train).to(torch.float32)
y_test_tf = torch.from_numpy(y_test).to(torch.float32)

In [11]:
trainloader = DataLoader(list(zip(x_train_tf, y_train_tf)), batch_size=1024, shuffle=True)
testloader = DataLoader(list(zip(x_test_tf, y_test_tf)), batch_size=1024, shuffle=False)

In [20]:
model = KAN([1024, 16, 1], wavelet_type='dog')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.0005)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
criterion = nn.MSELoss()

In [21]:
epochs = 10
for epoch in range(epochs):
    print(f'Epoch Number: {int(epoch)+1}')
    # Training
    train_loss, train_correct, train_total = 0.0, 0, 0
    tp, tn, fp, fn = 0,0,0,0
    model.train()
    # for samples, labels in tqdm(trainloader):
    with tqdm(trainloader) as pbar:
        for i, (samples, labels) in enumerate(pbar):
            samples = samples.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(samples).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
            
            train_loss += loss.item()
            pred = (outputs > 0.5).float()
            true = labels
            for x in range(len(pred)):
                if pred[x] and true[x]:
                    tp+=1
                if not pred[x] and not true[x]:
                    tn+=1
                if pred[x] and not true[x]:
                    fp+=1
                if not pred[x] and true[x]:
                   fn+=1
            train_total += labels.size(0)
            train_correct += (pred == labels).sum().item()

    train_loss /= len(trainloader)
    print(f'TP: [{tp}] TN: [{tn}] FP: [{fp}] FN: [{fn}]')
    train_acc = (tp+tn) / (tp+tn+fp+fn)
    print(f'Train Loss: {train_loss}')
    print(f'Train Accuracy: {train_acc}')
    # Update learning rate
    scheduler.step()
model.eval()
y_pred_prob = []
with torch.no_grad():
    for samples, labels in testloader:
        samples = samples.to(device)
        outputs = model(samples).squeeze()
        y_pred_prob.extend(torch.sigmoid(outputs).detach().cpu().numpy())
y_pred_prob = np.array(y_pred_prob)
y_pred = (y_pred_prob > 0.5).astype(int)
y_test_np = y_test_tf.cpu().numpy()
mcc = matthews_corrcoef(y_test_np, y_pred)
cm = confusion_matrix(y_test_np, y_pred)
acc = accuracy_score(y_test_np, y_pred)
print('MCC:', mcc)
print('Accuracy:', acc)
print('Confusion Martix:', cm)

Epoch Number: 1


100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.15it/s, loss=1.19, lr=0.0002]


TP: [1599] TN: [3322] FP: [1428] FN: [3150]
Train Loss: 1.4156940698623657
Train Accuracy: 0.5180545320560059
Epoch Number: 2


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.12it/s, loss=1.08, lr=0.00018]


TP: [2012] TN: [3723] FP: [1027] FN: [2737]
Train Loss: 1.1998140573501588
Train Accuracy: 0.6037477629224128
Epoch Number: 3


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.13it/s, loss=1.09, lr=0.000162]


TP: [2186] TN: [3817] FP: [933] FN: [2563]
Train Loss: 1.1159532070159912
Train Accuracy: 0.6319612590799032
Epoch Number: 4


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s, loss=0.988, lr=0.000146]


TP: [2279] TN: [3894] FP: [856] FN: [2470]
Train Loss: 1.0662566125392914
Train Accuracy: 0.6498578797768186
Epoch Number: 5


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.14it/s, loss=1.07, lr=0.000131]


TP: [2346] TN: [3925] FP: [825] FN: [2403]
Train Loss: 1.042097669839859
Train Accuracy: 0.6601747552373934
Epoch Number: 6


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.17it/s, loss=0.949, lr=0.000118]


TP: [2417] TN: [3960] FP: [790] FN: [2332]
Train Loss: 1.0099605321884155
Train Accuracy: 0.6713338246131172
Epoch Number: 7


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.06it/s, loss=1.01, lr=0.000106]


TP: [2465] TN: [4001] FP: [749] FN: [2284]
Train Loss: 0.9949743330478669
Train Accuracy: 0.6807032319191494
Epoch Number: 8


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.16it/s, loss=0.918, lr=9.57e-5]


TP: [2476] TN: [4029] FP: [721] FN: [2273]
Train Loss: 0.9725735783576965
Train Accuracy: 0.6848089272555006
Epoch Number: 9


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.07it/s, loss=0.873, lr=8.61e-5]


TP: [2490] TN: [4062] FP: [688] FN: [2259]
Train Loss: 0.9572371959686279
Train Accuracy: 0.6897568165070007
Epoch Number: 10


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.11it/s, loss=0.909, lr=7.75e-5]


TP: [2515] TN: [4073] FP: [677] FN: [2234]
Train Loss: 0.9492528736591339
Train Accuracy: 0.693546689125171
MCC: 0.19693377535787304
Accuracy: 0.6822690638561686
Confusion Martix: [[2033  940]
 [  85  168]]
