In [2]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [3]:
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GraphNorm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import pandas as pd

In [5]:
train_df = pd.read_csv("data/pdb/train_data.csv")
val_df = pd.read_csv("data/pdb/val_data.csv")
test_df = pd.read_csv("data/pdb/test_data.csv")

In [None]:
class SecondaryStructureDataset(Dataset):
    amino_acids = "ACDEFGHIKLMNPQRSTVXWY"
    dssp_types = "GHITEBSP-"

    @staticmethod
    def _to_one_hot(seq: str, charset: str) -> torch.Tensor:
        one_hot = torch.zeros(len(seq), len(charset), dtype=torch.float)
        for i, char in enumerate(seq):
            if char in charset:
                one_hot[i, charset.index(char)] = 1.0
        return one_hot

    def __init__(self, 
                 df: pd.DataFrame,
                 seq_col: str = "sequence",
                 ss_col: str = "secondary_structure"):
        super().__init__()
        self.df = df
        self.seq_col = seq_col
        self.ss_col = ss_col

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]

        sequence = row[self.seq_col]
        secondary_structure = row[self.ss_col]

        x = self._to_one_hot(sequence, self.amino_acids)
        y = self._to_one_hot(secondary_structure, self.dssp_types)

        edge_idx = torch.tensor([[i, j] for i in range(len(sequence)) for j in range(len(sequence)) if i != j], dtype=torch.long).t().contiguous()
        data = Data(x=x, edge_index=edge_idx, y=y)
        return data
    
    def class_weights(self):
        ss_cat = "".join(self.df[self.ss_col])
        counts = {char: ss_cat.count(char) for char in self.dssp_types}
        total = sum(counts.values())
        weights = {char: total / count if count > 0 else 0 for char, count in counts.items()}
        return torch.tensor([weights[char] for char in self.dssp_types], dtype=torch.float)
        

In [None]:
train_dataset = SecondaryStructureDataset(train_df)
val_dataset = SecondaryStructureDataset(val_df)
test_dataset = SecondaryStructureDataset(test_df)

In [None]:
training_loader = DataLoader(train_dataset, batch_size = 8, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = True)

In [None]:
# Parametric Exponential Linear Unit (PELU) activation function
class PELU(nn.Module):
    def __init__(self, alpha=1.0):
        super(PELU, self).__init__()
        self.log_alpha = nn.Parameter(torch.log(torch.tensor(alpha)))

    def forward(self, x):
        alpha = torch.exp(self.log_alpha)
        return torch.where(x >= 0, x, alpha * (torch.exp(x) - 1))

