In [None]:
from torch import nn
import torch

device = "cpu"
distributed = False

if torch.cuda.is_available():
    device = "cuda"
    distributed = torch.cuda.device_count() > 1
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='cuda')

In [None]:
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

In [None]:
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv, GraphNorm

In [None]:
import pandas as pd

In [None]:
class SecondaryStructureDataset(Dataset):
    amino_acids = "ACDEFGHIKLMNPQRSTVXWY"
    dssp_types = "GHITEBSP-"

    @staticmethod
    def _to_one_hot(seq: str, charset: str) -> torch.Tensor:
        one_hot = torch.zeros(len(seq), len(charset) + 1, dtype=torch.float)
        for i, char in enumerate(seq):
            if char in charset:
                one_hot[i, charset.index(char) + 1] = 1.0
        return one_hot

    @staticmethod
    def _to_index(seq: str, charset: str) -> torch.Tensor:
        indices = torch.zeros(len(seq), dtype=torch.long)
        for i, char in enumerate(seq):
            if char in charset:
                indices[i] = charset.index(char) + 1
        return indices

    @staticmethod
    def _to_label_indices(seq: str, charset: str) -> torch.Tensor:
        labels = torch.zeros(len(seq), dtype=torch.long)
        for i, char in enumerate(seq):
            if char in charset:
                labels[i] = charset.index(char)
        return labels

    @staticmethod
    def collate_fn(batch):
        # Pad x; flatten y to valid positions; return lengths and mask of valid tokens
        xs, ys = zip(*batch)
        xs_padded = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
        lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
        mask = (xs_padded != 0)
        y_flat = torch.cat([y[:l] for y, l in zip(ys, lengths)], dim=0)
        return xs_padded, y_flat, lengths, mask


    def __init__(self,
                 df: pd.DataFrame,
                 seq_col: str = "sequence",
                 ss_col: str = "secondary_structure"):
        super().__init__()
        self.df = df
        self.seq_col = seq_col
        self.ss_col = ss_col
        self.max_seq_length = df[seq_col].str.len().max()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]

        sequence = row[self.seq_col]
        secondary_structure = row[self.ss_col]

        x = self._to_index(sequence, self.amino_acids)
        y = self._to_label_indices(secondary_structure, self.dssp_types)

        return x, y

    def class_weights(self):
        ss_cat = "".join(self.df[self.ss_col])
        counts = {char: ss_cat.count(char) for char in self.dssp_types}
        total = sum(counts.values())
        weights = {char: total / count if count > 0 else 0 for char, count in counts.items()}
        return torch.tensor([weights[char] for char in self.dssp_types], dtype=torch.float)


In [None]:
# Parametric Exponential Linear Unit (PELU) activation function
class PELU(nn.Module):
    def __init__(self, alpha=1.0):
        super(PELU, self).__init__()
        self.log_alpha = nn.Parameter(torch.log(torch.tensor(alpha)))

    def forward(self, x):
        alpha = torch.exp(self.log_alpha)
        return torch.where(x >= 0, x, alpha * (torch.exp(x) - 1))

