In [None]:
from torch import nn
import torch

device = "cpu"
distributed = False

if torch.cuda.is_available():
    device = "cuda"
    distributed = torch.cuda.device_count() > 1
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

In [None]:
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

In [None]:
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GraphNorm

In [None]:
import pandas as pd

In [None]:
class SecondaryStructureDataset(Dataset):
    amino_acids = "ACDEFGHIKLMNPQRSTVXWY"
    dssp_types = "GHITEBSP-"
    edge_distances = [1, 2, 3, 4]

    @staticmethod
    def _to_one_hot(seq: str, charset: str) -> torch.Tensor:
        one_hot = torch.zeros(len(seq), len(charset), dtype=torch.float)
        for i, char in enumerate(seq):
            if char in charset:
                one_hot[i, charset.index(char)] = 1.0
        return one_hot

    @staticmethod
    def _to_index(seq: str, charset: str) -> torch.Tensor:
        indices = torch.zeros(len(seq), dtype=torch.long)
        for i, char in enumerate(seq):
            if char in charset:
                indices[i] = charset.index(char)
        return indices

    def __init__(self, 
                 df: pd.DataFrame,
                 seq_col: str = "sequence",
                 ss_col: str = "secondary_structure"):
        super().__init__()
        self.df = df
        self.seq_col = seq_col
        self.ss_col = ss_col
        self.max_seq_length = df[seq_col].str.len().max()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]

        sequence = row[self.seq_col]
        secondary_structure = row[self.ss_col]

        x = self._to_one_hot(sequence, self.amino_acids)
        y = self._to_one_hot(secondary_structure, self.dssp_types)

        edge_idx = []
        for i in range(len(sequence)):
            for j in SecondaryStructureDataset.edge_distances:
                if i + j < len(sequence):
                    edge_idx.append((i, i + j))
                if i - j >= 0:
                    edge_idx.append((i, i - j))

        edge_idx = torch.tensor(edge_idx, dtype=torch.long).t().contiguous()

        data = Data(x=x, edge_index=edge_idx, y=y)
        return data
    
    def class_weights(self):
        ss_cat = "".join(self.df[self.ss_col])
        counts = {char: ss_cat.count(char) for char in self.dssp_types}
        total = sum(counts.values())
        weights = {char: total / count if count > 0 else 0 for char, count in counts.items()}
        return torch.tensor([weights[char] for char in self.dssp_types], dtype=torch.float)
        

In [None]:
# Parametric Exponential Linear Unit (PELU) activation function
class PELU(nn.Module):
    def __init__(self, alpha=1.0):
        super(PELU, self).__init__()
        self.log_alpha = nn.Parameter(torch.log(torch.tensor(alpha)))

    def forward(self, x):
        alpha = torch.exp(self.log_alpha)
        return torch.where(x >= 0, x, alpha * (torch.exp(x) - 1))