In [None]:
class GCN(nn.Module):
    def __init__(self, 
                 in_channels: int = len(SecondaryStructureDataset.amino_acids),
                 out_channels: int = len(SecondaryStructureDataset.dssp_types),
                 hidden_channels: int = 64,
                 num_layers: int = 3):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(GraphNorm(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(GraphNorm(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.act = PELU()
    
    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index
        for i, (conv, norm) in enumerate(zip(self.convs[:-1], self.norms)):
            identity = x
            x = conv(x, edge_index)
            x = norm(x)
            x = self.act(x)
            x += identity

        x = self.convs[-1](x, edge_index)
        return x

In [None]:
from sklearn.metrics import classification_report, accuracy_score
from torch.utils.tensorboard import SummaryWriter

In [None]:
def classification_metrics(y_true, y_pred, name = "", tensorboard_writer = None, writer_val = 0):
    target_names = list(SecondaryStructureDataset.dssp_types)
    report = classification_report(y_true, y_pred, target_names=target_names, zero_division=0, output_dict=True)
    acc = accuracy_score(y_true, y_pred)

    write_keys = target_names + ["macro avg", "weighted avg"]

    if tensorboard_writer:
        for key in write_keys:
            if key in report:
                tensorboard_writer.add_scalar(f"Precision/{name}/{key}", report[key]["precision"], writer_val)
                tensorboard_writer.add_scalar(f"Recall/{name}/{key}", report[key]["recall"], writer_val)
                tensorboard_writer.add_scalar(f"F1/{name}/{key}", report[key]["f1-score"], writer_val)
        tensorboard_writer.add_scalar(f"Accuracy/{name}", acc, 0)

    return report, acc

In [None]:
def train(epoch: int,
          model: GCN,
          loader: SecondaryStructureDataset,
          optimizer: torch.optim.Optimizer,
          criterion: nn.Module,
          batch_print_freq: int = 32,
          writer: SummaryWriter = None) -> float:
    model.train()

    total_loss = 0.0
    running_loss = 0.0

    all_preds = []
    all_labels = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
        running_loss += loss.item() * data.num_graphs

        preds = out.argmax(dim=1)
        all_preds.append(preds.cpu())
        all_labels.append(data.y.argmax(dim=1).cpu())

        if writer is not None and len(all_preds) % batch_print_freq == batch_print_freq - 1:
            avg_loss = running_loss / (batch_print_freq * data.num_graphs)
            writer.add_scalar("Loss/train", avg_loss, epoch * len(loader) + len(all_preds))
            running_loss = 0.0

            # Write F1 score, recall, precision, and accuracy
            last_few_preds = torch.cat(all_preds[-batch_print_freq:]).numpy()
            last_few_labels = torch.cat(all_labels[-batch_print_freq:]).numpy()

            classification_metrics(
                last_few_labels,
                last_few_preds,
                name="train/batch",
                tensorboard_writer=writer,
                writer_val=epoch * len(loader) + len(all_preds)
            )

    total_loss /= len(loader.dataset)

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    report,accuracy = classification_metrics(
        all_labels,
        all_preds,
        name="train/epoch",
        tensorboard_writer=writer,
        writer_val=epoch
    )

    f1, recall, precision = report["macro avg"]["f1-score"], report["macro avg"]["recall"], report["macro avg"]["precision"]

    return total_loss, f1, recall, precision, accuracy

In [None]:
def evaluate(model: GCN,
            loader: SecondaryStructureDataset,
            criterion: nn.Module,
            dataset_name: str,
            epoch: int = 0,
            writer: SummaryWriter = None) -> float:
    model.eval()

    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            loss = criterion(out, data.y)

            total_loss += loss.item() * data.num_graphs

            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(data.y.argmax(dim=1).cpu())

    total_loss /= len(loader.dataset)

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    
    report, accuracy = classification_metrics(
        all_labels,
        all_preds,
        name=f"{dataset_name}/epoch",
        tensorboard_writer=writer,
        writer_val=epoch
    )

    f1, recall, precision = report["macro avg"]["f1-score"], report["macro avg"]["recall"], report["macro avg"]["precision"]

    return total_loss, f1, recall, precision, accuracy

In [None]:
import pathlib
from datetime import datetime

In [None]:
def train_model(model: GCN,
                train_loader: DataLoader,
                val_loader: DataLoader,
                test_loader: DataLoader,
                optimizer: torch.optim.Optimizer,
                criterion: nn.Module,
                epochs: int = 100,
                best_val_loss: float = float("inf"),
                batch_print_freq: int = 32,
                model_name: str = None,) -> GCN:
    if model_name is None:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_name = f"SecondCount-{timestamp}"

    writer = SummaryWriter(f"runs/{model_name}")

    model.to(device)
    pathlib.Path(f"models/{model_name}").mkdir(parents=True, exist_ok=True)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        train_loss, _, _, _, _ = train(
            epoch, model, train_loader, optimizer, criterion, batch_print_freq, writer
        )
        print(f"Train Loss: {train_loss:.4f}")

        val_loss, val_f1, val_recall, val_precision, val_accuracy = evaluate(
            model, val_loader, criterion, "val", epoch, writer
        )
        print(f"Validation Loss: {val_loss:.4f}, F1: {val_f1:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, Accuracy: {val_accuracy:.4f}")

        save_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "best_val_loss": best_val_loss,
        }

        torch.save(save_state, f"models/{model_name}/checkpoint.pth")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(save_state, f"models/{model_name}/best_model.pth")
            print(f"New best model saved with validation loss: {best_val_loss:.4f}")

    print("Training complete. Evaluating on test set...")
    test_loss, test_f1, test_recall, test_precision, test_accuracy = evaluate(
        model, test_loader, criterion, "test", writer
    )
    print(f"Test Loss: {test_loss:.4f}, F1: {test_f1:.4f}, Recall: {test_recall:.4f}, Precision: {test_precision:.4f}, Accuracy: {test_accuracy:.4f}")
    writer.close()

    return model

In [None]:
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(weight=train_dataset.class_weights().to(device))
trained_model = train_model(
    model,
    training_loader,
    val_loader,
    test_loader,
    optimizer,
    criterion
)