In [None]:
from train import *

config_name = "ESM2"
config_params = {"use_graph": False, "use_esm2": True}
results = {}
model_params = {
    "model_embedding_size": 32,
    "model_attention_heads": 3,
    "model_layers": 4,
    "model_dropout_rate": 0.20,
    "model_top_k_ratio": 0.5,
    "model_top_k_every_n": 1 ,
    "model_dense_neurons": 64,
    "use_pooling": True
    }
model_params.update(config_params)

num_classes = 10
feature_size = 8
edge_dim = 2


print('Loading state dict...')
model = HybridModel(feature_size, edge_dim, model_params, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
filename = "states/checkpoint_best_ESM2_model.pth"
model, optimizer, epoch, best_val_acc = load_checkpoint(model, optimizer, filename)

print('Loading data...')
test_data = ProteinDataset('test_partitions', 'test')
test_loader = DataLoader(test_data, batch_size=32)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
       for i, batch in enumerate(tqdm(test_loader, desc="Processing Batches", total=len(test_loader))):
            batch = batch.to(device)
            sequences = batch.sequence
            labels = batch.y
            
            outputs = model(batch.x, batch.edge_attr, batch.edge_index, batch.batch, sequences)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss /= len(test_loader)
    test_acc = correct / total
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    conf_matrix = confusion_matrix(all_labels, all_preds)

    return test_loss, test_acc, precision, recall, f1, conf_matrix


test_loss, test_acc, precision, recall, f1, conf_matrix = evaluate_model(model, test_loader, criterion, device)

results[config_name] = {
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }

print(f"\nResults for {config_name}:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")


architecture_names = {
    (1, 10): {"name": "Mainly Alpha: Orthogonal Bundle", "label": 0},
    (1, 20): {"name": "Mainly Alpha: Up-down Bundle", "label": 1},
    (2, 30): {"name": "Mainly Beta: Roll", "label": 2},
    (2, 40): {"name": "Mainly Beta: Beta Barrel", "label": 3},
    (2, 60): {"name": "Mainly Beta: Sandwich", "label": 4},
    (3, 10): {"name": "Alpha Beta: Roll", "label": 5},
    (3, 20): {"name": "Alpha Beta: Alpha-Beta Barrel", "label": 6},
    (3, 30): {"name": "Alpha Beta: 2-Layer Sandwich", "label": 7},
    (3, 40): {"name": "Alpha Beta: 3-Layer(aba) Sandwich", "label": 8},
    (3, 90): {"name": "Alpha Beta: Alpha-Beta Complex", "label": 9}
}
class_names = [str(key) for key in architecture_names.keys()]
plot_confusion_matrix(conf_matrix, class_names, config_name)


In [None]:
def plot_confusion_matrix(conf_matrix, class_names, config_name):
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'{config_name} Confusion Matrix')
    plt.tight_layout()
    plt.savefig(f'images/{config_name}_confusion_matrix.png')
    plt.close()

architecture_names = {
    (1, 10): {"name": "Mainly Alpha: Orthogonal Bundle", "label": 0},
    (1, 20): {"name": "Mainly Alpha: Up-down Bundle", "label": 1},
    (2, 30): {"name": "Mainly Beta: Roll", "label": 2},
    (2, 40): {"name": "Mainly Beta: Beta Barrel", "label": 3},
    (2, 60): {"name": "Mainly Beta: Sandwich", "label": 4},
    (3, 10): {"name": "Alpha Beta: Roll", "label": 5},
    (3, 20): {"name": "Alpha Beta: Alpha-Beta Barrel", "label": 6},
    (3, 30): {"name": "Alpha Beta: 2-Layer Sandwich", "label": 7},
    (3, 40): {"name": "Alpha Beta: 3-Layer(aba) Sandwich", "label": 8},
    (3, 90): {"name": "Alpha Beta: Alpha-Beta Complex", "label": 9}
}
class_names = [str(key) for key in architecture_names.keys()]
plot_confusion_matrix(conf_matrix, class_names, config_name)

In [None]:
def plot_training_curves(train_losses, train_accs, val_losses, val_accs, config_name):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'{config_name} Training and Validation Loss')

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'{config_name} Training and Validation Accuracy')

    plt.tight_layout()
    plt.savefig(f'images/{config_name}_training_curves.png')
    plt.close()


train_losses = [2.1233, 1.7255, 1.5408, 1.3875, 1.2810, 1.1908, 1.1604, 1.0914, 1.0837, 1.0022, 0.9747, 0.9720, 0.9120]
val_losses =   [1.8933,1.5749,1.4455, 1.4069, 1.3474, 1.3792, 1.3500, 1.2890, 1.2435, 1.2681, 1.2639, 1.1586, 1.2322]
train_accs = [0.2205,0.3385, 0.4128, 0.4615, 0.5000, 0.5404, 0.5513, 0.5974, 0.5769, 0.6160, 0.6327, 0.6346, 0.6481]
val_accs = [0.2846, 0.4051, 0.4410, 0.4821, 0.4821, 0.4821, 0.5256, 0.5385, 0.5231, 0.5231, 0.5718, 0.5744, 0.5615]
config_name = "ESM2"
plot_training_curves(train_losses, train_accs, val_losses, val_accs, config_name)