In [None]:
class GCN(nn.Module):
    def __init__(self,
                 in_channels: int = len(SecondaryStructureDataset.amino_acids),
                 out_channels: int = len(SecondaryStructureDataset.dssp_types),
                 hidden_channels: int = 64,
                 num_layers: int = 3):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(GraphNorm(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(GraphNorm(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.act = PELU()

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index

        for i, (conv, norm) in enumerate(zip(self.convs[:-1], self.norms)):
            identity = x
            x = conv(x, edge_index)
            x = norm(x)
            x = self.act(x)
            if i != 0:
              x = x + identity

        x = self.convs[-1](x, edge_index)
        return x

In [None]:
class GraphBuilder(nn.Module):
    def __init__(self,
                 local_dist: int = 0,
                 top_k: int = 32):
        super(GraphBuilder, self).__init__()

        self.local_dist = local_dist
        self.top_k = top_k

    def forward(self, x: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor) -> Batch:
        # x: (N, L, E) embedded features
        # lengths: (N,) true sequence lengths
        # mask: (N, L) valid positions
        batch_size, seq_length, embed_dim = x.size()

        # Pairwise similarity scores per batch
        scores = torch.matmul(x, x.transpose(1, 2))  # (N, L, L)

        # Ignore self and local neighborhood within |i-j| <= local_dist
        idx = torch.arange(seq_length, device=x.device)
        dist = (idx.unsqueeze(0) - idx.unsqueeze(1)).abs()
        scores = scores.masked_fill(dist.unsqueeze(0) <= self.local_dist, float('-inf'))

        # Mask padded positions
        pad_mask = ~mask  # padded=True
        scores = scores.masked_fill(pad_mask.unsqueeze(1), float('-inf'))
        scores = scores.masked_fill(pad_mask.unsqueeze(2), float('-inf'))

        data_list = []
        for b in range(batch_size):
            valid_len = int(lengths[b].item())
            if valid_len < 2:
                data_list.append(Data(x=x[b, :valid_len, :], edge_index=torch.empty((2, 0), dtype=torch.long, device=x.device), edge_attr=torch.empty((0, 1), device=x.device), num_nodes=valid_len))
                continue

            scores_b = scores[b, :valid_len, :valid_len]
            k = min(self.top_k, max(valid_len - 1, 1))

            # Top-k neighbor indices per node
            topk_vals, topk_idx = torch.topk(scores_b, k=k, dim=-1)  # (valid_len, k)

            src = torch.arange(valid_len, device=x.device).unsqueeze(1).expand(valid_len, k)
            flat_src = src.reshape(-1)
            flat_dst = topk_idx.reshape(-1)
            flat_scores = topk_vals.reshape(-1)
            keep = torch.isfinite(flat_scores)
            flat_src = flat_src[keep]
            flat_dst = flat_dst[keep]

            if flat_src.numel() == 0:
                data_list.append(Data(x=x[b, :valid_len, :], edge_index=torch.empty((2, 0), dtype=torch.long, device=x.device), edge_attr=torch.empty((0, 1), device=x.device), num_nodes=valid_len))
                continue

            directed_edges = torch.stack([flat_src, flat_dst], dim=0)  # (2, M)

            # Build all bidirectional edges first, then deduplicate directed edges
            bidir_initial = torch.cat([directed_edges, directed_edges.flip(0)], dim=1)  # (2, 2M)
            bidir_edges, _ = torch.unique(bidir_initial, dim=1, return_inverse=True)

            # Edge scores per directed edge
            edge_scores = scores_b[bidir_edges[0], bidir_edges[1]].unsqueeze(1)

            data_list.append(Data(
                x=x[b, :valid_len, :],
                edge_index=bidir_edges,
                edge_attr=edge_scores,
                num_nodes=valid_len,
                edge_score=edge_scores
            ))

        return Batch.from_data_list(data_list)

In [None]:
class Embedder(nn.Module):
    def __init__(self,
                 amino_acid_dim: int = len(SecondaryStructureDataset.amino_acids),
                 max_length: int = 2048,
                 embed_dim: int = 32,
                 kernel_size: int = 7,
                 num_layers: int = 3):
        super(Embedder, self).__init__()

        # amino_acid_dim is the count of characters; indices run 0..amino_acid_dim
        # so embedding needs +1 to accommodate padding (0) and max index (amino_acid_dim)
        self.amino_acid_dim = amino_acid_dim
        self.embed_dim = embed_dim

        self.embedding = nn.Embedding(amino_acid_dim + 1, embed_dim, padding_idx=0)

        # Sinusoidal positional encoding
        self.register_buffer("positional_encoding", self._gen_positional_encoding(max_length))

        self.conv_layers = nn.ModuleList()

        for _ in range(num_layers):
            self.conv_layers.extend([
                nn.Conv1d(embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size//2),
                nn.GELU(),
            ])

    def _gen_positional_encoding(self, length: int) -> torch.Tensor:
        pe = torch.zeros(length, self.embed_dim)
        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / self.embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    def forward(self, x: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # x: (N, L) LongTensor of indices
        # lengths: (N,) LongTensor of sequence lengths
        # mask: (N, L) BoolTensor where True indicates valid positions
        batch_size, seq_length = x.size()

        x = self.embedding(x)  # (N, L, E)

        # Ensure positional encoding on same device and dtype
        pos = self.positional_encoding[:seq_length, :].unsqueeze(0).to(x.device, dtype=x.dtype)
        x = x + pos  # Add positional encoding

        # Apply mask to zero out embeddings at padding positions
        x = x * mask.unsqueeze(2).float()

        x = x.transpose(1, 2)  # (N, E, L)

        for layer in self.conv_layers:
            x = layer(x)

        x = x.transpose(1, 2)  # (N, L, E)

        # Apply mask again
        x = x * mask.unsqueeze(2).float()

        return x


In [None]:
class SecondaryStructurePredictor(nn.Module):
    def __init__(self,
                 embed_dim: int = 32,
                 embed_kernel_size: int = 7,
                 embed_num_layers: int = 3,
                 embed_max_length: int = 2048,
                 embed_amino_acid_dim: int = len(SecondaryStructureDataset.amino_acids),
                 gcn_out_channels: int = len(SecondaryStructureDataset.dssp_types),
                 gcn_hidden_channels: int = 64,
                 gcn_num_layers: int = 3,
                 edge_local_dist: int = 0,
                 edge_top_k: int = 32):
        super(SecondaryStructurePredictor, self).__init__()
        self.embedder = Embedder(
            amino_acid_dim=embed_amino_acid_dim,
            max_length=embed_max_length,
            embed_dim=embed_dim,
            kernel_size=embed_kernel_size,
            num_layers=embed_num_layers
        )
        self.edge_model = GraphBuilder(
            local_dist=edge_local_dist,
            top_k=edge_top_k
        )
        self.gcn_model = GCN(
            in_channels=embed_dim,
            out_channels=gcn_out_channels,
            hidden_channels=gcn_hidden_channels,
            num_layers=gcn_num_layers
        )

    def forward(self, x: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.embedder(x, lengths, mask)  # (N, L, E)
        graph_data = self.edge_model(x, lengths, mask)
        out = self.gcn_model(graph_data)

        return out

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall

In [None]:
def train(epoch: int,
          model: SecondaryStructurePredictor,
          loader: DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler,
          criterion: nn.Module,
          device: str | torch.device | int,
          batch_print_freq: int = 32,
          writer: SummaryWriter = None,
          distributed: bool = False) -> float:
    model.train()

    total_loss = torch.zeros(1, device=device, requires_grad=False)
    running_loss = torch.zeros(1, device=device, requires_grad=False)
    total_count = torch.zeros(1, device=device, requires_grad=False)
    running_count = torch.zeros(1, device=device, requires_grad=False)

    num_class = len(SecondaryStructureDataset.dssp_types)

    metrics = {}

    for avg in [None, 'macro', 'micro', 'weighted']:
        metrics[f'accuracy_{avg}'] = Accuracy(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'f1_score_{avg}'] = F1Score(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'precision_{avg}'] = Precision(num_classes=num_class, task='multiclass', average=avg)
        metrics[f'recall_{avg}'] = Recall(num_classes=num_class, task='multiclass', average=avg)

    distTorchMetricsBatch = MetricCollection(metrics=metrics).to(device)

    distTorchMetricsEpoch = MetricCollection(metrics=metrics).to(device)

    for i, (x, y, lengths, mask) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)
        lengths = lengths.to(device)
        mask = mask.to(device)

        optimizer.zero_grad()
        out = model(x, lengths, mask)
        loss = criterion(out, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()

        node_count = y.numel()
        total_loss += loss.item() * node_count
        running_loss += loss.item() * node_count
        total_count += node_count
        running_count += node_count

        scheduler.step()

        preds = out.argmax(dim=1)
        distTorchMetricsBatch.update(preds, y)
        distTorchMetricsEpoch.update(preds, y)

        if i % batch_print_freq == 0 and (writer is not None or distributed):
            if distributed:
                torch.distributed.all_reduce(
                    running_loss,
                    op=torch.distributed.ReduceOp.SUM
                )
                torch.distributed.all_reduce(
                    running_count,
                    op=torch.distributed.ReduceOp.SUM
                )

            if writer is not None:
                local_stats = distTorchMetricsBatch.compute()
                for key, value in local_stats.items():
                    if distributed:
                        torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
                        value /= torch.distributed.get_world_size()
                    if value.dim() > 0 and value.size(0) > 1:
                        for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                            writer.add_scalar(f"{key}/train/batch/{cls}", value[idx].item(), epoch * len(loader) + i)
                    else:
                        writer.add_scalar(f"{key}/train/batch", value.item(), epoch * len(loader) + i)

                writer.add_scalar("Loss/train/batch", running_loss.item() / running_count.item(), epoch * len(loader) + i)
                distTorchMetricsBatch.reset()

            running_loss.zero_()
            running_count.zero_()

    if distributed:
        torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
        torch.distributed.all_reduce(total_count, op=torch.distributed.ReduceOp.SUM)
    total_loss /= total_count.item()

    if writer is not None:
        writer.add_scalar("Loss/train/epoch", total_loss.item(), epoch)

    epoch_stats = distTorchMetricsEpoch.compute()

    for key, value in epoch_stats.items():
        if distributed:
            torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
            value /= torch.distributed.get_world_size()
        if writer is not None:
            if value.dim() > 0 and value.size(0) > 1:
                for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                    writer.add_scalar(f"{key}/train/epoch/{cls}", value[idx].item(), epoch)
            else:
                writer.add_scalar(f"{key}/train/epoch", value.item(), epoch)

    return total_loss.item(), epoch_stats['accuracy_macro'].item(), epoch_stats['f1_score_macro'].item(), epoch_stats['precision_macro'].item(), epoch_stats['recall_macro'].item()

In [None]:
def evaluate(model: SecondaryStructurePredictor,
            loader: DataLoader,
            criterion: nn.Module,
            dataset_name: str,
            device: str | torch.device | int,
            epoch: int = 0,
            writer: SummaryWriter = None,
            distributed: bool = False) -> float:
    model.eval()

    total_loss = torch.zeros(1, device=device, requires_grad=False)
    total_count = torch.zeros(1, device=device, requires_grad=False)

    num_class = len(SecondaryStructureDataset.dssp_types)

    metricDict = {}

    for avg in [None, 'macro', 'micro', 'weighted']:
        metricDict[f'accuracy_{avg}'] = Accuracy(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'f1_score_{avg}'] = F1Score(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'precision_{avg}'] = Precision(num_classes=num_class, task='multiclass', average=avg)
        metricDict[f'recall_{avg}'] = Recall(num_classes=num_class, task='multiclass', average=avg)

    metrics = MetricCollection(metricDict).to(device)

    with torch.no_grad():
        for (x, y, lengths, mask) in loader:
            x = x.to(device)
            y = y.to(device)
            lengths = lengths.to(device)
            mask = mask.to(device)

            out = model(x, lengths, mask)
            loss = criterion(out, y)

            node_count = y.numel()
            total_loss += loss.item() * node_count
            total_count += node_count

            metrics.update(out.argmax(dim=1), y)

    if distributed:
        torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
        torch.distributed.all_reduce(total_count, op=torch.distributed.ReduceOp.SUM)

    total_loss /= total_count.item()

    # Compute metrics
    stats = metrics.compute()

    # Handle distributed metrics
    if distributed:
        for key, value in stats.items():
            torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM)
            value /= torch.distributed.get_world_size()

    # Write to tensorboard
    if writer is not None:
        writer.add_scalar(f"Loss/{dataset_name}/epoch", total_loss.item(), epoch)
        for key, value in stats.items():
            if value.dim() > 0 and value.size(0) > 1:
                for idx, cls in enumerate(SecondaryStructureDataset.dssp_types):
                    writer.add_scalar(f"{key}/{dataset_name}/epoch/{cls}", stats[key][idx].item(), epoch)
            else:
                writer.add_scalar(f"{key}/{dataset_name}/epoch", value.item(), epoch)

    return total_loss.item(), stats['f1_score_macro'].item(), stats['recall_macro'].item(), stats['precision_macro'].item(), stats['accuracy_macro'].item()

In [None]:
import pathlib
from datetime import datetime

In [None]:
def ddp_setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.cuda.set_device(rank)
    init_process_group("gloo", rank=rank, world_size=world_size)

In [None]:
def run(
        save_every: int,
        total_epochs: int,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        test_dataloader: DataLoader,
        device: str | torch.device | int,
        max_seq_length: int,
        save_path: str = "",
        distributed: bool = False):
    model = SecondaryStructurePredictor(
        embed_dim=32,
        embed_kernel_size=7,
        embed_num_layers=5,
        embed_max_length=max_seq_length,
        gcn_out_channels=len(SecondaryStructureDataset.dssp_types),
        gcn_hidden_channels=128,
        gcn_num_layers=10,
        edge_local_dist=0,
        edge_top_k=32
    ).to(device)
    model_name = f"SecondCount-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

    if distributed:
        model = DDP(model, device_ids=[device])

    writer = SummaryWriter(os.path.join(save_path, f"runs/{model_name}")) if not distributed or torch.distributed.get_rank() == 0 else None

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(weight=train_dataloader.dataset.class_weights().to(device))

    steps_per_epoch = len(train_dataloader)

    warmup = torch.optim.lr_scheduler.LinearLR(
        optimizer=optimizer,
        start_factor=0.1,
        total_iters=steps_per_epoch // 5
    )

    cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer=optimizer,
        T_0=steps_per_epoch * 25
    )

    scheduler = torch.optim.lr_scheduler.ChainedScheduler(
        schedulers=[warmup, cosine]
    )

    best_val_loss = float('inf')

    pathlib.Path(os.path.join(save_path, f"models/{model_name}")).mkdir(parents=True, exist_ok=True)

    for epoch in range(total_epochs):
        print(f"Epoch {epoch + 1}/{total_epochs}")

        train_loss, train_accuracy, train_f1, train_precision, train_recall = train(
            epoch, model, train_dataloader, optimizer, scheduler, criterion, device,
            batch_print_freq=save_every, writer=writer, distributed=distributed
        )
        if not distributed or torch.distributed.get_rank() == 0:
            print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}, F1: {train_f1:.4f}, Precision: {train_precision:.4f}, Recall: {train_recall:.4f}")

        scheduler.step()

        val_loss, val_f1, val_recall, val_precision, val_accuracy = evaluate(
            model, val_dataloader, criterion, "val", device, epoch, writer, distributed,
        )
        if not distributed or torch.distributed.get_rank() == 0:
            print(f"Validation Loss: {val_loss:.4f}, F1: {val_f1:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, Accuracy: {val_accuracy:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(save_path, f"models/{model_name}/best_model.pth"))
            print(f"New best model saved with validation loss: {best_val_loss:.4f}")

        save_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "best_val_loss": best_val_loss,
        }

        torch.save(save_state, os.path.join(save_path, f"models/{model_name}/checkpoint-{epoch}{f"-{device}" if distributed else ""}.pth"))

    if not distributed or torch.distributed.get_rank() == 0:
        print("Training complete. Evaluating on test set...")
    test_loss, test_f1, test_recall, test_precision, test_accuracy = evaluate(
        model, test_dataloader, criterion, "test", device, writer=writer, distributed=distributed
    )

    if not distributed or torch.distributed.get_rank() == 0:
        print(f"Test Loss: {test_loss:.4f}, F1: {test_f1:.4f}, Recall: {test_recall:.4f}, Precision: {test_precision:.4f}, Accuracy: {test_accuracy:.4f}")
        writer.close()

    if distributed:
        destroy_process_group()

In [None]:
def dist_run(rank: int,
             world_size: int,
             save_every: int,
             total_epochs: int,
             batch_size: int,
             train_dataset: SecondaryStructureDataset,
             val_dataset: SecondaryStructureDataset,
             test_dataset: SecondaryStructureDataset,
             max_seq_length: int):
    ddp_setup(rank, world_size)

    torch.manual_seed(432250798121)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler)

    run(
        save_every=save_every,
        total_epochs=total_epochs,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        test_dataloader=test_loader,
        device=rank,
        max_seq_length=max_seq_length,
        distributed=True,
        save_path="drive/MyDrive/SecondCount"
    )

In [None]:
if __name__ == "__main__":
    train_df = pd.read_csv("data/pdb/train_data.csv")
    val_df = pd.read_csv("data/pdb/val_data.csv")
    test_df = pd.read_csv("data/pdb/test_data.csv")

    # Calculate global max_seq_length
    all_sequences = pd.concat([train_df["sequence"], val_df["sequence"], test_df["sequence"]])
    global_max_seq_length = all_sequences.str.len().max()
    print(f"Global maximum sequence length: {global_max_seq_length}")

    train_dataset = SecondaryStructureDataset(train_df)
    val_dataset = SecondaryStructureDataset(val_df)
    test_dataset = SecondaryStructureDataset(test_df)

    batch_size = 8

    # The sampler will handle shuffling if distributed is True
    training_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=not distributed, sampler=DistributedSampler(train_dataset) if distributed else None, collate_fn=SecondaryStructureDataset.collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=not distributed, sampler=DistributedSampler(val_dataset) if distributed else None, collate_fn=SecondaryStructureDataset.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=not distributed, sampler=DistributedSampler(test_dataset) if distributed else None, collate_fn=SecondaryStructureDataset.collate_fn)

    if distributed:
        world_size = torch.cuda.device_count()
        mp.spawn(dist_run, args=(world_size, 32, 10, batch_size, train_dataset, val_dataset, test_dataset, global_max_seq_length), nprocs=world_size, join=True)
    else:
        run(
            save_every=32,
            total_epochs=10,
            train_dataloader=training_loader,
            val_dataloader=val_loader,
            test_dataloader=test_loader,
            device=device,
            max_seq_length=global_max_seq_length, # Pass global_max_seq_length
            distributed=False
        )