In [None]:
class GCN(nn.Module):
    def __init__(self,
                 in_channels: int = len(SecondaryStructureDataset.amino_acids),
                 out_channels: int = len(SecondaryStructureDataset.dssp_types),
                 hidden_channels: int = 64,
                 num_layers: int = 3):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(GraphNorm(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(GraphNorm(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.act = PELU()

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index

        for i, (conv, norm) in enumerate(zip(self.convs[:-1], self.norms)):
            identity = x
            x = conv(x, edge_index)
            x = norm(x)
            x = self.act(x)
            if i != 0:
              x = x + identity

        x = self.convs[-1](x, edge_index)
        return x

In [None]:
class EdgeProbabilisties(nn.Module):
    # Computes long-range edge probabilities
    # This is in addition to the fixed edges defined in the dataset that only
    # contain local connection

    # Takes in an arbitrary amount of one-hot encoded amino acids sequences
    # and outputs edge probabilities between residues
    def __init__(self,
                 amino_acid_dim: int = len(SecondaryStructureDataset.amino_acids),
                 max_length: int = 2048,
                 embed_dim: int = 32,
                 kernel_size: int = 7,
                 num_layers: int = 3):
        super(EdgeProbabilisties, self).__init__()

        self.amino_acid_dim = amino_acid_dim
        self.embed_dim = embed_dim

        self.embedding = nn.Embedding(amino_acid_dim, embed_dim)

        # Sinusoidal positional encoding
        self.register_buffer("positional_encoding", self._gen_positional_encoding(max_length))

        self.conv_layers = nn.ModuleList()

        for _ in range(num_layers):
            self.conv_layers.append(
                nn.Conv1d(embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size//2),
                nn.GELU(),
                nn.LayerNorm(embed_dim)
            )

    def _gen_positional_encoding(self, length: int) -> torch.Tensor:
        pe = torch.zeros(length, self.embed_dim)
        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / self.embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Takes in one-hot encoded sequences of shape (N, L, A)
        # Outputs edge probabilities of shape (N, L, L)
        batch_size, seq_length, _ = x.size()
        x = torch.argmax(x, dim=-1)  # Convert one-hot to indices
        x = self.embedding(x)  # (N, L, E)

        x = x + self.positional_encoding[:seq_length, :].unsqueeze(0)  # Add positional encoding
        
        x = x.permute(0, 2, 1)  # (N, E, L)
        
        for layer in self.conv_layers:
            x = layer(x)
        
        x = x.permute(0, 2, 1)  # (N, L, E)

        # Compute pairwise scores
        scores = torch.matmul(x, x.transpose(1, 2))  # (N, L, L)
        probabilities = torch.sigmoid(scores)

        return probabilities


In [None]:
class SecondaryStructurePredictor(nn.Module):
    def __init__(self,
                 gcn_in_channels: int = len(SecondaryStructureDataset.amino_acids),
                 gcn_out_channels: int = len(SecondaryStructureDataset.dssp_types),
                 gcn_hidden_channels: int = 64,
                 gcn_num_layers: int = 3,
                 edge_embed_dim: int = 32,
                 edge_kernel_size: int = 7,
                 edge_num_layers: int = 3,
                 edge_threshold: float = 0.7,
                 edge_top_k: int = -1):
        super(SecondaryStructurePredictor, self).__init__()
        self.edge_model = EdgeProbabilisties(
            amino_acid_dim=gcn_in_channels,
            embed_dim=edge_embed_dim,
            kernel_size=edge_kernel_size,
            num_layers=edge_num_layers
        )
        self.gcn_model = GCN(
            in_channels=gcn_in_channels,
            out_channels=gcn_out_channels,
            hidden_channels=gcn_hidden_channels,
            num_layers=gcn_num_layers
        )

        self.edge_threshold = edge_threshold
        self.edge_top_k = edge_top_k

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index

        # Compute edge probabilities
        edge_probs = self.edge_model(x.unsqueeze(0)).squeeze(0)  # (L, L)

        # Remove all edges below threshold
        edge_mask = edge_probs >= self.edge_threshold
        
        # If top_k is not -1, keep only top_k edges per node
        if self.edge_top_k > 0:
            _, topk_indices = torch.topk(edge_probs, self.edge_top_k, dim=-1)
            topk_mask = torch.zeros_like(edge_mask, dtype=torch.bool)
            for i in range(edge_mask.size(0)):
                topk_mask[i, topk_indices[i]] = True
            edge_mask = edge_mask & topk_mask
        
        # Only add edges that are not already in edge_index
        edges = set((edge_index[0, i].item(), edge_index[1, i].item()) for i in range(edge_index.size(1)))
        new_edges = []
        for i in range(edge_mask.size(0)):
            for j in range(edge_mask.size(1)):
                if edge_mask[i, j] and (i, j) not in edges:
                    new_edges.append((i, j))

        if new_edges:
            new_edge_index = torch.tensor(new_edges, dtype=torch.long).t().contiguous()
            edge_index = torch.cat([edge_index, new_edge_index], dim=1)

        data.edge_index = edge_index
        out = self.gcn_model(data)
        return out


In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall

In [None]:
def train(epoch: int,
          model: SecondaryStructurePredictor,
          loader: SecondaryStructureDataset,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler,
          criterion: nn.Module,
          device: str | torch.device | int,
          batch_print_freq: int = 32,
          writer: SummaryWriter = None,
          distributed: bool = False) -> float:
    model.train()

    total_loss = torch.zeros(1, device=device, requires_grad=False)
    running_loss = torch.zeros(1, device=device, requires_grad=False)
    total_count = torch.zeros(1, device=device, requires_grad=False)
    running_count = torch.zeros(1, device=device, requires_grad=False)

    num_class = len(SecondaryStructureDataset.dssp_types)

    metrics = {}

    for avg in [None, 'macro', 'micro', 'weighted']:
        metrics[f'accuracy_{avg}'] = Accuracy(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'f1_score_{avg}'] = F1Score(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'precision_{avg}'] = Precision(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'recall_{avg}'] = Recall(num_classes=num_class, task='multiclass', average=avg)

    distTorchMetricsBatch = MetricCollection(metrics=metrics).to(device)

    distTorchMetricsEpoch = MetricCollection(metrics=metrics).to(device)

    for i, data in enumerate(loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
        running_loss += loss.item() * data.num_graphs
        total_count += data.num_graphs
        running_count += data.num_graphs

        scheduler.step()

        distTorchMetricsBatch.update(out.argmax(dim=1), data.y.argmax(dim=1))
        distTorchMetricsEpoch.update(out.argmax(dim=1), data.y.argmax(dim=1))

        if i % batch_print_freq == 0 and (writer is not None or distributed):
            if distributed:
                torch.distributed.all_reduce(
                    running_loss,
                    op=torch.distributed.ReduceOp.SUM
                )
                torch.distributed.all_reduce(
                    running_count,
                    op=torch.distributed.ReduceOp.SUM
                )

            if writer is not None:
                local_stats = distTorchMetricsBatch.compute()
                for key, value in local_stats.items():
                    if distributed:
                        torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
                        value /= torch.distributed.get_world_size()
                    if value.dim() > 0 and value.size(0) > 1:
                        for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                            writer.add_scalar(f"{key}/train/batch/{cls}", value[idx].item(), epoch * len(loader) + i)
                    else:
                        writer.add_scalar(f"{key}/train/batch", value.item(), epoch * len(loader) + i)

                writer.add_scalar("Loss/train/batch", running_loss.item() / running_count.item(), epoch * len(loader) + i)
                distTorchMetricsBatch.reset()

            running_loss.zero_()
            running_count.zero_()

    if distributed:
        torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
        torch.distributed.all_reduce(total_count, op=torch.distributed.ReduceOp.SUM)
    total_loss /= total_count.item()

    if writer is not None:
        writer.add_scalar("Loss/train/epoch", total_loss.item(), epoch)
    
    epoch_stats = distTorchMetricsEpoch.compute()
    for key, value in epoch_stats.items():
        if distributed:
            torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
            value /= torch.distributed.get_world_size()
        if writer is not None:
            writer.add_scalar(f"{key}/train/epoch", value.item(), epoch)
            for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                writer.add_scalar(f"{key}/train/epoch/{cls}", value[idx].item(), epoch)

    return total_loss.item(), epoch_stats['accuracy'].item(), epoch_stats['f1_score'].item(), epoch_stats['precision'].item(), epoch_stats['recall'].item()

In [None]:
def evaluate(model: SecondaryStructurePredictor,
            loader: DataLoader,
            criterion: nn.Module,
            dataset_name: str,
            device: str | torch.device | int,
            epoch: int = 0,
            writer: SummaryWriter = None,
            distributed: bool = False) -> float:
    model.eval()

    total_loss = torch.zeros(1, device=device, requires_grad=False)
    total_count = torch.zeros(1, device=device, requires_grad=False)
    
    num_class = len(SecondaryStructureDataset.dssp_types)

    metricDict = {}

    for avg in [None, 'macro', 'micro', 'weighted']:
        metricDict[f'accuracy_{avg}'] = Accuracy(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'f1_score_{avg}'] = F1Score(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'precision_{avg}'] = Precision(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'recall_{avg}'] = Recall(num_classes=num_class, task='multiclass', average=avg)

    metrics = MetricCollection(metricDict).to(device)

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            loss = criterion(out, data.y)

            total_loss += loss.item() * data.num_graphs
            total_count += data.num_graphs
            
            metrics.update(out.argmax(dim=1), data.y.argmax(dim=1))

    if distributed:
        torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
        torch.distributed.all_reduce(total_count, op=torch.distributed.ReduceOp.SUM)
    
    total_loss /= total_count.item()
    
    # Compute metrics
    stats = metrics.compute()
    accuracy = stats['accuracy']
    f1 = stats['f1_score']
    recall = stats['recall']
    precision = stats['precision']
    
    # Handle distributed metrics
    if distributed:
        for key, value in stats.items():
            torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
            value /= torch.distributed.get_world_size()
    
    # Write to tensorboard
    if writer is not None:
        writer.add_scalar(f"Loss/{dataset_name}/epoch", total_loss.item(), epoch)
        for key, value in stats.items():
            if value.dim() > 0 and value.size(0) > 1:
                for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                    writer.add_scalar(f"{key}/{dataset_name}/epoch/{cls}", stats[key][idx].item(), epoch)
            else:
                writer.add_scalar(f"{key}/{dataset_name}/epoch", value.item(), epoch)
    
    return total_loss.item(), f1.item(), recall.item(), precision.item(), accuracy.item()

In [None]:
import pathlib
from datetime import datetime

In [None]:
def ddp_setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.cuda.set_device(rank)
    init_process_group("gloo", rank=rank, world_size=world_size)

In [None]:
def run(save_every: int,
        total_epochs: int,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        test_dataloader: DataLoader,
        device: str | torch.device | int,
        distributed: bool = False):
    model = SecondaryStructurePredictor(
        gcn_in_channels=len(SecondaryStructureDataset.amino_acids),
        gcn_out_channels=len(SecondaryStructureDataset.dssp_types),
        gcn_hidden_channels=128,
        gcn_num_layers=10,
        edge_embed_dim=32,
        edge_kernel_size=7,
        edge_num_layers=3,
        edge_threshold=0.7,
        edge_top_k=-1,
    ).to(device)
    model_name = f"SecondCount-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

    if distributed:
        model = DDP(model, device_ids=[device])

    writer = SummaryWriter(f"runs/{model_name}") if not distributed or torch.distributed.get_rank() == 0 else None
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(weight=train_dataloader.dataset.class_weights().to(device))

    steps_per_epoch = len(train_dataloader)

    warmup = torch.optim.lr_scheduler.LinearLR(
        optimizer=optimizer,
        start_factor=0.1,
        total_iters=steps_per_epoch // 5
    )

    cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer=optimizer,
        T_0=steps_per_epoch * 25
    )

    scheduler = torch.optim.lr_scheduler.ChainedScheduler(
        schedulers=[warmup, cosine]
    )

    for epoch in range(total_epochs):
        print(f"Epoch {epoch + 1}/{total_epochs}")
        
        train_loss, train_accuracy, train_f1, train_precision, train_recall = train(
            epoch, model, train_dataloader, optimizer, scheduler, criterion, device,
            batch_print_freq=save_every, writer=writer, distributed=distributed
        )
        if not distributed or torch.distributed.get_rank() == 0:
            print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}, F1: {train_f1:.4f}, Precision: {train_precision:.4f}, Recall: {train_recall:.4f}")

        scheduler.step()

        val_loss, val_f1, val_recall, val_precision, val_accuracy = evaluate(
            model, val_dataloader, criterion, "val", device, epoch, writer, distributed, 
        )
        if not distributed or torch.distributed.get_rank() == 0:
            print(f"Validation Loss: {val_loss:.4f}, F1: {val_f1:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, Accuracy: {val_accuracy:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"models/{model_name}/best_model.pth")
            print(f"New best model saved with validation loss: {best_val_loss:.4f}")
    
        save_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "best_val_loss": best_val_loss,
        }

        torch.save(save_state, f"models/{model_name}/checkpoint-{epoch}{f"-{device}" if distributed else ""}.pth")

    if not distributed or torch.distributed.get_rank() == 0:
        print("Training complete. Evaluating on test set...")
    test_loss, test_f1, test_recall, test_precision, test_accuracy = evaluate(
        model, test_dataloader, criterion, "test", device, writer=writer, distributed=distributed
    )

    if not distributed or torch.distributed.get_rank() == 0:
        print(f"Test Loss: {test_loss:.4f}, F1: {test_f1:.4f}, Recall: {test_recall:.4f}, Precision: {test_precision:.4f}, Accuracy: {test_accuracy:.4f}")
        writer.close()
    
    if distributed:
        destroy_process_group()

    

In [None]:
def dist_run(rank: int, 
             world_size: int, 
             save_every: int,
             total_epochs: int, 
             batch_size: int,
             train_dataset: SecondaryStructureDataset,
             val_dataset: SecondaryStructureDataset,
             test_dataset: SecondaryStructureDataset):
    ddp_setup(rank, world_size)
    
    torch.manual_seed(432250798121)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler)

    run(
        save_every=save_every,
        total_epochs=total_epochs,
        batch_size=batch_size,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        test_dataloader=test_loader,
        device=rank,
        distributed=True
    )

In [None]:
if __name__ == "__main__":
    train_df = pd.read_csv("data/pdb/train_data.csv")
    val_df = pd.read_csv("data/pdb/val_data.csv")
    test_df = pd.read_csv("data/pdb/test_data.csv")

    train_dataset = SecondaryStructureDataset(train_df)
    val_dataset = SecondaryStructureDataset(val_df)
    test_dataset = SecondaryStructureDataset(test_df)

    training_loader = DataLoader(train_dataset, batch_size = 8, shuffle = True, sampler=DistributedSampler(train_dataset) if distributed else None)
    val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = True, sampler=DistributedSampler(val_dataset) if distributed else None)
    test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = True, sampler=DistributedSampler(test_dataset) if distributed else None)
    if distributed:
        world_size = torch.cuda.device_count()
        mp.spawn(dist_run, args=(world_size, 32, 10, 8, train_dataset, val_dataset, test_dataset), nprocs=world_size, join=True)
    else:
        run(
            save_every=32,
            total_epochs=10,
            train_dataloader=training_loader,
            val_dataloader=val_loader,
            test_dataloader=test_loader,
            device=device,
            distributed=False
        )
