In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
import glob
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Define GAT model for batched data
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool, global_max_pool

class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels=32, out_channels=16, dropout_p=0.3):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=1, concat=True, edge_dim=1)
        self.norm1 = nn.BatchNorm1d(hidden_channels)
        
        self.gat2 = GATConv(hidden_channels, out_channels, heads=1, concat=True, edge_dim=1)
        self.norm2 = nn.BatchNorm1d(out_channels)
        
        self.dropout = nn.Dropout(dropout_p)
        self.pool = global_mean_pool

        # MLP head
        self.linear1 = nn.Linear(out_channels, 128)
        self.linear2 = nn.Linear(128, 1)

    def forward(self, x, edge_index, edge_attr, batch, return_attention=False): #only apply att weight for val
        attn_weights = None

        x = self.gat1(x, edge_index, edge_attr)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.dropout(x)

        if return_attention:
            x, attn_weights = self.gat2(x, edge_index, edge_attr, return_attention_weights=True)
        else:
            x = self.gat2(x, edge_index, edge_attr)

        x = self.norm2(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.pool(x, batch)
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)

        if return_attention:
            return x.squeeze(1), attn_weights
        return x.squeeze(1)


def organize_graph_and_add_weight(file_path, label):
    data = np.load(file_path, allow_pickle=True).item()
    inverse_distance = data['inverse_distance']
    encoded_matrix = data['encoded_matrix']

    x = torch.tensor(encoded_matrix, dtype=torch.float32)
    adj = torch.tensor(inverse_distance, dtype=torch.float32)

    # Normalize adjacency (row-normalize)
    adj = adj / (adj.sum(dim=1, keepdim=True) + 1e-8)

    # Create edge_index and edge weights
    edge_index = (adj > 0).nonzero(as_tuple=False).t()
    edge_weight = adj[adj > 0]

    y = torch.tensor([label], dtype=torch.float32)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y)

In [5]:
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0

    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()

        outputs= model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = criterion(outputs.view(-1), batch.y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * batch.num_graphs
        preds = torch.sigmoid(outputs) >= 0.5
        correct += (preds == batch.y.view(-1).bool()).sum().item()
        total += batch.num_graphs

    epoch_loss = running_loss / total
    accuracy = correct / total
    return epoch_loss, accuracy


def validate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    total = 0
    correct = 0
    attention_data = []

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)

            outputs, attn_weights = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, return_attention=True)
            loss = criterion(outputs.view(-1), batch.y.view(-1))
            running_loss += loss.item() * batch.num_graphs

            preds = torch.sigmoid(outputs) >= 0.5
            correct += (preds == batch.y.view(-1).bool()).sum().item()
            total += batch.num_graphs

            edge_idx, alpha = attn_weights
            attention_data.append({
                "edge_index": edge_idx.cpu(),
                "attention": alpha.cpu(),
                "batch": batch.batch.cpu()
            })

    validation_loss = running_loss / total
    accuracy = correct / total
    return validation_loss, accuracy, attention_data


In [6]:
# Create a dictionary with file names as keys and label + tensor grid as values
positive_grids = glob.glob('../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/PositiveWithoutSpies/*.npy')
validation_grids = glob.glob('../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Validation_Set/*.npy')

positive_graphs = []
validation_graphs = []

for file in positive_grids:
    positive_graphs.append(organize_graph_and_add_weight(file, 1))

positive_validation_count = 0
unlabeled_validation_count = 0

for file in validation_grids:
    # Label as negative if "-f1" to "-f5" is in the filename
    if any(f"-f{i}" in file for i in range(1, 6)):
        label = 0
        unlabeled_validation_count += 1
        validation_graphs.append(organize_graph_and_add_weight(file, label))
    else:
        label = 1
        positive_validation_count += 1
        validation_graphs.append(organize_graph_and_add_weight(file, label))

print("In validation directory there are", positive_validation_count, "positives and", unlabeled_validation_count, "fragments")

k = 50
bins = []
for i in range(1, k + 1):
    bin = positive_graphs.copy()
    subset_grid = glob.glob(f'../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/k_subsets/subset_{i}/*.npy')  # Adjust path as needed
    for file in subset_grid:
        bin.append(organize_graph_and_add_weight(file, 0))
    
    bins.append(bin)

for i, bin in enumerate(bins):
    pos = sum(1 for g in bin if g.y.item() == 1)
    neg = sum(1 for g in bin if g.y.item() == 0)
    print(f"Bin {i+1}: Positives = {pos}, Negatives = {neg}")

In validation directory there are 77 positives and 277 fragments
Bin 1: Positives = 385, Negatives = 385
Bin 2: Positives = 385, Negatives = 385
Bin 3: Positives = 385, Negatives = 385
Bin 4: Positives = 385, Negatives = 385
Bin 5: Positives = 385, Negatives = 385
Bin 6: Positives = 385, Negatives = 385
Bin 7: Positives = 385, Negatives = 385
Bin 8: Positives = 385, Negatives = 385
Bin 9: Positives = 385, Negatives = 385
Bin 10: Positives = 385, Negatives = 385
Bin 11: Positives = 385, Negatives = 385
Bin 12: Positives = 385, Negatives = 385
Bin 13: Positives = 385, Negatives = 385
Bin 14: Positives = 385, Negatives = 385
Bin 15: Positives = 385, Negatives = 385
Bin 16: Positives = 385, Negatives = 385
Bin 17: Positives = 385, Negatives = 385
Bin 18: Positives = 385, Negatives = 385
Bin 19: Positives = 385, Negatives = 385
Bin 20: Positives = 385, Negatives = 385
Bin 21: Positives = 385, Negatives = 385
Bin 22: Positives = 385, Negatives = 385
Bin 23: Positives = 385, Negatives = 385
B

In [7]:
def plot_graphs(train_losses, validation_losses, validation_accuracies, learning_rates):
    # Plot Training Loss vs Validation Loss
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
    plt.plot(range(1, len(validation_losses) + 1), validation_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss vs Validation Loss')
    plt.legend()
    plt.grid()
    plt.show()

    # Plot Validation Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(validation_accuracies) + 1), validation_accuracies, label='Validation Accuracy', color='green')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy over Epochs')
    plt.legend()
    plt.grid()
    plt.show()

    # Plot Learning Rate
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(learning_rates) + 1), learning_rates, label='Learning Rates', color='green')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy over Epochs')
    plt.legend()
    plt.grid()
    plt.show()

In [8]:
import torch
import torch.nn.functional as F

def weighted_positive_loss(y_pred, y_true, lambda_weight=0.01):
    # BCE Loss
    bce_loss = F.binary_cross_entropy(y_pred, y_true)

    # Additional Weighted Positive Term
    positive_mask = (y_true == 1.0).float()
    P = positive_mask.sum()

    if P > 0:
        wp_term = torch.sqrt(
            torch.mean((torch.log(y_pred[positive_mask.bool()] + 1) - torch.log(y_true[positive_mask.bool()] + 1)) ** 2)
        )
    else:
        wp_term = torch.tensor(0.0, device=y_pred.device)

    total_loss = bce_loss + lambda_weight * wp_term
    return total_loss


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define paths for saving models
save_dir = "GATComplexModelsAndWeights-5A"
os.makedirs(f"{save_dir}/Models", exist_ok=True)
os.makedirs(f"{save_dir}/Weights", exist_ok=True)

epochs = 1000
batch_size = 64

val_loader = DataLoader(validation_graphs, batch_size=batch_size, shuffle=False)

for i, bin in enumerate(bins, start=0):
    model = GAT(in_channels=37, out_channels=16).to(device)
    criterion = nn.BCEWithLogitsLoss()  
    optimizer = optim.Adam(model.parameters(), lr=0.000001, weight_decay=1e-4)

    train_loader = DataLoader(bin, batch_size=batch_size, shuffle=True)    

    print(f"Training on bin {i+1}/{len(bins)}")

    train_losses, val_losses, val_accuracies, learning_rates = [], [], [], []

    for epoch in range(epochs):
        epoch_loss, accuracy = train_model(model, train_loader, criterion, optimizer, device)
        validation_loss, validation_accuracy, attention_data = validate_model(model, val_loader, criterion, device)
        
        current_lr = optimizer.param_groups[0]['lr'] 
        train_losses.append(epoch_loss)
        learning_rates.append(current_lr)
        val_losses.append(validation_loss)
        val_accuracies.append(validation_accuracy)   

        if epoch % 10 == 0:
            print(
                f"Bin {i+1}, Epoch {epoch+1}/{epochs}, "
                f"Train Loss: {epoch_loss:.4f}, Validation Loss: {validation_loss:.4f},  "
                f"ValAccuracy: {validation_accuracy:.4f}, "
                f"LR: {current_lr:.6f}"
            )
                
    plot_graphs(train_losses, val_losses, val_accuracies, learning_rates)

    #Save the trained model
    model_path = os.path.join(save_dir, f"Models/model_bin_{i+1}.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Model for bin {i+1} saved to {model_path}")

    # Save attention weights
    attn_path = os.path.join(save_dir, f"Weights/attn_bin_{i+1}.pt")
    torch.save(attention_data, attn_path)
    print(f"Attention weights for bin {i+1} saved to {attn_path}")


Training on bin 1/50
Bin 1, Epoch 1/1000, Train Loss: 0.6635, Validation Loss: 0.6878,  ValAccuracy: 0.7768, LR: 0.000001
Bin 1, Epoch 11/1000, Train Loss: 0.6687, Validation Loss: 0.7066,  ValAccuracy: 0.3277, LR: 0.000001
Bin 1, Epoch 21/1000, Train Loss: 0.6635, Validation Loss: 0.7093,  ValAccuracy: 0.3333, LR: 0.000001
Bin 1, Epoch 31/1000, Train Loss: 0.6677, Validation Loss: 0.7090,  ValAccuracy: 0.3277, LR: 0.000001
Bin 1, Epoch 41/1000, Train Loss: 0.6627, Validation Loss: 0.7075,  ValAccuracy: 0.3333, LR: 0.000001
Bin 1, Epoch 51/1000, Train Loss: 0.6646, Validation Loss: 0.7045,  ValAccuracy: 0.3503, LR: 0.000001
Bin 1, Epoch 61/1000, Train Loss: 0.6674, Validation Loss: 0.7055,  ValAccuracy: 0.3475, LR: 0.000001
Bin 1, Epoch 71/1000, Train Loss: 0.6577, Validation Loss: 0.7081,  ValAccuracy: 0.3333, LR: 0.000001
Bin 1, Epoch 81/1000, Train Loss: 0.6656, Validation Loss: 0.7097,  ValAccuracy: 0.3418, LR: 0.000001
Bin 1, Epoch 91/1000, Train Loss: 0.6638, Validation Loss: 0.7

KeyboardInterrupt: 