In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!sudo apt install build-essential

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
build-essential is already the newest version (12.9ubuntu3).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.


In [None]:
%pip install -r "/content/drive/MyDrive/Protein-binding/requirements.txt"
%pip install datasets mdtraj dssp
%pip install numpy --no-cache-dir
%pip install pandas==2.2.0
%pip install pykan

[31mERROR: Could not find a version that satisfies the requirement dssp (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for dssp[0m[31m


In [None]:
import torch
import ast
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F
import warnings
import mdtraj as md
import esm
import gc
import pickle
import torch.optim as optim

from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_add_pool
from Bio import SeqIO
from Bio import PDB
from Bio.PDB.DSSP import DSSP
from Bio.PDB import Selection

from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils import seq1
from torch_geometric.data import Data
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          AutoModelForMaskedLM, DataCollatorForTokenClassification,
                           EsmForMaskedLM, EsmTokenizer, EsmModel, EsmForTokenClassification,
                           TrainingArguments, Trainer, TrainerCallback
                        )
from kan import KAN
from transformers.trainer_callback import ProgressCallback
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pprint import pprint
from datasets import Dataset
from datetime import datetime
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from glob import glob
from loguru import logger

warnings.filterwarnings("ignore", message="Ignoring unrecognized record 'END'")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

### Preparing train-test dataset

In [None]:
initial_train_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_train_binding_sites_df.csv")
initial_train_df['binding_sites'] = initial_train_df['binding_sites'].apply(ast.literal_eval)
initial_train_df['any_ligand_binding_sites'] = initial_train_df['any_ligand_binding_sites'].apply(ast.literal_eval)
initial_train_df['metal_binding'] = initial_train_df['metal_binding'].apply(ast.literal_eval)
initial_train_df['small_binding'] = initial_train_df['small_binding'].apply(ast.literal_eval)
initial_train_df['nuclear_binding'] = initial_train_df['nuclear_binding'].apply(ast.literal_eval)

In [None]:
np.random.seed(42)
excluded_protein_id = ['Q9NZV6']

train_df = initial_train_df[~initial_train_df['prot_id'].isin(excluded_protein_id)]

In [None]:
# test_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_test_binding_sites_df.csv")
test_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/independent_set/grouped_test_46_new_binding_sites.csv")

test_df['binding_sites'] = test_df['binding_sites'].apply(ast.literal_eval)
test_df['any_ligand_binding_sites'] = test_df['any_ligand_binding_sites'].apply(ast.literal_eval)
test_df['metal_binding'] = test_df['metal_binding'].apply(ast.literal_eval)
test_df['small_binding'] = test_df['small_binding'].apply(ast.literal_eval)
test_df['nuclear_binding'] = test_df['nuclear_binding'].apply(ast.literal_eval)

In [None]:
# Initial sequences
test_seq = test_df['sequence'].tolist()
test_labels = test_df['small_binding'].tolist() # Edit labels

train_seq = train_df['sequence'].tolist()
train_labels = train_df['small_binding'].tolist() # Edit labels

saved_model_name = '/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_small_binding_Test46_23Apr_model.pth'

### Tokenization and get embeddings from ESM-2 language model

In [None]:
pretrained_model = "facebook/esm2_t33_650M_UR50D"

tokenizer = EsmTokenizer.from_pretrained(pretrained_model)
max_sequence_length = 1000

train_tokenized = tokenizer(train_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def get_embeddings_list(tokenized_dataset, batch_size, model_name="facebook/esm2_t33_650M_UR50D", device="cuda",
                       embedding_mode="multi_layer", num_layers=4, return_hidden_states=True, return_attentions=False):
    """
    Extract ESM-2 embeddings with multi-layer aggregation or attention-guided pooling.

    Args:
        tokenized_dataset: Dict with 'input_ids', 'attention_mask', and optionally 'sequence_id'.
        batch_size: Number of sequences per batch.
        model_name: Pretrained ESM-2 model name.
        device: Device to run model on ('cuda' or 'cpu').
        embedding_mode: 'multi_layer' for layer aggregation, 'attention_guided' for attention pooling.
        num_layers: Number of layers to aggregate (for multi_layer mode).
        return_hidden_states: Whether to return hidden states.
        return_attentions: Whether to return attention weights (required for attention_guided).

    Returns:
        List of dicts with sequence_id and per-residue embeddings.
    """
    # Validate inputs
    if embedding_mode == "attention_guided" and not return_attentions:
        raise ValueError("Attention-guided pooling requires return_attentions=True")

    # Initialize model
    model = EsmModel.from_pretrained(model_name).to(device)
    model.eval()

    # Extract input tensors
    ids_list = tokenized_dataset['input_ids'].to(device)  # Shape: [num_sequences, max_length]
    attention_mask_list = tokenized_dataset['attention_mask'].to(device)
    sequence_ids = tokenized_dataset.get('sequence_id', list(range(len(ids_list))))

    num_batches = (len(ids_list) + batch_size - 1) // batch_size
    embeddings_list = []

    for i in tqdm(range(num_batches), total=num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(ids_list))

        batch_ids = ids_list[start_idx:end_idx]  # Shape: [batch_size, max_length]
        batch_attention_mask = attention_mask_list[start_idx:end_idx]
        batch_seq_ids = sequence_ids[start_idx:end_idx]

        with torch.no_grad():
            outputs = model(
                input_ids=batch_ids,
                attention_mask=batch_attention_mask,
                output_hidden_states=return_hidden_states,
                output_attentions=(embedding_mode == "attention_guided")
            )

            if embedding_mode == "multi_layer":
                # Aggregate last N layers
                hidden_states = outputs.hidden_states[-num_layers:]  # List of (batch_size, seq_len, embedding_dim)
                aggregated_embeddings = torch.stack(hidden_states, dim=0).mean(dim=0)  # Average across layers
                # Alternative: Concatenate layers (uncomment to use)
                # aggregated_embeddings = torch.cat(hidden_states, dim=-1)  # Shape: (batch_size, seq_len, num_layers * embedding_dim)

                # Process each sequence
                for j in range(aggregated_embeddings.shape[0]):
                    mask = batch_attention_mask[j].bool()
                    seq_embeddings = aggregated_embeddings[j][mask][1:-1]  # Remove <cls>, <eos>
                    embeddings_list.append({
                        'sequence_id': batch_seq_ids[j],
                        'embeddings': seq_embeddings.cpu().numpy()
                    })

            elif embedding_mode == "attention_guided":
                # Attention-guided pooling
                hidden_states = outputs.hidden_states[-1]  # Last layer: (batch_size, seq_len, embedding_dim)
                attentions = outputs.attentions[-1]  # Last layer attention: (batch_size, num_heads, seq_len, seq_len)

                # Process each sequence individually to handle variable lengths
                for j in range(hidden_states.shape[0]):
                    mask = batch_attention_mask[j].bool()
                    seq_hidden = hidden_states[j][mask]  # (valid_len, embedding_dim)
                    seq_attention = attentions[j][:, mask][:, :, mask]  # (num_heads, valid_len, valid_len)
                    # Average attention across heads
                    attention_weights = seq_attention.mean(dim=0)  # (valid_len, valid_len)
                    attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)
                    # Weighted embedding
                    weighted_embedding = torch.matmul(attention_weights, seq_hidden)  # (valid_len, embedding_dim)
                    seq_embeddings = weighted_embedding[1:-1]  # Remove <cls>, <eos>
                    embeddings_list.append({
                        'sequence_id': batch_seq_ids[j],
                        'embeddings': seq_embeddings.cpu().numpy()
                    })

            else:
                raise ValueError("embedding_mode must be 'multi_layer' or 'attention_guided'")

    return embeddings_list

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# train_embeddings = get_embeddings_list(train_tokenized, batch_size = 8, model_name = pretrained_model, device= "cuda",
#                                        embedding_mode = "multi_layer", num_layers=2, return_hidden_states = True, return_attentions = False)

# test_embeddings = get_embeddings_list(test_tokenized, batch_size = 8, model_name = pretrained_model, device = "cuda",
#                                        embedding_mode = "multi_layer", num_layers=2, return_hidden_states = True, return_attentions = False)

In [None]:
# with open('/content/drive/MyDrive/Protein-binding/train_multi_layer_embeddings.pkl', 'wb') as f:
#        pickle.dump(train_embeddings, f)

# with open('/content/drive/MyDrive/Protein-binding/test_multi_layer_embeddings.pkl', 'wb') as f:
#     pickle.dump(test_embeddings, f)

In [None]:
with open('/content/drive/MyDrive/Protein-binding/train_multi_layer_embeddings.pkl', 'rb') as f:
       train_embeddings = pickle.load(f)

with open('/content/drive/MyDrive/Protein-binding/test_multi_layer_embeddings.pkl', 'rb') as f:
    test_embeddings = pickle.load(f)

### Features extraction

In [None]:
def get_structure(prot_id, pdb_file):
    parser = PDB.PDBParser()
    structure = parser.get_structure(prot_id, pdb_file)
    return structure

def extract_coordinates(structure):
    # Extract Cα coordinates (central carbon atom)
    coordinates = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if "CA" in residue:  # Get Cα atom
                    ca_atom = residue["CA"]
                    coord = ca_atom.get_coord()  # Returns numpy array [x, y, z]
                    coordinates.append(coord)

    return coordinates

def get_amino_acid_types(structure):
    amino_acids = []
    for model in structure:
        for chain in model:
            for residue in chain:
                # Get residue name (3-letter code)
                if not is_aa(residue):
                    continue
                res_name = residue.get_resname()
                # Convert to 1-letter code if needed
                one_letter = seq1(res_name)
                amino_acids.append(one_letter)
    return amino_acids

def get_secondary_structure_mdtraj(pdb_file, sequence_length):
    """
    Extract secondary structure features from a .PDB file using mdtraj.

    Args:
        pdb_file (str): Path to the .PDB file.
        sequence_length (int): Length of the protein sequence.

    Returns:
        dict: A dictionary with:
            - 'raw': List of raw secondary structure codes.
            - 'one_hot': Tensor of one-hot encoded secondary structure.
    """
    # Load the .PDB file with mdtraj
    traj = md.load(pdb_file)

    # Compute secondary structure
    ss = md.compute_dssp(traj)[0]  # Returns codes like 'H', 'E', 'C', '-'('NA')

    # One-hot encode
    ss_onehot = []
    for code in ss:
        ss_onehot.append([
            1 if code == 'H' else 0,
            1 if code == 'E' else 0,
            1 if code == 'T' else 0,
            1 if code == 'C' or code == 'NA' else 0
        ])

    ss_onehot = torch.tensor(ss_onehot, dtype=torch.float32)

    # Pad with zeros to match sequence length
    if ss_onehot.shape[0] < sequence_length:
        padding = torch.zeros((sequence_length - ss_onehot.shape[0], ss_onehot.shape[1]), dtype=torch.float32)
        ss_onehot = torch.cat([ss_onehot, padding], dim=0)

    return {
        "raw": ss.tolist(),
        "one_hot": ss_onehot
    }

def calculate_residue_distances(coordinates):
    """
    Calculate pairwise distances between residues in a protein structure.
    Args: Coordinates (list): List of residue atom's coordinates.
    Returns:
        np.ndarray: 2D array of pairwise distances.
    """
    num_residues = len(coordinates)
    distances = np.zeros((num_residues, num_residues))

    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            dist = np.linalg.norm(coordinates[i] - coordinates[j])
            distances[i, j] = distances[j, i] = dist

    return distances

def get_dihedral_angles(pdb_file, sequence_length):
    traj = md.load(pdb_file)
    # Compute phi and psi angles
    phi_indices, phi_angles = md.compute_phi(traj)
    psi_indices, psi_angles = md.compute_psi(traj)

    # Convert to degrees and create tensors
    phi_angles = torch.tensor(np.degrees(phi_angles[0]), dtype=torch.float32).unsqueeze(1)  # Shape: [num_residues-1, 1]
    psi_angles = torch.tensor(np.degrees(psi_angles[0]), dtype=torch.float32).unsqueeze(1)  # Shape: [num_residues-1, 1]

    # Pad with zeros to match sequence length
    while phi_angles.shape[0] < sequence_length:
        phi_angles = torch.cat([phi_angles, torch.zeros(1, 1, dtype=torch.float32)], dim=0)
    while psi_angles.shape[0] < sequence_length:
        psi_angles = torch.cat([psi_angles, torch.zeros(1, 1, dtype=torch.float32)], dim=0)

    return phi_angles, psi_angles


def get_b_factors(structure, sequence_length):
    b_factors = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if "CA" in residue:
                    ca_atom = residue["CA"]
                    b_factor = ca_atom.get_bfactor()
                    b_factors.append(b_factor)

    # Pad with zeros if b_factors length is less than sequence length
    while len(b_factors) < sequence_length:
        b_factors.append(0.0)

    return torch.tensor(b_factors, dtype=torch.float32).unsqueeze(1)



In [None]:
# Compute RSA values for amino acids
MAX_SASA = {
    'ALA': 129.0, 'ARG': 274.0, 'ASN': 195.0, 'ASP': 193.0, 'CYS': 167.0,
    'GLU': 223.0, 'GLN': 225.0, 'GLY': 104.0, 'HIS': 224.0, 'ILE': 197.0,
    'LEU': 201.0, 'LYS': 236.0, 'MET': 224.0, 'PHE': 240.0, 'PRO': 159.0,
    'SER': 155.0, 'THR': 172.0, 'TRP': 285.0, 'TYR': 263.0, 'VAL': 174.0
}

def compute_rsa(structure_file, chain_id=None):
    """
    Compute RSA for each residue in a protein structure.

    Args:
        structure_file (str): Path to the protein structure file (e.g., PDB file).
        chain_id (str, optional): Chain ID to analyze (if None, uses the first chain).

    Returns:
        list: List of (residue_name, residue_id, rsa) tuples.
    """
    traj = md.load(structure_file)

    if chain_id:
        chain = next(c for c in traj.topology.chains if c.chain_id == chain_id)
        traj = traj.atom_slice([atom.index for atom in traj.topology.atoms if atom.residue.chain == chain])

    sasa = md.shrake_rupley(traj, mode='residue')[0]  # Shape: [n_residues]
    rsa_values = []
    for i, residue in enumerate(traj.topology.residues):
        if not residue.is_protein:
            continue
        res_name = residue.name
        res_id = residue.resSeq
        residue_sasa = sasa[i]
        max_sasa = MAX_SASA.get(res_name, 0.0)
        if max_sasa == 0.0:
            print(f"Warning: No max SASA value for residue {res_name}. Skipping.")
            continue
        rsa = residue_sasa / max_sasa if max_sasa > 0 else 0.0
        rsa = min(rsa, 1.0)
        rsa_values.append((res_name, res_id, rsa))
    # return rsa_values
    return torch.tensor([rsa for _, _, rsa in rsa_values], dtype=torch.float32).unsqueeze(1)


def calculate_depth(structure, sequence_length):
    depth_per_residue = {}
    atoms = Selection.unfold_entities(structure, "A")  # Get all atoms
    for residue in Selection.unfold_entities(structure, "R"):
        res_id = (residue.get_parent().id, residue.id[1])
        min_dist = float("inf")
        for atom in residue:
            for other_atom in atoms:
                if other_atom.get_parent() != residue:
                    dist = np.linalg.norm(atom.coord - other_atom.coord)
                    min_dist = min(min_dist, dist)
        depth_per_residue[res_id] = min_dist

    # Assign default value for missing residues
    depth_values = [0.0] * sequence_length  # Or use average depth if available
    for i, residue in enumerate(structure.get_residues()):
        res_id = (residue.get_parent().id, residue.id[1])
        if res_id in depth_per_residue:
            depth_values[i] = depth_per_residue[res_id]

    return torch.tensor(depth_values, dtype=torch.float32).unsqueeze(1)


In [None]:
def fuse_features(esm2_embeddings, ss_onehot, phi_angles,
                  psi_angles, b_factors):
    """
    Fuse ESM-2 embeddings with structural features.

    Args:
        esm2_embeddings (torch.Tensor): Shape [num_residues, 1280]
        ss_onehot (torch.Tensor): Shape [num_residues, 4]
        phi_angles (torch.Tensor): Shape [num_residues, 1]
        psi_angles (torch.Tensor): Shape [num_residues, 1]
        b_factors (torch.Tensor): Shape [num_residues, 1]

    Returns:
        torch.Tensor: Fused node features, shape [num_residues, 1287] (1280 + 4 + 1 + 1 + 1)
    """
    # Ensure all features have the same length
    num_residues = esm2_embeddings.shape[0]

    # Print shapes for debugging
    print(f"esm2_embeddings shape: {esm2_embeddings.shape}")
    print(f"ss_onehot shape: {ss_onehot.shape}")
    print(f"phi_angles shape: {phi_angles.shape}")
    print(f"psi_angles shape: {psi_angles.shape}")
    print(f"b_factors shape: {b_factors.shape}")
    # print(f"residue_depths shape: {residue_depths.shape}")

    # Adjust ss_onehot, phi_angles, psi_angles, b_factors if needed
    min_length = min(num_residues, ss_onehot.shape[0], phi_angles.shape[0],
                     psi_angles.shape[0], b_factors.shape[0])

    esm2_embeddings = esm2_embeddings[:min_length]
    ss_onehot = ss_onehot[:min_length]
    phi_angles = phi_angles[:min_length]
    psi_angles = psi_angles[:min_length]
    b_factors = b_factors[:min_length]

    # Concatenate all features
    node_features = torch.cat([
        esm2_embeddings,  # [num_residues, 1280]
        ss_onehot,        # [num_residues, 4]
        phi_angles,       # [num_residues, 1]
        psi_angles,       # [num_residues, 1]
        b_factors        # [num_residues, 1]
    ], dim=1)  # Shape: [num_residues, 1287]

    return node_features

In [None]:
def create_edge_features(distances, threshold=8.0):
    num_residues = distances.shape[0]
    contact_map = (distances < threshold) & (distances > 0)

    # Restrict indices to be within the valid node range (0 to num_residues - 1)
    edge_index = torch.nonzero(torch.tensor(contact_map, dtype=torch.bool), as_tuple=False).t()
    edge_index = edge_index[:, :num_residues]  # Ensure indices are within bounds

    # Vectorized edge attribute computation
    src, dst = edge_index[0], edge_index[1]
    dists = torch.tensor(distances[src, dst], dtype=torch.float32)
    seq_seps = torch.abs(src - dst).float()

    # Normalize features
    dists = dists / threshold  # Scale to 0–1
    seq_seps = seq_seps / num_residues  # Scale to 0–1

    edge_attr = torch.stack([dists, seq_seps], dim=1)  # Shape: [num_edges, 2]

    return edge_index, edge_attr


In [None]:
def create_graph_data(node_features, edge_index, edge_attr, labels):
    """
    Create a PyTorch Geometric Data object for the protein graph.

    Args:
        node_features (torch.Tensor): Shape [num_residues, 1288], fused node features.
        edge_index (torch.Tensor): Shape [2, num_edges], indices of connected nodes.
        edge_attr (torch.Tensor): Shape [num_edges, 2], edge features.
        labels (torch.Tensor): Shape [num_residues], binary labels (0 or 1) for binding sites.

    Returns:
        Data: PyTorch Geometric Data object.
    """
    return Data(
        x=node_features,      # Node features
        edge_index=edge_index,  # Edge indices
        edge_attr=edge_attr,  # Edge features
        y=labels              # Labels for binding site prediction
    )

### Test feature extraction

In [None]:
sample_idx = 540
sample_prot_id = train_df.iloc[sample_idx]['prot_id']
sample_sequence = train_df.iloc[sample_idx]['sequence']
sample_sequence_len = len(train_df.iloc[sample_idx]['sequence'])
sample_labels = train_df.iloc[sample_idx]['any_ligand_binding_sites']
sample_train_embeddings = train_embeddings[sample_idx]

sample_structure_file = f"/content/drive/MyDrive/Protein-binding/esmFold_pdb_files/{sample_prot_id}.pdb"

sample_structure = get_structure(sample_prot_id, sample_structure_file)
sample_coordinates = extract_coordinates(sample_structure)
sample_distances = calculate_residue_distances(sample_coordinates)
sample_ss_one_hot = get_secondary_structure_mdtraj(sample_structure_file, sample_sequence_len)['one_hot'].to(device)
sample_phi_angles, sample_psi_angles = get_dihedral_angles(sample_structure_file, sample_sequence_len)
sample_b_factors = get_b_factors(sample_structure, sample_sequence_len).to(device)
sample_rsa_values = compute_rsa(sample_structure_file)
min_length = 1

### Prepare data tensors

In [None]:
def get_graph_data(df, embeddings, device):
    graph_data_list = []

    for idx in tqdm(range(len(df))):
        prot_id = df.iloc[idx]['prot_id']
        sequence = df.iloc[idx]['sequence']
        sequence_len = len(df.iloc[idx]['sequence'])
        labels = df.iloc[idx]['any_ligand_binding_sites']
        embedding = torch.tensor(embeddings[idx]['embeddings'])

        try:
            structure_file = f"/content/drive/MyDrive/Protein-binding/esmFold_pdb_files/{prot_id}.pdb"
            structure = get_structure(prot_id, structure_file)
            coordinates = extract_coordinates(structure)
            distances = calculate_residue_distances(coordinates)
            # depths = calculate_depth(structure, sequence_len)
            # rsa_values = compute_rsa(structure_file)
            ss_one_hot = get_secondary_structure_mdtraj(structure_file, sequence_len)['one_hot'].to(device)
            phi_angles, psi_angles = get_dihedral_angles(structure_file, sequence_len)
            b_factors = get_b_factors(structure, sequence_len).to(device)
        except Exception as e:
            # print(f"PDB file of {prot_id} ID cannot be found")
            print(f"Error: {e}")

        edge_index, edge_attr = create_edge_features(distances)
        node_features = fuse_features(embedding.to(device), ss_one_hot,
                                      phi_angles.to(device), psi_angles.to(device),
                                      b_factors)
        labels = torch.tensor(labels, dtype = torch.long).to(device)
        print(f"Shape of labels: {labels.shape}")
        graph_data = create_graph_data(node_features, edge_index, edge_attr, labels)

        graph_data_list.append(graph_data)

    return graph_data_list

In [None]:
train_graphs_data = get_graph_data(train_df, train_embeddings, device)
test_graphs_data = get_graph_data(test_df, test_embeddings, device)

  0%|          | 2/1013 [00:00<02:50,  5.94it/s]

esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])


  0%|          | 4/1013 [00:00<02:44,  6.15it/s]

esm2_embeddings shape: torch.Size([231, 1280])
ss_onehot shape: torch.Size([231, 4])
phi_angles shape: torch.Size([231, 1])
psi_angles shape: torch.Size([231, 1])
b_factors shape: torch.Size([231, 1])
Shape of labels: torch.Size([231])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])


  0%|          | 5/1013 [00:00<02:41,  6.26it/s]

esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])


  1%|          | 8/1013 [00:01<02:20,  7.15it/s]

esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


  1%|          | 11/1013 [00:01<02:08,  7.80it/s]

esm2_embeddings shape: torch.Size([222, 1280])
ss_onehot shape: torch.Size([222, 4])
phi_angles shape: torch.Size([222, 1])
psi_angles shape: torch.Size([222, 1])
b_factors shape: torch.Size([222, 1])
Shape of labels: torch.Size([222])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])
esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])


  1%|▏         | 13/1013 [00:02<03:40,  4.53it/s]

esm2_embeddings shape: torch.Size([182, 1280])
ss_onehot shape: torch.Size([182, 4])
phi_angles shape: torch.Size([182, 1])
psi_angles shape: torch.Size([182, 1])
b_factors shape: torch.Size([182, 1])
Shape of labels: torch.Size([182])
esm2_embeddings shape: torch.Size([155, 1280])
ss_onehot shape: torch.Size([155, 4])
phi_angles shape: torch.Size([155, 1])
psi_angles shape: torch.Size([155, 1])
b_factors shape: torch.Size([155, 1])
Shape of labels: torch.Size([155])


  1%|▏         | 15/1013 [00:02<03:15,  5.10it/s]

esm2_embeddings shape: torch.Size([202, 1280])
ss_onehot shape: torch.Size([202, 4])
phi_angles shape: torch.Size([202, 1])
psi_angles shape: torch.Size([202, 1])
b_factors shape: torch.Size([202, 1])
Shape of labels: torch.Size([202])
esm2_embeddings shape: torch.Size([198, 1280])
ss_onehot shape: torch.Size([198, 4])
phi_angles shape: torch.Size([198, 1])
psi_angles shape: torch.Size([198, 1])
b_factors shape: torch.Size([198, 1])
Shape of labels: torch.Size([198])


  2%|▏         | 17/1013 [00:03<02:48,  5.92it/s]

esm2_embeddings shape: torch.Size([227, 1280])
ss_onehot shape: torch.Size([227, 4])
phi_angles shape: torch.Size([227, 1])
psi_angles shape: torch.Size([227, 1])
b_factors shape: torch.Size([227, 1])
Shape of labels: torch.Size([227])
esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])


  2%|▏         | 18/1013 [00:03<02:40,  6.19it/s]

esm2_embeddings shape: torch.Size([181, 1280])
ss_onehot shape: torch.Size([181, 4])
phi_angles shape: torch.Size([181, 1])
psi_angles shape: torch.Size([181, 1])
b_factors shape: torch.Size([181, 1])
Shape of labels: torch.Size([181])


  2%|▏         | 20/1013 [00:03<02:59,  5.53it/s]

esm2_embeddings shape: torch.Size([278, 1280])
ss_onehot shape: torch.Size([278, 4])
phi_angles shape: torch.Size([278, 1])
psi_angles shape: torch.Size([278, 1])
b_factors shape: torch.Size([278, 1])
Shape of labels: torch.Size([278])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


  2%|▏         | 22/1013 [00:03<02:24,  6.86it/s]

esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


  2%|▏         | 24/1013 [00:04<03:31,  4.67it/s]

esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])
esm2_embeddings shape: torch.Size([166, 1280])
ss_onehot shape: torch.Size([166, 4])
phi_angles shape: torch.Size([166, 1])
psi_angles shape: torch.Size([166, 1])
b_factors shape: torch.Size([166, 1])
Shape of labels: torch.Size([166])


  3%|▎         | 26/1013 [00:04<03:14,  5.08it/s]

esm2_embeddings shape: torch.Size([265, 1280])
ss_onehot shape: torch.Size([265, 4])
phi_angles shape: torch.Size([265, 1])
psi_angles shape: torch.Size([265, 1])
b_factors shape: torch.Size([265, 1])
Shape of labels: torch.Size([265])
esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size([170, 4])
phi_angles shape: torch.Size([170, 1])
psi_angles shape: torch.Size([170, 1])
b_factors shape: torch.Size([170, 1])
Shape of labels: torch.Size([170])


  3%|▎         | 28/1013 [00:04<02:17,  7.17it/s]

esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])
esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])
esm2_embeddings shape: torch.Size([91, 1280])
ss_onehot shape: torch.Size([91, 4])
phi_angles shape: torch.Size([91, 1])
psi_angles shape: torch.Size([91, 1])
b_factors shape: torch.Size([91, 1])
Shape of labels: torch.Size([91])


  3%|▎         | 32/1013 [00:05<01:56,  8.39it/s]

esm2_embeddings shape: torch.Size([248, 1280])
ss_onehot shape: torch.Size([248, 4])
phi_angles shape: torch.Size([248, 1])
psi_angles shape: torch.Size([248, 1])
b_factors shape: torch.Size([248, 1])
Shape of labels: torch.Size([248])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])
esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])


  3%|▎         | 34/1013 [00:05<02:29,  6.53it/s]

esm2_embeddings shape: torch.Size([267, 1280])
ss_onehot shape: torch.Size([267, 4])
phi_angles shape: torch.Size([267, 1])
psi_angles shape: torch.Size([267, 1])
b_factors shape: torch.Size([267, 1])
Shape of labels: torch.Size([267])
esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])


  3%|▎         | 35/1013 [00:05<02:25,  6.73it/s]

esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])


  4%|▎         | 37/1013 [00:06<04:14,  3.83it/s]

esm2_embeddings shape: torch.Size([270, 1280])
ss_onehot shape: torch.Size([270, 4])
phi_angles shape: torch.Size([270, 1])
psi_angles shape: torch.Size([270, 1])
b_factors shape: torch.Size([270, 1])
Shape of labels: torch.Size([270])
esm2_embeddings shape: torch.Size([188, 1280])
ss_onehot shape: torch.Size([188, 4])
phi_angles shape: torch.Size([188, 1])
psi_angles shape: torch.Size([188, 1])
b_factors shape: torch.Size([188, 1])
Shape of labels: torch.Size([188])


  4%|▍         | 39/1013 [00:07<03:14,  5.01it/s]

esm2_embeddings shape: torch.Size([81, 1280])
ss_onehot shape: torch.Size([81, 4])
phi_angles shape: torch.Size([81, 1])
psi_angles shape: torch.Size([81, 1])
b_factors shape: torch.Size([81, 1])
Shape of labels: torch.Size([81])
esm2_embeddings shape: torch.Size([174, 1280])
ss_onehot shape: torch.Size([174, 4])
phi_angles shape: torch.Size([174, 1])
psi_angles shape: torch.Size([174, 1])
b_factors shape: torch.Size([174, 1])
Shape of labels: torch.Size([174])


  4%|▍         | 41/1013 [00:07<02:24,  6.74it/s]

esm2_embeddings shape: torch.Size([83, 1280])
ss_onehot shape: torch.Size([83, 4])
phi_angles shape: torch.Size([83, 1])
psi_angles shape: torch.Size([83, 1])
b_factors shape: torch.Size([83, 1])
Shape of labels: torch.Size([83])
esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])


  4%|▍         | 42/1013 [00:07<02:29,  6.48it/s]

esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])
esm2_embeddings shape: torch.Size([68, 1280])
ss_onehot shape: torch.Size([68, 4])
phi_angles shape: torch.Size([68, 1])
psi_angles shape: torch.Size([68, 1])
b_factors shape: torch.Size([68, 1])
Shape of labels: torch.Size([68])


  4%|▍         | 44/1013 [00:07<02:15,  7.16it/s]

esm2_embeddings shape: torch.Size([210, 1280])
ss_onehot shape: torch.Size([210, 4])
phi_angles shape: torch.Size([210, 1])
psi_angles shape: torch.Size([210, 1])
b_factors shape: torch.Size([210, 1])
Shape of labels: torch.Size([210])
esm2_embeddings shape: torch.Size([84, 1280])
ss_onehot shape: torch.Size([84, 4])
phi_angles shape: torch.Size([84, 1])
psi_angles shape: torch.Size([84, 1])
b_factors shape: torch.Size([84, 1])
Shape of labels: torch.Size([84])


  5%|▍         | 46/1013 [00:07<02:06,  7.62it/s]

esm2_embeddings shape: torch.Size([194, 1280])
ss_onehot shape: torch.Size([194, 4])
phi_angles shape: torch.Size([194, 1])
psi_angles shape: torch.Size([194, 1])
b_factors shape: torch.Size([194, 1])
Shape of labels: torch.Size([194])


  5%|▍         | 48/1013 [00:08<02:26,  6.59it/s]

esm2_embeddings shape: torch.Size([294, 1280])
ss_onehot shape: torch.Size([294, 4])
phi_angles shape: torch.Size([294, 1])
psi_angles shape: torch.Size([294, 1])
b_factors shape: torch.Size([294, 1])
Shape of labels: torch.Size([294])
esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])


  5%|▍         | 49/1013 [00:08<04:01,  4.00it/s]

esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])


  5%|▌         | 51/1013 [00:09<03:22,  4.74it/s]

esm2_embeddings shape: torch.Size([246, 1280])
ss_onehot shape: torch.Size([246, 4])
phi_angles shape: torch.Size([246, 1])
psi_angles shape: torch.Size([246, 1])
b_factors shape: torch.Size([246, 1])
Shape of labels: torch.Size([246])
esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])


  5%|▌         | 53/1013 [00:09<02:50,  5.62it/s]

esm2_embeddings shape: torch.Size([206, 1280])
ss_onehot shape: torch.Size([206, 4])
phi_angles shape: torch.Size([206, 1])
psi_angles shape: torch.Size([206, 1])
b_factors shape: torch.Size([206, 1])
Shape of labels: torch.Size([206])
esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size([170, 4])
phi_angles shape: torch.Size([170, 1])
psi_angles shape: torch.Size([170, 1])
b_factors shape: torch.Size([170, 1])
Shape of labels: torch.Size([170])


  5%|▌         | 55/1013 [00:09<02:25,  6.56it/s]

esm2_embeddings shape: torch.Size([197, 1280])
ss_onehot shape: torch.Size([197, 4])
phi_angles shape: torch.Size([197, 1])
psi_angles shape: torch.Size([197, 1])
b_factors shape: torch.Size([197, 1])
Shape of labels: torch.Size([197])
esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])


  6%|▌         | 56/1013 [00:09<02:18,  6.92it/s]

esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])


  6%|▌         | 59/1013 [00:10<02:03,  7.70it/s]

esm2_embeddings shape: torch.Size([258, 1280])
ss_onehot shape: torch.Size([258, 4])
phi_angles shape: torch.Size([258, 1])
psi_angles shape: torch.Size([258, 1])
b_factors shape: torch.Size([258, 1])
Shape of labels: torch.Size([258])
esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])
esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])
esm2_embeddings shape: torch.Size([141, 1280])
ss_onehot shape: torch.Size([141, 4])
phi_angles shape: torch.Size([141, 1])
psi_angles shape: torch.Size([141, 1])
b_factors shape: torch.Size([141, 1])
Shape of labels: torch.Size([141])


  6%|▌         | 61/1013 [00:10<03:22,  4.71it/s]

esm2_embeddings shape: torch.Size([207, 1280])
ss_onehot shape: torch.Size([207, 4])
phi_angles shape: torch.Size([207, 1])
psi_angles shape: torch.Size([207, 1])
b_factors shape: torch.Size([207, 1])
Shape of labels: torch.Size([207])


  6%|▌         | 63/1013 [00:11<03:18,  4.79it/s]

esm2_embeddings shape: torch.Size([304, 1280])
ss_onehot shape: torch.Size([304, 4])
phi_angles shape: torch.Size([304, 1])
psi_angles shape: torch.Size([304, 1])
b_factors shape: torch.Size([304, 1])
Shape of labels: torch.Size([304])
esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])


  7%|▋         | 66/1013 [00:11<02:25,  6.50it/s]

esm2_embeddings shape: torch.Size([203, 1280])
ss_onehot shape: torch.Size([203, 4])
phi_angles shape: torch.Size([203, 1])
psi_angles shape: torch.Size([203, 1])
b_factors shape: torch.Size([203, 1])
Shape of labels: torch.Size([203])
esm2_embeddings shape: torch.Size([102, 1280])
ss_onehot shape: torch.Size([102, 4])
phi_angles shape: torch.Size([102, 1])
psi_angles shape: torch.Size([102, 1])
b_factors shape: torch.Size([102, 1])
Shape of labels: torch.Size([102])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])


  7%|▋         | 68/1013 [00:11<02:10,  7.22it/s]

esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


  7%|▋         | 69/1013 [00:12<02:24,  6.51it/s]

esm2_embeddings shape: torch.Size([227, 1280])
ss_onehot shape: torch.Size([227, 4])
phi_angles shape: torch.Size([227, 1])
psi_angles shape: torch.Size([227, 1])
b_factors shape: torch.Size([227, 1])
Shape of labels: torch.Size([227])


  7%|▋         | 71/1013 [00:12<03:47,  4.15it/s]

esm2_embeddings shape: torch.Size([233, 1280])
ss_onehot shape: torch.Size([233, 4])
phi_angles shape: torch.Size([233, 1])
psi_angles shape: torch.Size([233, 1])
b_factors shape: torch.Size([233, 1])
Shape of labels: torch.Size([233])
esm2_embeddings shape: torch.Size([168, 1280])
ss_onehot shape: torch.Size([168, 4])
phi_angles shape: torch.Size([168, 1])
psi_angles shape: torch.Size([168, 1])
b_factors shape: torch.Size([168, 1])
Shape of labels: torch.Size([168])


  7%|▋         | 72/1013 [00:13<03:58,  3.94it/s]

esm2_embeddings shape: torch.Size([297, 1280])
ss_onehot shape: torch.Size([297, 4])
phi_angles shape: torch.Size([297, 1])
psi_angles shape: torch.Size([297, 1])
b_factors shape: torch.Size([297, 1])
Shape of labels: torch.Size([297])
esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])


  7%|▋         | 74/1013 [00:13<03:21,  4.66it/s]

esm2_embeddings shape: torch.Size([265, 1280])
ss_onehot shape: torch.Size([265, 4])
phi_angles shape: torch.Size([265, 1])
psi_angles shape: torch.Size([265, 1])
b_factors shape: torch.Size([265, 1])
Shape of labels: torch.Size([265])


  8%|▊         | 76/1013 [00:14<03:28,  4.49it/s]

esm2_embeddings shape: torch.Size([307, 1280])
ss_onehot shape: torch.Size([307, 4])
phi_angles shape: torch.Size([307, 1])
psi_angles shape: torch.Size([307, 1])
b_factors shape: torch.Size([307, 1])
Shape of labels: torch.Size([307])
esm2_embeddings shape: torch.Size([221, 1280])
ss_onehot shape: torch.Size([221, 4])
phi_angles shape: torch.Size([221, 1])
psi_angles shape: torch.Size([221, 1])
b_factors shape: torch.Size([221, 1])
Shape of labels: torch.Size([221])


  8%|▊         | 78/1013 [00:14<03:08,  4.95it/s]

esm2_embeddings shape: torch.Size([251, 1280])
ss_onehot shape: torch.Size([251, 4])
phi_angles shape: torch.Size([251, 1])
psi_angles shape: torch.Size([251, 1])
b_factors shape: torch.Size([251, 1])
Shape of labels: torch.Size([251])
esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


  8%|▊         | 81/1013 [00:15<03:11,  4.86it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([62, 1280])
ss_onehot shape: torch.Size([62, 4])
phi_angles shape: torch.Size([62, 1])
psi_angles shape: torch.Size([62, 1])
b_factors shape: torch.Size([62, 1])
Shape of labels: torch.Size([62])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])


  8%|▊         | 83/1013 [00:15<02:55,  5.29it/s]

esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])
esm2_embeddings shape: torch.Size([218, 1280])
ss_onehot shape: torch.Size([218, 4])
phi_angles shape: torch.Size([218, 1])
psi_angles shape: torch.Size([218, 1])
b_factors shape: torch.Size([218, 1])
Shape of labels: torch.Size([218])


  8%|▊         | 85/1013 [00:15<02:24,  6.43it/s]

esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])
esm2_embeddings shape: torch.Size([198, 1280])
ss_onehot shape: torch.Size([198, 4])
phi_angles shape: torch.Size([198, 1])
psi_angles shape: torch.Size([198, 1])
b_factors shape: torch.Size([198, 1])
Shape of labels: torch.Size([198])
esm2_embeddings shape: torch.Size([47, 1280])
ss_onehot shape: torch.Size([47, 4])
phi_angles shape: torch.Size([47, 1])
psi_angles shape: torch.Size([47, 1])
b_factors shape: torch.Size([47, 1])
Shape of labels: torch.Size([47])


  9%|▊         | 88/1013 [00:16<02:26,  6.30it/s]

esm2_embeddings shape: torch.Size([321, 1280])
ss_onehot shape: torch.Size([321, 4])
phi_angles shape: torch.Size([321, 1])
psi_angles shape: torch.Size([321, 1])
b_factors shape: torch.Size([321, 1])
Shape of labels: torch.Size([321])
esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])


  9%|▉         | 90/1013 [00:16<02:02,  7.51it/s]

esm2_embeddings shape: torch.Size([117, 1280])
ss_onehot shape: torch.Size([117, 4])
phi_angles shape: torch.Size([117, 1])
psi_angles shape: torch.Size([117, 1])
b_factors shape: torch.Size([117, 1])
Shape of labels: torch.Size([117])
esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])


  9%|▉         | 91/1013 [00:16<02:01,  7.60it/s]

esm2_embeddings shape: torch.Size([154, 1280])
ss_onehot shape: torch.Size([154, 4])
phi_angles shape: torch.Size([154, 1])
psi_angles shape: torch.Size([154, 1])
b_factors shape: torch.Size([154, 1])
Shape of labels: torch.Size([154])


  9%|▉         | 93/1013 [00:17<03:01,  5.07it/s]

esm2_embeddings shape: torch.Size([108, 1280])
ss_onehot shape: torch.Size([108, 4])
phi_angles shape: torch.Size([108, 1])
psi_angles shape: torch.Size([108, 1])
b_factors shape: torch.Size([108, 1])
Shape of labels: torch.Size([108])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])


  9%|▉         | 95/1013 [00:17<02:54,  5.25it/s]

esm2_embeddings shape: torch.Size([281, 1280])
ss_onehot shape: torch.Size([281, 4])
phi_angles shape: torch.Size([281, 1])
psi_angles shape: torch.Size([281, 1])
b_factors shape: torch.Size([281, 1])
Shape of labels: torch.Size([281])
esm2_embeddings shape: torch.Size([111, 1280])
ss_onehot shape: torch.Size([111, 4])
phi_angles shape: torch.Size([111, 1])
psi_angles shape: torch.Size([111, 1])
b_factors shape: torch.Size([111, 1])
Shape of labels: torch.Size([111])


 10%|▉         | 97/1013 [00:17<02:59,  5.11it/s]

esm2_embeddings shape: torch.Size([328, 1280])
ss_onehot shape: torch.Size([328, 4])
phi_angles shape: torch.Size([328, 1])
psi_angles shape: torch.Size([328, 1])
b_factors shape: torch.Size([328, 1])
Shape of labels: torch.Size([328])
esm2_embeddings shape: torch.Size([91, 1280])
ss_onehot shape: torch.Size([91, 4])
phi_angles shape: torch.Size([91, 1])
psi_angles shape: torch.Size([91, 1])
b_factors shape: torch.Size([91, 1])
Shape of labels: torch.Size([91])


 10%|▉         | 100/1013 [00:18<02:22,  6.41it/s]

esm2_embeddings shape: torch.Size([182, 1280])
ss_onehot shape: torch.Size([182, 4])
phi_angles shape: torch.Size([182, 1])
psi_angles shape: torch.Size([182, 1])
b_factors shape: torch.Size([182, 1])
Shape of labels: torch.Size([182])
esm2_embeddings shape: torch.Size([166, 1280])
ss_onehot shape: torch.Size([166, 4])
phi_angles shape: torch.Size([166, 1])
psi_angles shape: torch.Size([166, 1])
b_factors shape: torch.Size([166, 1])
Shape of labels: torch.Size([166])


 10%|█         | 103/1013 [00:18<01:49,  8.30it/s]

esm2_embeddings shape: torch.Size([194, 1280])
ss_onehot shape: torch.Size([194, 4])
phi_angles shape: torch.Size([194, 1])
psi_angles shape: torch.Size([194, 1])
b_factors shape: torch.Size([194, 1])
Shape of labels: torch.Size([194])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])
esm2_embeddings shape: torch.Size([62, 1280])
ss_onehot shape: torch.Size([62, 4])
phi_angles shape: torch.Size([62, 1])
psi_angles shape: torch.Size([62, 1])
b_factors shape: torch.Size([62, 1])
Shape of labels: torch.Size([62])


 10%|█         | 104/1013 [00:19<03:22,  4.49it/s]

esm2_embeddings shape: torch.Size([183, 1280])
ss_onehot shape: torch.Size([183, 4])
phi_angles shape: torch.Size([183, 1])
psi_angles shape: torch.Size([183, 1])
b_factors shape: torch.Size([183, 1])
Shape of labels: torch.Size([183])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])


 11%|█         | 107/1013 [00:19<02:54,  5.20it/s]

esm2_embeddings shape: torch.Size([321, 1280])
ss_onehot shape: torch.Size([321, 4])
phi_angles shape: torch.Size([321, 1])
psi_angles shape: torch.Size([321, 1])
b_factors shape: torch.Size([321, 1])
Shape of labels: torch.Size([321])
esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])


 11%|█         | 109/1013 [00:19<02:40,  5.64it/s]

esm2_embeddings shape: torch.Size([229, 1280])
ss_onehot shape: torch.Size([229, 4])
phi_angles shape: torch.Size([229, 1])
psi_angles shape: torch.Size([229, 1])
b_factors shape: torch.Size([229, 1])
Shape of labels: torch.Size([229])


 11%|█         | 111/1013 [00:20<02:28,  6.09it/s]

esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])
esm2_embeddings shape: torch.Size([141, 1280])
ss_onehot shape: torch.Size([141, 4])
phi_angles shape: torch.Size([141, 1])
psi_angles shape: torch.Size([141, 1])
b_factors shape: torch.Size([141, 1])
Shape of labels: torch.Size([141])


 11%|█         | 113/1013 [00:20<02:18,  6.50it/s]

esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([203, 1280])
ss_onehot shape: torch.Size([203, 4])
phi_angles shape: torch.Size([203, 1])
psi_angles shape: torch.Size([203, 1])
b_factors shape: torch.Size([203, 1])
Shape of labels: torch.Size([203])


 11%|█▏        | 114/1013 [00:21<04:23,  3.41it/s]

esm2_embeddings shape: torch.Size([255, 1280])
ss_onehot shape: torch.Size([255, 4])
phi_angles shape: torch.Size([255, 1])
psi_angles shape: torch.Size([255, 1])
b_factors shape: torch.Size([255, 1])
Shape of labels: torch.Size([255])
esm2_embeddings shape: torch.Size([108, 1280])
ss_onehot shape: torch.Size([108, 4])
phi_angles shape: torch.Size([108, 1])
psi_angles shape: torch.Size([108, 1])
b_factors shape: torch.Size([108, 1])
Shape of labels: torch.Size([108])


 12%|█▏        | 117/1013 [00:21<03:10,  4.70it/s]

esm2_embeddings shape: torch.Size([288, 1280])
ss_onehot shape: torch.Size([288, 4])
phi_angles shape: torch.Size([288, 1])
psi_angles shape: torch.Size([288, 1])
b_factors shape: torch.Size([288, 1])
Shape of labels: torch.Size([288])
esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])


 12%|█▏        | 118/1013 [00:21<03:11,  4.68it/s]

esm2_embeddings shape: torch.Size([242, 1280])
ss_onehot shape: torch.Size([242, 4])
phi_angles shape: torch.Size([242, 1])
psi_angles shape: torch.Size([242, 1])
b_factors shape: torch.Size([242, 1])
Shape of labels: torch.Size([242])
esm2_embeddings shape: torch.Size([95, 1280])
ss_onehot shape: torch.Size([95, 4])
phi_angles shape: torch.Size([95, 1])
psi_angles shape: torch.Size([95, 1])
b_factors shape: torch.Size([95, 1])
Shape of labels: torch.Size([95])


 12%|█▏        | 122/1013 [00:22<02:15,  6.57it/s]

esm2_embeddings shape: torch.Size([262, 1280])
ss_onehot shape: torch.Size([262, 4])
phi_angles shape: torch.Size([262, 1])
psi_angles shape: torch.Size([262, 1])
b_factors shape: torch.Size([262, 1])
Shape of labels: torch.Size([262])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])
esm2_embeddings shape: torch.Size([107, 1280])
ss_onehot shape: torch.Size([107, 4])
phi_angles shape: torch.Size([107, 1])
psi_angles shape: torch.Size([107, 1])
b_factors shape: torch.Size([107, 1])
Shape of labels: torch.Size([107])


 12%|█▏        | 125/1013 [00:23<02:49,  5.23it/s]

esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])
esm2_embeddings shape: torch.Size([143, 1280])
ss_onehot shape: torch.Size([143, 4])
phi_angles shape: torch.Size([143, 1])
psi_angles shape: torch.Size([143, 1])
b_factors shape: torch.Size([143, 1])
Shape of labels: torch.Size([143])


 13%|█▎        | 127/1013 [00:23<02:41,  5.48it/s]

esm2_embeddings shape: torch.Size([249, 1280])
ss_onehot shape: torch.Size([249, 4])
phi_angles shape: torch.Size([249, 1])
psi_angles shape: torch.Size([249, 1])
b_factors shape: torch.Size([249, 1])
Shape of labels: torch.Size([249])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])


 13%|█▎        | 129/1013 [00:23<02:19,  6.35it/s]

esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])
esm2_embeddings shape: torch.Size([183, 1280])
ss_onehot shape: torch.Size([183, 4])
phi_angles shape: torch.Size([183, 1])
psi_angles shape: torch.Size([183, 1])
b_factors shape: torch.Size([183, 1])
Shape of labels: torch.Size([183])


 13%|█▎        | 131/1013 [00:23<02:04,  7.08it/s]

esm2_embeddings shape: torch.Size([176, 1280])
ss_onehot shape: torch.Size([176, 4])
phi_angles shape: torch.Size([176, 1])
psi_angles shape: torch.Size([176, 1])
b_factors shape: torch.Size([176, 1])
Shape of labels: torch.Size([176])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])


 13%|█▎        | 132/1013 [00:24<01:58,  7.45it/s]

esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])
esm2_embeddings shape: torch.Size([113, 1280])
ss_onehot shape: torch.Size([113, 4])
phi_angles shape: torch.Size([113, 1])
psi_angles shape: torch.Size([113, 1])
b_factors shape: torch.Size([113, 1])
Shape of labels: torch.Size([113])


 13%|█▎        | 134/1013 [00:24<01:52,  7.82it/s]

esm2_embeddings shape: torch.Size([192, 1280])
ss_onehot shape: torch.Size([192, 4])
phi_angles shape: torch.Size([192, 1])
psi_angles shape: torch.Size([192, 1])
b_factors shape: torch.Size([192, 1])
Shape of labels: torch.Size([192])


 13%|█▎        | 136/1013 [00:25<04:36,  3.17it/s]

esm2_embeddings shape: torch.Size([537, 1280])
ss_onehot shape: torch.Size([537, 4])
phi_angles shape: torch.Size([537, 1])
psi_angles shape: torch.Size([537, 1])
b_factors shape: torch.Size([537, 1])
Shape of labels: torch.Size([537])
esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])
esm2_embeddings shape: torch.Size([64, 1280])
ss_onehot shape: torch.Size([64, 4])
phi_angles shape: torch.Size([64, 1])
psi_angles shape: torch.Size([64, 1])
b_factors shape: torch.Size([64, 1])
Shape of labels: torch.Size([64])


 14%|█▎        | 139/1013 [00:25<02:55,  4.98it/s]

esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])
esm2_embeddings shape: torch.Size([180, 1280])
ss_onehot shape: torch.Size([180, 4])
phi_angles shape: torch.Size([180, 1])
psi_angles shape: torch.Size([180, 1])
b_factors shape: torch.Size([180, 1])
Shape of labels: torch.Size([180])


 14%|█▍        | 141/1013 [00:26<02:29,  5.82it/s]

esm2_embeddings shape: torch.Size([96, 1280])
ss_onehot shape: torch.Size([96, 4])
phi_angles shape: torch.Size([96, 1])
psi_angles shape: torch.Size([96, 1])
b_factors shape: torch.Size([96, 1])
Shape of labels: torch.Size([96])
esm2_embeddings shape: torch.Size([218, 1280])
ss_onehot shape: torch.Size([218, 4])
phi_angles shape: torch.Size([218, 1])
psi_angles shape: torch.Size([218, 1])
b_factors shape: torch.Size([218, 1])
Shape of labels: torch.Size([218])


 14%|█▍        | 143/1013 [00:26<02:08,  6.75it/s]

esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])
esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])


 14%|█▍        | 144/1013 [00:26<02:06,  6.87it/s]

esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])


 14%|█▍        | 146/1013 [00:27<03:14,  4.46it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([81, 1280])
ss_onehot shape: torch.Size([81, 4])
phi_angles shape: torch.Size([81, 1])
psi_angles shape: torch.Size([81, 1])
b_factors shape: torch.Size([81, 1])
Shape of labels: torch.Size([81])


 15%|█▍        | 148/1013 [00:27<02:09,  6.68it/s]

esm2_embeddings shape: torch.Size([77, 1280])
ss_onehot shape: torch.Size([77, 4])
phi_angles shape: torch.Size([77, 1])
psi_angles shape: torch.Size([77, 1])
b_factors shape: torch.Size([77, 1])
Shape of labels: torch.Size([77])


 15%|█▍        | 150/1013 [00:27<02:10,  6.63it/s]

esm2_embeddings shape: torch.Size([249, 1280])
ss_onehot shape: torch.Size([249, 4])
phi_angles shape: torch.Size([249, 1])
psi_angles shape: torch.Size([249, 1])
b_factors shape: torch.Size([249, 1])
Shape of labels: torch.Size([249])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 15%|█▌        | 152/1013 [00:28<02:35,  5.55it/s]

esm2_embeddings shape: torch.Size([355, 1280])
ss_onehot shape: torch.Size([355, 4])
phi_angles shape: torch.Size([355, 1])
psi_angles shape: torch.Size([355, 1])
b_factors shape: torch.Size([355, 1])
Shape of labels: torch.Size([355])
esm2_embeddings shape: torch.Size([155, 1280])
ss_onehot shape: torch.Size([155, 4])
phi_angles shape: torch.Size([155, 1])
psi_angles shape: torch.Size([155, 1])
b_factors shape: torch.Size([155, 1])
Shape of labels: torch.Size([155])


 15%|█▌        | 155/1013 [00:28<01:51,  7.68it/s]

esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])
esm2_embeddings shape: torch.Size([106, 1280])
ss_onehot shape: torch.Size([106, 4])
phi_angles shape: torch.Size([106, 1])
psi_angles shape: torch.Size([106, 1])
b_factors shape: torch.Size([106, 1])
Shape of labels: torch.Size([106])
esm2_embeddings shape: torch.Size([134, 1280])
ss_onehot shape: torch.Size([134, 4])
phi_angles shape: torch.Size([134, 1])
psi_angles shape: torch.Size([134, 1])
b_factors shape: torch.Size([134, 1])
Shape of labels: torch.Size([134])


 15%|█▌        | 156/1013 [00:28<01:54,  7.47it/s]

esm2_embeddings shape: torch.Size([189, 1280])
ss_onehot shape: torch.Size([189, 4])
phi_angles shape: torch.Size([189, 1])
psi_angles shape: torch.Size([189, 1])
b_factors shape: torch.Size([189, 1])
Shape of labels: torch.Size([189])
esm2_embeddings shape: torch.Size([93, 1280])
ss_onehot shape: torch.Size([93, 4])
phi_angles shape: torch.Size([93, 1])
psi_angles shape: torch.Size([93, 1])
b_factors shape: torch.Size([93, 1])
Shape of labels: torch.Size([93])


 16%|█▌        | 158/1013 [00:29<03:00,  4.73it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])


 16%|█▌        | 160/1013 [00:29<02:54,  4.89it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([222, 1280])
ss_onehot shape: torch.Size([222, 4])
phi_angles shape: torch.Size([222, 1])
psi_angles shape: torch.Size([222, 1])
b_factors shape: torch.Size([222, 1])
Shape of labels: torch.Size([222])


 16%|█▌        | 161/1013 [00:29<02:53,  4.91it/s]

esm2_embeddings shape: torch.Size([218, 1280])
ss_onehot shape: torch.Size([218, 4])
phi_angles shape: torch.Size([218, 1])
psi_angles shape: torch.Size([218, 1])
b_factors shape: torch.Size([218, 1])
Shape of labels: torch.Size([218])
esm2_embeddings shape: torch.Size([69, 1280])
ss_onehot shape: torch.Size([69, 4])
phi_angles shape: torch.Size([69, 1])
psi_angles shape: torch.Size([69, 1])
b_factors shape: torch.Size([69, 1])
Shape of labels: torch.Size([69])


 16%|█▌        | 164/1013 [00:30<02:19,  6.10it/s]

esm2_embeddings shape: torch.Size([199, 1280])
ss_onehot shape: torch.Size([199, 4])
phi_angles shape: torch.Size([199, 1])
psi_angles shape: torch.Size([199, 1])
b_factors shape: torch.Size([199, 1])
Shape of labels: torch.Size([199])
esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])


 16%|█▋        | 166/1013 [00:30<02:25,  5.81it/s]

esm2_embeddings shape: torch.Size([232, 1280])
ss_onehot shape: torch.Size([232, 4])
phi_angles shape: torch.Size([232, 1])
psi_angles shape: torch.Size([232, 1])
b_factors shape: torch.Size([232, 1])
Shape of labels: torch.Size([232])
esm2_embeddings shape: torch.Size([212, 1280])
ss_onehot shape: torch.Size([212, 4])
phi_angles shape: torch.Size([212, 1])
psi_angles shape: torch.Size([212, 1])
b_factors shape: torch.Size([212, 1])
Shape of labels: torch.Size([212])


 16%|█▋        | 167/1013 [00:31<03:58,  3.55it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])


 17%|█▋        | 169/1013 [00:31<03:00,  4.68it/s]

esm2_embeddings shape: torch.Size([197, 1280])
ss_onehot shape: torch.Size([197, 4])
phi_angles shape: torch.Size([197, 1])
psi_angles shape: torch.Size([197, 1])
b_factors shape: torch.Size([197, 1])
Shape of labels: torch.Size([197])
esm2_embeddings shape: torch.Size([80, 1280])
ss_onehot shape: torch.Size([80, 4])
phi_angles shape: torch.Size([80, 1])
psi_angles shape: torch.Size([80, 1])
b_factors shape: torch.Size([80, 1])
Shape of labels: torch.Size([80])


 17%|█▋        | 172/1013 [00:31<02:43,  5.14it/s]

esm2_embeddings shape: torch.Size([324, 1280])
ss_onehot shape: torch.Size([324, 4])
phi_angles shape: torch.Size([324, 1])
psi_angles shape: torch.Size([324, 1])
b_factors shape: torch.Size([324, 1])
Shape of labels: torch.Size([324])
esm2_embeddings shape: torch.Size([198, 1280])
ss_onehot shape: torch.Size([198, 4])
phi_angles shape: torch.Size([198, 1])
psi_angles shape: torch.Size([198, 1])
b_factors shape: torch.Size([198, 1])
Shape of labels: torch.Size([198])


 17%|█▋        | 174/1013 [00:32<02:33,  5.47it/s]

esm2_embeddings shape: torch.Size([174, 1280])
ss_onehot shape: torch.Size([174, 4])
phi_angles shape: torch.Size([174, 1])
psi_angles shape: torch.Size([174, 1])
b_factors shape: torch.Size([174, 1])
Shape of labels: torch.Size([174])
esm2_embeddings shape: torch.Size([227, 1280])
ss_onehot shape: torch.Size([227, 4])
phi_angles shape: torch.Size([227, 1])
psi_angles shape: torch.Size([227, 1])
b_factors shape: torch.Size([227, 1])
Shape of labels: torch.Size([227])


 17%|█▋        | 176/1013 [00:32<02:09,  6.47it/s]

esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])
esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])


 17%|█▋        | 177/1013 [00:32<01:58,  7.07it/s]

esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 18%|█▊        | 179/1013 [00:33<03:38,  3.83it/s]

esm2_embeddings shape: torch.Size([269, 1280])
ss_onehot shape: torch.Size([269, 4])
phi_angles shape: torch.Size([269, 1])
psi_angles shape: torch.Size([269, 1])
b_factors shape: torch.Size([269, 1])
Shape of labels: torch.Size([269])
esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])


 18%|█▊        | 182/1013 [00:33<02:14,  6.20it/s]

esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 18%|█▊        | 183/1013 [00:33<02:03,  6.72it/s]

esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])


 18%|█▊        | 186/1013 [00:34<01:58,  6.97it/s]

esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])
esm2_embeddings shape: torch.Size([175, 1280])
ss_onehot shape: torch.Size([175, 4])
phi_angles shape: torch.Size([175, 1])
psi_angles shape: torch.Size([175, 1])
b_factors shape: torch.Size([175, 1])
Shape of labels: torch.Size([175])


 18%|█▊        | 187/1013 [00:34<01:54,  7.22it/s]

esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([111, 1280])
ss_onehot shape: torch.Size([111, 4])
phi_angles shape: torch.Size([111, 1])
psi_angles shape: torch.Size([111, 1])
b_factors shape: torch.Size([111, 1])
Shape of labels: torch.Size([111])


 19%|█▉        | 190/1013 [00:34<01:50,  7.46it/s]

esm2_embeddings shape: torch.Size([171, 1280])
ss_onehot shape: torch.Size([171, 4])
phi_angles shape: torch.Size([171, 1])
psi_angles shape: torch.Size([171, 1])
b_factors shape: torch.Size([171, 1])
Shape of labels: torch.Size([171])
esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])


 19%|█▉        | 193/1013 [00:35<02:42,  5.04it/s]

esm2_embeddings shape: torch.Size([308, 1280])
ss_onehot shape: torch.Size([308, 4])
phi_angles shape: torch.Size([308, 1])
psi_angles shape: torch.Size([308, 1])
b_factors shape: torch.Size([308, 1])
Shape of labels: torch.Size([308])
esm2_embeddings shape: torch.Size([81, 1280])
ss_onehot shape: torch.Size([81, 4])
phi_angles shape: torch.Size([81, 1])
psi_angles shape: torch.Size([81, 1])
b_factors shape: torch.Size([81, 1])
Shape of labels: torch.Size([81])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 19%|█▉        | 194/1013 [00:35<02:27,  5.57it/s]

esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])


 19%|█▉        | 197/1013 [00:36<02:04,  6.56it/s]

esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])
esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


 20%|█▉        | 199/1013 [00:36<01:41,  8.04it/s]

esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])
esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])


 20%|█▉        | 201/1013 [00:36<01:36,  8.41it/s]

esm2_embeddings shape: torch.Size([168, 1280])
ss_onehot shape: torch.Size([168, 4])
phi_angles shape: torch.Size([168, 1])
psi_angles shape: torch.Size([168, 1])
b_factors shape: torch.Size([168, 1])
Shape of labels: torch.Size([168])


 20%|██        | 203/1013 [00:37<01:59,  6.78it/s]

esm2_embeddings shape: torch.Size([305, 1280])
ss_onehot shape: torch.Size([305, 4])
phi_angles shape: torch.Size([305, 1])
psi_angles shape: torch.Size([305, 1])
b_factors shape: torch.Size([305, 1])
Shape of labels: torch.Size([305])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])


 20%|██        | 204/1013 [00:37<03:19,  4.05it/s]

esm2_embeddings shape: torch.Size([168, 1280])
ss_onehot shape: torch.Size([168, 4])
phi_angles shape: torch.Size([168, 1])
psi_angles shape: torch.Size([168, 1])
b_factors shape: torch.Size([168, 1])
Shape of labels: torch.Size([168])


 20%|██        | 205/1013 [00:37<03:19,  4.06it/s]

esm2_embeddings shape: torch.Size([269, 1280])
ss_onehot shape: torch.Size([269, 4])
phi_angles shape: torch.Size([269, 1])
psi_angles shape: torch.Size([269, 1])
b_factors shape: torch.Size([269, 1])
Shape of labels: torch.Size([269])


 20%|██        | 206/1013 [00:38<03:13,  4.16it/s]

esm2_embeddings shape: torch.Size([251, 1280])
ss_onehot shape: torch.Size([251, 4])
phi_angles shape: torch.Size([251, 1])
psi_angles shape: torch.Size([251, 1])
b_factors shape: torch.Size([251, 1])
Shape of labels: torch.Size([251])


 21%|██        | 209/1013 [00:38<02:16,  5.88it/s]

esm2_embeddings shape: torch.Size([260, 1280])
ss_onehot shape: torch.Size([260, 4])
phi_angles shape: torch.Size([260, 1])
psi_angles shape: torch.Size([260, 1])
b_factors shape: torch.Size([260, 1])
Shape of labels: torch.Size([260])
esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])


 21%|██        | 211/1013 [00:38<01:45,  7.60it/s]

esm2_embeddings shape: torch.Size([65, 1280])
ss_onehot shape: torch.Size([65, 4])
phi_angles shape: torch.Size([65, 1])
psi_angles shape: torch.Size([65, 1])
b_factors shape: torch.Size([65, 1])
Shape of labels: torch.Size([65])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([37, 1280])
ss_onehot shape: torch.Size([37, 4])
phi_angles shape: torch.Size([37, 1])
psi_angles shape: torch.Size([37, 1])
b_factors shape: torch.Size([37, 1])
Shape of labels: torch.Size([37])
esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])


 21%|██▏       | 216/1013 [00:38<01:15, 10.53it/s]

esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([63, 1280])
ss_onehot shape: torch.Size([63, 4])
phi_angles shape: torch.Size([63, 1])
psi_angles shape: torch.Size([63, 1])
b_factors shape: torch.Size([63, 1])
Shape of labels: torch.Size([63])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 22%|██▏       | 219/1013 [00:39<02:02,  6.47it/s]

esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])


 22%|██▏       | 220/1013 [00:39<01:56,  6.80it/s]

esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])
esm2_embeddings shape: torch.Size([100, 1280])
ss_onehot shape: torch.Size([100, 4])
phi_angles shape: torch.Size([100, 1])
psi_angles shape: torch.Size([100, 1])
b_factors shape: torch.Size([100, 1])
Shape of labels: torch.Size([100])


 22%|██▏       | 224/1013 [00:40<01:39,  7.94it/s]

esm2_embeddings shape: torch.Size([241, 1280])
ss_onehot shape: torch.Size([241, 4])
phi_angles shape: torch.Size([241, 1])
psi_angles shape: torch.Size([241, 1])
b_factors shape: torch.Size([241, 1])
Shape of labels: torch.Size([241])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])


 22%|██▏       | 225/1013 [00:40<01:58,  6.67it/s]

esm2_embeddings shape: torch.Size([266, 1280])
ss_onehot shape: torch.Size([266, 4])
phi_angles shape: torch.Size([266, 1])
psi_angles shape: torch.Size([266, 1])
b_factors shape: torch.Size([266, 1])
Shape of labels: torch.Size([266])


 22%|██▏       | 227/1013 [00:40<02:09,  6.09it/s]

esm2_embeddings shape: torch.Size([258, 1280])
ss_onehot shape: torch.Size([258, 4])
phi_angles shape: torch.Size([258, 1])
psi_angles shape: torch.Size([258, 1])
b_factors shape: torch.Size([258, 1])
Shape of labels: torch.Size([258])
esm2_embeddings shape: torch.Size([187, 1280])
ss_onehot shape: torch.Size([187, 4])
phi_angles shape: torch.Size([187, 1])
psi_angles shape: torch.Size([187, 1])
b_factors shape: torch.Size([187, 1])
Shape of labels: torch.Size([187])


 23%|██▎       | 229/1013 [00:41<01:46,  7.37it/s]

esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])
esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])
esm2_embeddings shape: torch.Size([114, 1280])
ss_onehot shape: torch.Size([114, 4])
phi_angles shape: torch.Size([114, 1])
psi_angles shape: torch.Size([114, 1])
b_factors shape: torch.Size([114, 1])
Shape of labels: torch.Size([114])


 23%|██▎       | 232/1013 [00:41<02:35,  5.02it/s]

esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])
esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])


 23%|██▎       | 234/1013 [00:42<02:17,  5.68it/s]

esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])
esm2_embeddings shape: torch.Size([213, 1280])
ss_onehot shape: torch.Size([213, 4])
phi_angles shape: torch.Size([213, 1])
psi_angles shape: torch.Size([213, 1])
b_factors shape: torch.Size([213, 1])
Shape of labels: torch.Size([213])


 23%|██▎       | 236/1013 [00:42<02:11,  5.89it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([173, 1280])
ss_onehot shape: torch.Size([173, 4])
phi_angles shape: torch.Size([173, 1])
psi_angles shape: torch.Size([173, 1])
b_factors shape: torch.Size([173, 1])
Shape of labels: torch.Size([173])


 23%|██▎       | 238/1013 [00:42<01:54,  6.75it/s]

esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])
esm2_embeddings shape: torch.Size([207, 1280])
ss_onehot shape: torch.Size([207, 4])
phi_angles shape: torch.Size([207, 1])
psi_angles shape: torch.Size([207, 1])
b_factors shape: torch.Size([207, 1])
Shape of labels: torch.Size([207])


 24%|██▎       | 240/1013 [00:42<01:33,  8.30it/s]

esm2_embeddings shape: torch.Size([63, 1280])
ss_onehot shape: torch.Size([63, 4])
phi_angles shape: torch.Size([63, 1])
psi_angles shape: torch.Size([63, 1])
b_factors shape: torch.Size([63, 1])
Shape of labels: torch.Size([63])
esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])


 24%|██▍       | 241/1013 [00:43<01:33,  8.23it/s]

esm2_embeddings shape: torch.Size([158, 1280])
ss_onehot shape: torch.Size([158, 4])
phi_angles shape: torch.Size([158, 1])
psi_angles shape: torch.Size([158, 1])
b_factors shape: torch.Size([158, 1])
Shape of labels: torch.Size([158])


 24%|██▍       | 242/1013 [00:43<01:50,  7.00it/s]

esm2_embeddings shape: torch.Size([245, 1280])
ss_onehot shape: torch.Size([245, 4])
phi_angles shape: torch.Size([245, 1])
psi_angles shape: torch.Size([245, 1])
b_factors shape: torch.Size([245, 1])
Shape of labels: torch.Size([245])


 24%|██▍       | 243/1013 [00:43<03:23,  3.79it/s]

esm2_embeddings shape: torch.Size([230, 1280])
ss_onehot shape: torch.Size([230, 4])
phi_angles shape: torch.Size([230, 1])
psi_angles shape: torch.Size([230, 1])
b_factors shape: torch.Size([230, 1])
Shape of labels: torch.Size([230])


 24%|██▍       | 245/1013 [00:44<03:10,  4.03it/s]

esm2_embeddings shape: torch.Size([315, 1280])
ss_onehot shape: torch.Size([315, 4])
phi_angles shape: torch.Size([315, 1])
psi_angles shape: torch.Size([315, 1])
b_factors shape: torch.Size([315, 1])
Shape of labels: torch.Size([315])
esm2_embeddings shape: torch.Size([196, 1280])
ss_onehot shape: torch.Size([196, 4])
phi_angles shape: torch.Size([196, 1])
psi_angles shape: torch.Size([196, 1])
b_factors shape: torch.Size([196, 1])
Shape of labels: torch.Size([196])


 24%|██▍       | 247/1013 [00:44<02:14,  5.69it/s]

esm2_embeddings shape: torch.Size([130, 1280])
ss_onehot shape: torch.Size([130, 4])
phi_angles shape: torch.Size([130, 1])
psi_angles shape: torch.Size([130, 1])
b_factors shape: torch.Size([130, 1])
Shape of labels: torch.Size([130])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])


 24%|██▍       | 248/1013 [00:44<02:20,  5.44it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([74, 1280])
ss_onehot shape: torch.Size([74, 4])
phi_angles shape: torch.Size([74, 1])
psi_angles shape: torch.Size([74, 1])
b_factors shape: torch.Size([74, 1])
Shape of labels: torch.Size([74])


 25%|██▍       | 252/1013 [00:45<01:38,  7.75it/s]

esm2_embeddings shape: torch.Size([284, 1280])
ss_onehot shape: torch.Size([284, 4])
phi_angles shape: torch.Size([284, 1])
psi_angles shape: torch.Size([284, 1])
b_factors shape: torch.Size([284, 1])
Shape of labels: torch.Size([284])
esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])
esm2_embeddings shape: torch.Size([61, 1280])
ss_onehot shape: torch.Size([61, 4])
phi_angles shape: torch.Size([61, 1])
psi_angles shape: torch.Size([61, 1])
b_factors shape: torch.Size([61, 1])
Shape of labels: torch.Size([61])
esm2_embeddings shape: torch.Size([193, 1280])
ss_onehot shape: torch.Size([193, 4])
phi_angles shape: torch.Size([193, 1])
psi_angles shape: torch.Size([193, 1])
b_factors shape: torch.Size([193, 1])
Shape of labels: torch.Size([193])


 25%|██▌       | 255/1013 [00:46<02:29,  5.05it/s]

esm2_embeddings shape: torch.Size([245, 1280])
ss_onehot shape: torch.Size([245, 4])
phi_angles shape: torch.Size([245, 1])
psi_angles shape: torch.Size([245, 1])
b_factors shape: torch.Size([245, 1])
Shape of labels: torch.Size([245])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])


 25%|██▌       | 257/1013 [00:46<02:23,  5.28it/s]

esm2_embeddings shape: torch.Size([241, 1280])
ss_onehot shape: torch.Size([241, 4])
phi_angles shape: torch.Size([241, 1])
psi_angles shape: torch.Size([241, 1])
b_factors shape: torch.Size([241, 1])
Shape of labels: torch.Size([241])
esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])


 26%|██▌       | 259/1013 [00:46<02:02,  6.14it/s]

esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])
esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])


 26%|██▌       | 261/1013 [00:46<01:51,  6.75it/s]

esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])
esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])


 26%|██▌       | 263/1013 [00:47<01:37,  7.71it/s]

esm2_embeddings shape: torch.Size([93, 1280])
ss_onehot shape: torch.Size([93, 4])
phi_angles shape: torch.Size([93, 1])
psi_angles shape: torch.Size([93, 1])
b_factors shape: torch.Size([93, 1])
Shape of labels: torch.Size([93])
esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])


 26%|██▌       | 264/1013 [00:47<02:05,  5.97it/s]

esm2_embeddings shape: torch.Size([324, 1280])
ss_onehot shape: torch.Size([324, 4])
phi_angles shape: torch.Size([324, 1])
psi_angles shape: torch.Size([324, 1])
b_factors shape: torch.Size([324, 1])
Shape of labels: torch.Size([324])


 26%|██▋       | 266/1013 [00:48<02:42,  4.61it/s]

esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])
esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])
esm2_embeddings shape: torch.Size([93, 1280])
ss_onehot shape: torch.Size([93, 4])
phi_angles shape: torch.Size([93, 1])
psi_angles shape: torch.Size([93, 1])
b_factors shape: torch.Size([93, 1])
Shape of labels: torch.Size([93])


 26%|██▋       | 268/1013 [00:48<02:27,  5.04it/s]

esm2_embeddings shape: torch.Size([289, 1280])
ss_onehot shape: torch.Size([289, 4])
phi_angles shape: torch.Size([289, 1])
psi_angles shape: torch.Size([289, 1])
b_factors shape: torch.Size([289, 1])
Shape of labels: torch.Size([289])
esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])


 27%|██▋       | 272/1013 [00:48<01:46,  6.96it/s]

esm2_embeddings shape: torch.Size([267, 1280])
ss_onehot shape: torch.Size([267, 4])
phi_angles shape: torch.Size([267, 1])
psi_angles shape: torch.Size([267, 1])
b_factors shape: torch.Size([267, 1])
Shape of labels: torch.Size([267])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([80, 1280])
ss_onehot shape: torch.Size([80, 4])
phi_angles shape: torch.Size([80, 1])
psi_angles shape: torch.Size([80, 1])
b_factors shape: torch.Size([80, 1])
Shape of labels: torch.Size([80])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])


 27%|██▋       | 274/1013 [00:49<01:31,  8.06it/s]

esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])


 27%|██▋       | 276/1013 [00:49<02:09,  5.71it/s]

esm2_embeddings shape: torch.Size([402, 1280])
ss_onehot shape: torch.Size([402, 4])
phi_angles shape: torch.Size([402, 1])
psi_angles shape: torch.Size([402, 1])
b_factors shape: torch.Size([402, 1])
Shape of labels: torch.Size([402])
esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])


 27%|██▋       | 277/1013 [00:50<03:09,  3.88it/s]

esm2_embeddings shape: torch.Size([231, 1280])
ss_onehot shape: torch.Size([231, 4])
phi_angles shape: torch.Size([231, 1])
psi_angles shape: torch.Size([231, 1])
b_factors shape: torch.Size([231, 1])
Shape of labels: torch.Size([231])


 28%|██▊       | 280/1013 [00:50<02:18,  5.29it/s]

esm2_embeddings shape: torch.Size([267, 1280])
ss_onehot shape: torch.Size([267, 4])
phi_angles shape: torch.Size([267, 1])
psi_angles shape: torch.Size([267, 1])
b_factors shape: torch.Size([267, 1])
Shape of labels: torch.Size([267])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])


 28%|██▊       | 283/1013 [00:51<01:43,  7.05it/s]

esm2_embeddings shape: torch.Size([246, 1280])
ss_onehot shape: torch.Size([246, 4])
phi_angles shape: torch.Size([246, 1])
psi_angles shape: torch.Size([246, 1])
b_factors shape: torch.Size([246, 1])
Shape of labels: torch.Size([246])
esm2_embeddings shape: torch.Size([68, 1280])
ss_onehot shape: torch.Size([68, 4])
phi_angles shape: torch.Size([68, 1])
psi_angles shape: torch.Size([68, 1])
b_factors shape: torch.Size([68, 1])
Shape of labels: torch.Size([68])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])


 28%|██▊       | 285/1013 [00:51<01:29,  8.14it/s]

esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])
esm2_embeddings shape: torch.Size([276, 1280])
ss_onehot shape: torch.Size([276, 4])
phi_angles shape: torch.Size([276, 1])
psi_angles shape: torch.Size([276, 1])
b_factors shape: torch.Size([276, 1])
Shape of labels: torch.Size([276])


 28%|██▊       | 287/1013 [00:51<02:05,  5.81it/s]

esm2_embeddings shape: torch.Size([283, 1280])
ss_onehot shape: torch.Size([283, 4])
phi_angles shape: torch.Size([283, 1])
psi_angles shape: torch.Size([283, 1])
b_factors shape: torch.Size([283, 1])
Shape of labels: torch.Size([283])
esm2_embeddings shape: torch.Size([79, 1280])
ss_onehot shape: torch.Size([79, 4])
phi_angles shape: torch.Size([79, 1])
psi_angles shape: torch.Size([79, 1])
b_factors shape: torch.Size([79, 1])
Shape of labels: torch.Size([79])


 29%|██▊       | 289/1013 [00:52<02:34,  4.69it/s]

esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])
esm2_embeddings shape: torch.Size([34, 1280])
ss_onehot shape: torch.Size([34, 4])
phi_angles shape: torch.Size([34, 1])
psi_angles shape: torch.Size([34, 1])
b_factors shape: torch.Size([34, 1])
Shape of labels: torch.Size([34])
esm2_embeddings shape: torch.Size([81, 1280])
ss_onehot shape: torch.Size([81, 4])
phi_angles shape: torch.Size([81, 1])
psi_angles shape: torch.Size([81, 1])
b_factors shape: torch.Size([81, 1])
Shape of labels: torch.Size([81])


 29%|██▉       | 293/1013 [00:52<01:55,  6.26it/s]

esm2_embeddings shape: torch.Size([254, 1280])
ss_onehot shape: torch.Size([254, 4])
phi_angles shape: torch.Size([254, 1])
psi_angles shape: torch.Size([254, 1])
b_factors shape: torch.Size([254, 1])
Shape of labels: torch.Size([254])
esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])


 29%|██▉       | 296/1013 [00:53<01:39,  7.18it/s]

esm2_embeddings shape: torch.Size([227, 1280])
ss_onehot shape: torch.Size([227, 4])
phi_angles shape: torch.Size([227, 1])
psi_angles shape: torch.Size([227, 1])
b_factors shape: torch.Size([227, 1])
Shape of labels: torch.Size([227])
esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])
esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])


 29%|██▉       | 298/1013 [00:53<01:36,  7.41it/s]

esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([194, 1280])
ss_onehot shape: torch.Size([194, 4])
phi_angles shape: torch.Size([194, 1])
psi_angles shape: torch.Size([194, 1])
b_factors shape: torch.Size([194, 1])
Shape of labels: torch.Size([194])


 30%|██▉       | 300/1013 [00:53<01:35,  7.44it/s]

esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])
esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])


 30%|██▉       | 301/1013 [00:53<01:37,  7.32it/s]

esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])
esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])


 30%|███       | 304/1013 [00:54<02:25,  4.86it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([176, 1280])
ss_onehot shape: torch.Size([176, 4])
phi_angles shape: torch.Size([176, 1])
psi_angles shape: torch.Size([176, 1])
b_factors shape: torch.Size([176, 1])
Shape of labels: torch.Size([176])


 30%|███       | 305/1013 [00:54<02:13,  5.32it/s]

esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])


 30%|███       | 308/1013 [00:55<01:50,  6.41it/s]

esm2_embeddings shape: torch.Size([285, 1280])
ss_onehot shape: torch.Size([285, 4])
phi_angles shape: torch.Size([285, 1])
psi_angles shape: torch.Size([285, 1])
b_factors shape: torch.Size([285, 1])
Shape of labels: torch.Size([285])
esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])
esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])


 31%|███       | 310/1013 [00:55<01:46,  6.59it/s]

esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])
esm2_embeddings shape: torch.Size([196, 1280])
ss_onehot shape: torch.Size([196, 4])
phi_angles shape: torch.Size([196, 1])
psi_angles shape: torch.Size([196, 1])
b_factors shape: torch.Size([196, 1])
Shape of labels: torch.Size([196])
esm2_embeddings shape: torch.Size([96, 1280])
ss_onehot shape: torch.Size([96, 4])
phi_angles shape: torch.Size([96, 1])
psi_angles shape: torch.Size([96, 1])
b_factors shape: torch.Size([96, 1])
Shape of labels: torch.Size([96])


 31%|███       | 313/1013 [00:55<01:40,  6.96it/s]

esm2_embeddings shape: torch.Size([211, 1280])
ss_onehot shape: torch.Size([211, 4])
phi_angles shape: torch.Size([211, 1])
psi_angles shape: torch.Size([211, 1])
b_factors shape: torch.Size([211, 1])
Shape of labels: torch.Size([211])
esm2_embeddings shape: torch.Size([177, 1280])
ss_onehot shape: torch.Size([177, 4])
phi_angles shape: torch.Size([177, 1])
psi_angles shape: torch.Size([177, 1])
b_factors shape: torch.Size([177, 1])
Shape of labels: torch.Size([177])


 31%|███       | 315/1013 [00:56<02:36,  4.46it/s]

esm2_embeddings shape: torch.Size([220, 1280])
ss_onehot shape: torch.Size([220, 4])
phi_angles shape: torch.Size([220, 1])
psi_angles shape: torch.Size([220, 1])
b_factors shape: torch.Size([220, 1])
Shape of labels: torch.Size([220])
esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])


 31%|███▏      | 317/1013 [00:56<02:13,  5.22it/s]

esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])
esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])


 32%|███▏      | 320/1013 [00:57<01:29,  7.76it/s]

esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])


 32%|███▏      | 321/1013 [00:57<01:35,  7.26it/s]

esm2_embeddings shape: torch.Size([194, 1280])
ss_onehot shape: torch.Size([194, 4])
phi_angles shape: torch.Size([194, 1])
psi_angles shape: torch.Size([194, 1])
b_factors shape: torch.Size([194, 1])
Shape of labels: torch.Size([194])
esm2_embeddings shape: torch.Size([61, 1280])
ss_onehot shape: torch.Size([61, 4])
phi_angles shape: torch.Size([61, 1])
psi_angles shape: torch.Size([61, 1])
b_factors shape: torch.Size([61, 1])
Shape of labels: torch.Size([61])


 32%|███▏      | 324/1013 [00:57<01:27,  7.83it/s]

esm2_embeddings shape: torch.Size([210, 1280])
ss_onehot shape: torch.Size([210, 4])
phi_angles shape: torch.Size([210, 1])
psi_angles shape: torch.Size([210, 1])
b_factors shape: torch.Size([210, 1])
Shape of labels: torch.Size([210])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])
esm2_embeddings shape: torch.Size([70, 1280])
ss_onehot shape: torch.Size([70, 4])
phi_angles shape: torch.Size([70, 1])
psi_angles shape: torch.Size([70, 1])
b_factors shape: torch.Size([70, 1])
Shape of labels: torch.Size([70])


 32%|███▏      | 328/1013 [00:58<01:10,  9.78it/s]

esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])
esm2_embeddings shape: torch.Size([155, 1280])
ss_onehot shape: torch.Size([155, 4])
phi_angles shape: torch.Size([155, 1])
psi_angles shape: torch.Size([155, 1])
b_factors shape: torch.Size([155, 1])
Shape of labels: torch.Size([155])
esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])


 33%|███▎      | 330/1013 [00:58<01:50,  6.17it/s]

esm2_embeddings shape: torch.Size([120, 1280])
ss_onehot shape: torch.Size([120, 4])
phi_angles shape: torch.Size([120, 1])
psi_angles shape: torch.Size([120, 1])
b_factors shape: torch.Size([120, 1])
Shape of labels: torch.Size([120])


 33%|███▎      | 332/1013 [00:59<01:56,  5.82it/s]

esm2_embeddings shape: torch.Size([279, 1280])
ss_onehot shape: torch.Size([279, 4])
phi_angles shape: torch.Size([279, 1])
psi_angles shape: torch.Size([279, 1])
b_factors shape: torch.Size([279, 1])
Shape of labels: torch.Size([279])
esm2_embeddings shape: torch.Size([172, 1280])
ss_onehot shape: torch.Size([172, 4])
phi_angles shape: torch.Size([172, 1])
psi_angles shape: torch.Size([172, 1])
b_factors shape: torch.Size([172, 1])
Shape of labels: torch.Size([172])


 33%|███▎      | 336/1013 [00:59<01:20,  8.42it/s]

esm2_embeddings shape: torch.Size([286, 1280])
ss_onehot shape: torch.Size([286, 4])
phi_angles shape: torch.Size([286, 1])
psi_angles shape: torch.Size([286, 1])
b_factors shape: torch.Size([286, 1])
Shape of labels: torch.Size([286])
esm2_embeddings shape: torch.Size([84, 1280])
ss_onehot shape: torch.Size([84, 4])
phi_angles shape: torch.Size([84, 1])
psi_angles shape: torch.Size([84, 1])
b_factors shape: torch.Size([84, 1])
Shape of labels: torch.Size([84])
esm2_embeddings shape: torch.Size([45, 1280])
ss_onehot shape: torch.Size([45, 4])
phi_angles shape: torch.Size([45, 1])
psi_angles shape: torch.Size([45, 1])
b_factors shape: torch.Size([45, 1])
Shape of labels: torch.Size([45])
esm2_embeddings shape: torch.Size([38, 1280])
ss_onehot shape: torch.Size([38, 4])
phi_angles shape: torch.Size([38, 1])
psi_angles shape: torch.Size([38, 1])
b_factors shape: torch.Size([38, 1])
Shape of labels: torch.Size([38])
esm2_embeddings shape: torch.Size([37, 1280])
ss_onehot shape: torch.Size(

 34%|███▎      | 341/1013 [00:59<00:54, 12.28it/s]

esm2_embeddings shape: torch.Size([36, 1280])
ss_onehot shape: torch.Size([36, 4])
phi_angles shape: torch.Size([36, 1])
psi_angles shape: torch.Size([36, 1])
b_factors shape: torch.Size([36, 1])
Shape of labels: torch.Size([36])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([85, 1280])
ss_onehot shape: torch.Size([85, 4])
phi_angles shape: torch.Size([85, 1])
psi_angles shape: torch.Size([85, 1])
b_factors shape: torch.Size([85, 1])
Shape of labels: torch.Size([85])


 34%|███▍      | 343/1013 [00:59<01:00, 11.09it/s]

esm2_embeddings shape: torch.Size([116, 1280])
ss_onehot shape: torch.Size([116, 4])
phi_angles shape: torch.Size([116, 1])
psi_angles shape: torch.Size([116, 1])
b_factors shape: torch.Size([116, 1])
Shape of labels: torch.Size([116])
esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])


 34%|███▍      | 345/1013 [01:00<01:24,  7.94it/s]

esm2_embeddings shape: torch.Size([306, 1280])
ss_onehot shape: torch.Size([306, 4])
phi_angles shape: torch.Size([306, 1])
psi_angles shape: torch.Size([306, 1])
b_factors shape: torch.Size([306, 1])
Shape of labels: torch.Size([306])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])


 34%|███▍      | 347/1013 [01:00<01:51,  5.96it/s]

esm2_embeddings shape: torch.Size([66, 1280])
ss_onehot shape: torch.Size([66, 4])
phi_angles shape: torch.Size([66, 1])
psi_angles shape: torch.Size([66, 1])
b_factors shape: torch.Size([66, 1])
Shape of labels: torch.Size([66])
esm2_embeddings shape: torch.Size([80, 1280])
ss_onehot shape: torch.Size([80, 4])
phi_angles shape: torch.Size([80, 1])
psi_angles shape: torch.Size([80, 1])
b_factors shape: torch.Size([80, 1])
Shape of labels: torch.Size([80])
esm2_embeddings shape: torch.Size([82, 1280])
ss_onehot shape: torch.Size([82, 4])
phi_angles shape: torch.Size([82, 1])
psi_angles shape: torch.Size([82, 1])
b_factors shape: torch.Size([82, 1])
Shape of labels: torch.Size([82])


 34%|███▍      | 349/1013 [01:01<01:43,  6.42it/s]

esm2_embeddings shape: torch.Size([226, 1280])
ss_onehot shape: torch.Size([226, 4])
phi_angles shape: torch.Size([226, 1])
psi_angles shape: torch.Size([226, 1])
b_factors shape: torch.Size([226, 1])
Shape of labels: torch.Size([226])


 35%|███▍      | 351/1013 [01:01<02:08,  5.14it/s]

esm2_embeddings shape: torch.Size([374, 1280])
ss_onehot shape: torch.Size([374, 4])
phi_angles shape: torch.Size([374, 1])
psi_angles shape: torch.Size([374, 1])
b_factors shape: torch.Size([374, 1])
Shape of labels: torch.Size([374])
esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])


 35%|███▍      | 354/1013 [01:02<01:44,  6.33it/s]

esm2_embeddings shape: torch.Size([273, 1280])
ss_onehot shape: torch.Size([273, 4])
phi_angles shape: torch.Size([273, 1])
psi_angles shape: torch.Size([273, 1])
b_factors shape: torch.Size([273, 1])
Shape of labels: torch.Size([273])
esm2_embeddings shape: torch.Size([67, 1280])
ss_onehot shape: torch.Size([67, 4])
phi_angles shape: torch.Size([67, 1])
psi_angles shape: torch.Size([67, 1])
b_factors shape: torch.Size([67, 1])
Shape of labels: torch.Size([67])
esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])


 35%|███▌      | 356/1013 [01:02<01:36,  6.83it/s]

esm2_embeddings shape: torch.Size([154, 1280])
ss_onehot shape: torch.Size([154, 4])
phi_angles shape: torch.Size([154, 1])
psi_angles shape: torch.Size([154, 1])
b_factors shape: torch.Size([154, 1])
Shape of labels: torch.Size([154])
esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])


 35%|███▌      | 357/1013 [01:03<03:11,  3.43it/s]

esm2_embeddings shape: torch.Size([295, 1280])
ss_onehot shape: torch.Size([295, 4])
phi_angles shape: torch.Size([295, 1])
psi_angles shape: torch.Size([295, 1])
b_factors shape: torch.Size([295, 1])
Shape of labels: torch.Size([295])


 35%|███▌      | 358/1013 [01:03<03:00,  3.63it/s]

esm2_embeddings shape: torch.Size([260, 1280])
ss_onehot shape: torch.Size([260, 4])
phi_angles shape: torch.Size([260, 1])
psi_angles shape: torch.Size([260, 1])
b_factors shape: torch.Size([260, 1])
Shape of labels: torch.Size([260])
esm2_embeddings shape: torch.Size([118, 1280])
ss_onehot shape: torch.Size([118, 4])
phi_angles shape: torch.Size([118, 1])
psi_angles shape: torch.Size([118, 1])
b_factors shape: torch.Size([118, 1])
Shape of labels: torch.Size([118])


 36%|███▌      | 360/1013 [01:03<02:22,  4.59it/s]

esm2_embeddings shape: torch.Size([232, 1280])
ss_onehot shape: torch.Size([232, 4])
phi_angles shape: torch.Size([232, 1])
psi_angles shape: torch.Size([232, 1])
b_factors shape: torch.Size([232, 1])
Shape of labels: torch.Size([232])


 36%|███▌      | 363/1013 [01:04<01:50,  5.86it/s]

esm2_embeddings shape: torch.Size([270, 1280])
ss_onehot shape: torch.Size([270, 4])
phi_angles shape: torch.Size([270, 1])
psi_angles shape: torch.Size([270, 1])
b_factors shape: torch.Size([270, 1])
Shape of labels: torch.Size([270])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 36%|███▌      | 364/1013 [01:04<02:27,  4.41it/s]

esm2_embeddings shape: torch.Size([385, 1280])
ss_onehot shape: torch.Size([385, 4])
phi_angles shape: torch.Size([385, 1])
psi_angles shape: torch.Size([385, 1])
b_factors shape: torch.Size([385, 1])
Shape of labels: torch.Size([385])


 36%|███▌      | 365/1013 [01:04<02:27,  4.39it/s]

esm2_embeddings shape: torch.Size([256, 1280])
ss_onehot shape: torch.Size([256, 4])
phi_angles shape: torch.Size([256, 1])
psi_angles shape: torch.Size([256, 1])
b_factors shape: torch.Size([256, 1])
Shape of labels: torch.Size([256])
esm2_embeddings shape: torch.Size([85, 1280])
ss_onehot shape: torch.Size([85, 4])
phi_angles shape: torch.Size([85, 1])
psi_angles shape: torch.Size([85, 1])
b_factors shape: torch.Size([85, 1])
Shape of labels: torch.Size([85])


 36%|███▋      | 368/1013 [01:05<02:32,  4.22it/s]

esm2_embeddings shape: torch.Size([193, 1280])
ss_onehot shape: torch.Size([193, 4])
phi_angles shape: torch.Size([193, 1])
psi_angles shape: torch.Size([193, 1])
b_factors shape: torch.Size([193, 1])
Shape of labels: torch.Size([193])
esm2_embeddings shape: torch.Size([180, 1280])
ss_onehot shape: torch.Size([180, 4])
phi_angles shape: torch.Size([180, 1])
psi_angles shape: torch.Size([180, 1])
b_factors shape: torch.Size([180, 1])
Shape of labels: torch.Size([180])


 37%|███▋      | 370/1013 [01:05<01:55,  5.55it/s]

esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])
esm2_embeddings shape: torch.Size([158, 1280])
ss_onehot shape: torch.Size([158, 4])
phi_angles shape: torch.Size([158, 1])
psi_angles shape: torch.Size([158, 1])
b_factors shape: torch.Size([158, 1])
Shape of labels: torch.Size([158])


 37%|███▋      | 374/1013 [01:06<01:25,  7.47it/s]

esm2_embeddings shape: torch.Size([290, 1280])
ss_onehot shape: torch.Size([290, 4])
phi_angles shape: torch.Size([290, 1])
psi_angles shape: torch.Size([290, 1])
b_factors shape: torch.Size([290, 1])
Shape of labels: torch.Size([290])
esm2_embeddings shape: torch.Size([69, 1280])
ss_onehot shape: torch.Size([69, 4])
phi_angles shape: torch.Size([69, 1])
psi_angles shape: torch.Size([69, 1])
b_factors shape: torch.Size([69, 1])
Shape of labels: torch.Size([69])
esm2_embeddings shape: torch.Size([65, 1280])
ss_onehot shape: torch.Size([65, 4])
phi_angles shape: torch.Size([65, 1])
psi_angles shape: torch.Size([65, 1])
b_factors shape: torch.Size([65, 1])
Shape of labels: torch.Size([65])
esm2_embeddings shape: torch.Size([112, 1280])
ss_onehot shape: torch.Size([112, 4])
phi_angles shape: torch.Size([112, 1])
psi_angles shape: torch.Size([112, 1])
b_factors shape: torch.Size([112, 1])
Shape of labels: torch.Size([112])


 37%|███▋      | 376/1013 [01:06<01:16,  8.32it/s]

esm2_embeddings shape: torch.Size([109, 1280])
ss_onehot shape: torch.Size([109, 4])
phi_angles shape: torch.Size([109, 1])
psi_angles shape: torch.Size([109, 1])
b_factors shape: torch.Size([109, 1])
Shape of labels: torch.Size([109])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])
esm2_embeddings shape: torch.Size([67, 1280])
ss_onehot shape: torch.Size([67, 4])
phi_angles shape: torch.Size([67, 1])
psi_angles shape: torch.Size([67, 1])
b_factors shape: torch.Size([67, 1])
Shape of labels: torch.Size([67])


 37%|███▋      | 378/1013 [01:06<01:03,  9.95it/s]

esm2_embeddings shape: torch.Size([100, 1280])
ss_onehot shape: torch.Size([100, 4])
phi_angles shape: torch.Size([100, 1])
psi_angles shape: torch.Size([100, 1])
b_factors shape: torch.Size([100, 1])
Shape of labels: torch.Size([100])


 38%|███▊      | 380/1013 [01:06<01:37,  6.46it/s]

esm2_embeddings shape: torch.Size([354, 1280])
ss_onehot shape: torch.Size([354, 4])
phi_angles shape: torch.Size([354, 1])
psi_angles shape: torch.Size([354, 1])
b_factors shape: torch.Size([354, 1])
Shape of labels: torch.Size([354])
esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])


 38%|███▊      | 383/1013 [01:07<01:52,  5.58it/s]

esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])
esm2_embeddings shape: torch.Size([70, 1280])
ss_onehot shape: torch.Size([70, 4])
phi_angles shape: torch.Size([70, 1])
psi_angles shape: torch.Size([70, 1])
b_factors shape: torch.Size([70, 1])
Shape of labels: torch.Size([70])
esm2_embeddings shape: torch.Size([132, 1280])
ss_onehot shape: torch.Size([132, 4])
phi_angles shape: torch.Size([132, 1])
psi_angles shape: torch.Size([132, 1])
b_factors shape: torch.Size([132, 1])
Shape of labels: torch.Size([132])


 38%|███▊      | 384/1013 [01:07<01:43,  6.09it/s]

esm2_embeddings shape: torch.Size([132, 1280])
ss_onehot shape: torch.Size([132, 4])
phi_angles shape: torch.Size([132, 1])
psi_angles shape: torch.Size([132, 1])
b_factors shape: torch.Size([132, 1])
Shape of labels: torch.Size([132])
esm2_embeddings shape: torch.Size([50, 1280])
ss_onehot shape: torch.Size([50, 4])
phi_angles shape: torch.Size([50, 1])
psi_angles shape: torch.Size([50, 1])
b_factors shape: torch.Size([50, 1])
Shape of labels: torch.Size([50])
esm2_embeddings shape: torch.Size([69, 1280])
ss_onehot shape: torch.Size([69, 4])
phi_angles shape: torch.Size([69, 1])
psi_angles shape: torch.Size([69, 1])
b_factors shape: torch.Size([69, 1])
Shape of labels: torch.Size([69])


 38%|███▊      | 389/1013 [01:08<01:09,  8.97it/s]

esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])
esm2_embeddings shape: torch.Size([57, 1280])
ss_onehot shape: torch.Size([57, 4])
phi_angles shape: torch.Size([57, 1])
psi_angles shape: torch.Size([57, 1])
b_factors shape: torch.Size([57, 1])
Shape of labels: torch.Size([57])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


 39%|███▊      | 391/1013 [01:08<01:06,  9.42it/s]

esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])
esm2_embeddings shape: torch.Size([256, 1280])
ss_onehot shape: torch.Size([256, 4])
phi_angles shape: torch.Size([256, 1])
psi_angles shape: torch.Size([256, 1])
b_factors shape: torch.Size([256, 1])
Shape of labels: torch.Size([256])


 39%|███▉      | 395/1013 [01:09<01:17,  7.94it/s]

esm2_embeddings shape: torch.Size([295, 1280])
ss_onehot shape: torch.Size([295, 4])
phi_angles shape: torch.Size([295, 1])
psi_angles shape: torch.Size([295, 1])
b_factors shape: torch.Size([295, 1])
Shape of labels: torch.Size([295])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([54, 1280])
ss_onehot shape: torch.Size([54, 4])
phi_angles shape: torch.Size([54, 1])
psi_angles shape: torch.Size([54, 1])
b_factors shape: torch.Size([54, 1])
Shape of labels: torch.Size([54])
esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])


 39%|███▉      | 397/1013 [01:09<01:58,  5.21it/s]

esm2_embeddings shape: torch.Size([226, 1280])
ss_onehot shape: torch.Size([226, 4])
phi_angles shape: torch.Size([226, 1])
psi_angles shape: torch.Size([226, 1])
b_factors shape: torch.Size([226, 1])
Shape of labels: torch.Size([226])


 39%|███▉      | 398/1013 [01:09<01:59,  5.13it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])


 40%|███▉      | 401/1013 [01:10<01:50,  5.55it/s]

esm2_embeddings shape: torch.Size([385, 1280])
ss_onehot shape: torch.Size([385, 4])
phi_angles shape: torch.Size([385, 1])
psi_angles shape: torch.Size([385, 1])
b_factors shape: torch.Size([385, 1])
Shape of labels: torch.Size([385])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])
esm2_embeddings shape: torch.Size([53, 1280])
ss_onehot shape: torch.Size([53, 4])
phi_angles shape: torch.Size([53, 1])
psi_angles shape: torch.Size([53, 1])
b_factors shape: torch.Size([53, 1])
Shape of labels: torch.Size([53])
esm2_embeddings shape: torch.Size([41, 1280])
ss_onehot shape: torch.Size([41, 4])
phi_angles shape: torch.Size([41, 1])
psi_angles shape: torch.Size([41, 1])
b_factors shape: torch.Size([41, 1])
Shape of labels: torch.Size([41])


 40%|███▉      | 403/1013 [01:10<01:29,  6.82it/s]

esm2_embeddings shape: torch.Size([177, 1280])
ss_onehot shape: torch.Size([177, 4])
phi_angles shape: torch.Size([177, 1])
psi_angles shape: torch.Size([177, 1])
b_factors shape: torch.Size([177, 1])
Shape of labels: torch.Size([177])


 40%|███▉      | 404/1013 [01:10<01:38,  6.18it/s]

esm2_embeddings shape: torch.Size([251, 1280])
ss_onehot shape: torch.Size([251, 4])
phi_angles shape: torch.Size([251, 1])
psi_angles shape: torch.Size([251, 1])
b_factors shape: torch.Size([251, 1])
Shape of labels: torch.Size([251])


 40%|███▉      | 405/1013 [01:11<01:48,  5.60it/s]

esm2_embeddings shape: torch.Size([260, 1280])
ss_onehot shape: torch.Size([260, 4])
phi_angles shape: torch.Size([260, 1])
psi_angles shape: torch.Size([260, 1])
b_factors shape: torch.Size([260, 1])
Shape of labels: torch.Size([260])


 40%|████      | 408/1013 [01:12<02:27,  4.10it/s]

esm2_embeddings shape: torch.Size([417, 1280])
ss_onehot shape: torch.Size([417, 4])
phi_angles shape: torch.Size([417, 1])
psi_angles shape: torch.Size([417, 1])
b_factors shape: torch.Size([417, 1])
Shape of labels: torch.Size([417])
esm2_embeddings shape: torch.Size([106, 1280])
ss_onehot shape: torch.Size([106, 4])
phi_angles shape: torch.Size([106, 1])
psi_angles shape: torch.Size([106, 1])
b_factors shape: torch.Size([106, 1])
Shape of labels: torch.Size([106])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])


 40%|████      | 410/1013 [01:12<02:09,  4.66it/s]

esm2_embeddings shape: torch.Size([263, 1280])
ss_onehot shape: torch.Size([263, 4])
phi_angles shape: torch.Size([263, 1])
psi_angles shape: torch.Size([263, 1])
b_factors shape: torch.Size([263, 1])
Shape of labels: torch.Size([263])
esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])


 41%|████      | 411/1013 [01:12<01:53,  5.31it/s]

esm2_embeddings shape: torch.Size([137, 1280])
ss_onehot shape: torch.Size([137, 4])
phi_angles shape: torch.Size([137, 1])
psi_angles shape: torch.Size([137, 1])
b_factors shape: torch.Size([137, 1])
Shape of labels: torch.Size([137])


 41%|████      | 412/1013 [01:13<02:48,  3.58it/s]

esm2_embeddings shape: torch.Size([438, 1280])
ss_onehot shape: torch.Size([438, 4])
phi_angles shape: torch.Size([438, 1])
psi_angles shape: torch.Size([438, 1])
b_factors shape: torch.Size([438, 1])
Shape of labels: torch.Size([438])


 41%|████      | 413/1013 [01:13<03:34,  2.79it/s]

esm2_embeddings shape: torch.Size([189, 1280])
ss_onehot shape: torch.Size([189, 4])
phi_angles shape: torch.Size([189, 1])
psi_angles shape: torch.Size([189, 1])
b_factors shape: torch.Size([189, 1])
Shape of labels: torch.Size([189])


 41%|████      | 414/1013 [01:14<03:37,  2.76it/s]

esm2_embeddings shape: torch.Size([350, 1280])
ss_onehot shape: torch.Size([350, 4])
phi_angles shape: torch.Size([350, 1])
psi_angles shape: torch.Size([350, 1])
b_factors shape: torch.Size([350, 1])
Shape of labels: torch.Size([350])


 41%|████      | 416/1013 [01:14<02:55,  3.39it/s]

esm2_embeddings shape: torch.Size([321, 1280])
ss_onehot shape: torch.Size([321, 4])
phi_angles shape: torch.Size([321, 1])
psi_angles shape: torch.Size([321, 1])
b_factors shape: torch.Size([321, 1])
Shape of labels: torch.Size([321])
esm2_embeddings shape: torch.Size([194, 1280])
ss_onehot shape: torch.Size([194, 4])
phi_angles shape: torch.Size([194, 1])
psi_angles shape: torch.Size([194, 1])
b_factors shape: torch.Size([194, 1])
Shape of labels: torch.Size([194])


 41%|████▏     | 419/1013 [01:14<01:34,  6.30it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])
esm2_embeddings shape: torch.Size([57, 1280])
ss_onehot shape: torch.Size([57, 4])
phi_angles shape: torch.Size([57, 1])
psi_angles shape: torch.Size([57, 1])
b_factors shape: torch.Size([57, 1])
Shape of labels: torch.Size([57])
esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])


 42%|████▏     | 422/1013 [01:15<01:14,  7.96it/s]

esm2_embeddings shape: torch.Size([218, 1280])
ss_onehot shape: torch.Size([218, 4])
phi_angles shape: torch.Size([218, 1])
psi_angles shape: torch.Size([218, 1])
b_factors shape: torch.Size([218, 1])
Shape of labels: torch.Size([218])
esm2_embeddings shape: torch.Size([62, 1280])
ss_onehot shape: torch.Size([62, 4])
phi_angles shape: torch.Size([62, 1])
psi_angles shape: torch.Size([62, 1])
b_factors shape: torch.Size([62, 1])
Shape of labels: torch.Size([62])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 42%|████▏     | 425/1013 [01:15<01:49,  5.38it/s]

esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])
esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])


 42%|████▏     | 427/1013 [01:16<01:40,  5.81it/s]

esm2_embeddings shape: torch.Size([186, 1280])
ss_onehot shape: torch.Size([186, 4])
phi_angles shape: torch.Size([186, 1])
psi_angles shape: torch.Size([186, 1])
b_factors shape: torch.Size([186, 1])
Shape of labels: torch.Size([186])
esm2_embeddings shape: torch.Size([186, 1280])
ss_onehot shape: torch.Size([186, 4])
phi_angles shape: torch.Size([186, 1])
psi_angles shape: torch.Size([186, 1])
b_factors shape: torch.Size([186, 1])
Shape of labels: torch.Size([186])


 42%|████▏     | 429/1013 [01:16<01:24,  6.92it/s]

esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])
esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])


 42%|████▏     | 430/1013 [01:16<01:26,  6.74it/s]

esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])
esm2_embeddings shape: torch.Size([73, 1280])
ss_onehot shape: torch.Size([73, 4])
phi_angles shape: torch.Size([73, 1])
psi_angles shape: torch.Size([73, 1])
b_factors shape: torch.Size([73, 1])
Shape of labels: torch.Size([73])


 43%|████▎     | 434/1013 [01:17<01:10,  8.23it/s]

esm2_embeddings shape: torch.Size([224, 1280])
ss_onehot shape: torch.Size([224, 4])
phi_angles shape: torch.Size([224, 1])
psi_angles shape: torch.Size([224, 1])
b_factors shape: torch.Size([224, 1])
Shape of labels: torch.Size([224])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])


 43%|████▎     | 436/1013 [01:18<02:37,  3.67it/s]

esm2_embeddings shape: torch.Size([465, 1280])
ss_onehot shape: torch.Size([465, 4])
phi_angles shape: torch.Size([465, 1])
psi_angles shape: torch.Size([465, 1])
b_factors shape: torch.Size([465, 1])
Shape of labels: torch.Size([465])
esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])


 43%|████▎     | 440/1013 [01:18<01:41,  5.63it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([70, 1280])
ss_onehot shape: torch.Size([70, 4])
phi_angles shape: torch.Size([70, 1])
psi_angles shape: torch.Size([70, 1])
b_factors shape: torch.Size([70, 1])
Shape of labels: torch.Size([70])
esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])


 44%|████▎     | 442/1013 [01:18<01:30,  6.34it/s]

esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])
esm2_embeddings shape: torch.Size([192, 1280])
ss_onehot shape: torch.Size([192, 4])
phi_angles shape: torch.Size([192, 1])
psi_angles shape: torch.Size([192, 1])
b_factors shape: torch.Size([192, 1])
Shape of labels: torch.Size([192])


 44%|████▍     | 444/1013 [01:19<01:25,  6.69it/s]

esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])
esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])


 44%|████▍     | 446/1013 [01:19<01:14,  7.57it/s]

esm2_embeddings shape: torch.Size([114, 1280])
ss_onehot shape: torch.Size([114, 4])
phi_angles shape: torch.Size([114, 1])
psi_angles shape: torch.Size([114, 1])
b_factors shape: torch.Size([114, 1])
Shape of labels: torch.Size([114])
esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])


 44%|████▍     | 449/1013 [01:19<01:03,  8.93it/s]

esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])
esm2_embeddings shape: torch.Size([60, 1280])
ss_onehot shape: torch.Size([60, 4])
phi_angles shape: torch.Size([60, 1])
psi_angles shape: torch.Size([60, 1])
b_factors shape: torch.Size([60, 1])
Shape of labels: torch.Size([60])


 45%|████▍     | 451/1013 [01:20<01:57,  4.78it/s]

esm2_embeddings shape: torch.Size([289, 1280])
ss_onehot shape: torch.Size([289, 4])
phi_angles shape: torch.Size([289, 1])
psi_angles shape: torch.Size([289, 1])
b_factors shape: torch.Size([289, 1])
Shape of labels: torch.Size([289])


 45%|████▍     | 452/1013 [01:20<02:04,  4.49it/s]

esm2_embeddings shape: torch.Size([286, 1280])
ss_onehot shape: torch.Size([286, 4])
phi_angles shape: torch.Size([286, 1])
psi_angles shape: torch.Size([286, 1])
b_factors shape: torch.Size([286, 1])
Shape of labels: torch.Size([286])
esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])


 45%|████▍     | 454/1013 [01:20<01:52,  4.95it/s]

esm2_embeddings shape: torch.Size([266, 1280])
ss_onehot shape: torch.Size([266, 4])
phi_angles shape: torch.Size([266, 1])
psi_angles shape: torch.Size([266, 1])
b_factors shape: torch.Size([266, 1])
Shape of labels: torch.Size([266])
esm2_embeddings shape: torch.Size([93, 1280])
ss_onehot shape: torch.Size([93, 4])
phi_angles shape: torch.Size([93, 1])
psi_angles shape: torch.Size([93, 1])
b_factors shape: torch.Size([93, 1])
Shape of labels: torch.Size([93])


 45%|████▌     | 456/1013 [01:21<01:33,  5.93it/s]

esm2_embeddings shape: torch.Size([180, 1280])
ss_onehot shape: torch.Size([180, 4])
phi_angles shape: torch.Size([180, 1])
psi_angles shape: torch.Size([180, 1])
b_factors shape: torch.Size([180, 1])
Shape of labels: torch.Size([180])
esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])


 45%|████▌     | 459/1013 [01:21<01:21,  6.77it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])


 46%|████▌     | 461/1013 [01:22<01:52,  4.93it/s]

esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])


 46%|████▌     | 462/1013 [01:22<02:04,  4.43it/s]

esm2_embeddings shape: torch.Size([317, 1280])
ss_onehot shape: torch.Size([317, 4])
phi_angles shape: torch.Size([317, 1])
psi_angles shape: torch.Size([317, 1])
b_factors shape: torch.Size([317, 1])
Shape of labels: torch.Size([317])


 46%|████▌     | 464/1013 [01:22<02:01,  4.53it/s]

esm2_embeddings shape: torch.Size([345, 1280])
ss_onehot shape: torch.Size([345, 4])
phi_angles shape: torch.Size([345, 1])
psi_angles shape: torch.Size([345, 1])
b_factors shape: torch.Size([345, 1])
Shape of labels: torch.Size([345])
esm2_embeddings shape: torch.Size([130, 1280])
ss_onehot shape: torch.Size([130, 4])
phi_angles shape: torch.Size([130, 1])
psi_angles shape: torch.Size([130, 1])
b_factors shape: torch.Size([130, 1])
Shape of labels: torch.Size([130])
esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])


 46%|████▌     | 468/1013 [01:23<01:23,  6.54it/s]

esm2_embeddings shape: torch.Size([274, 1280])
ss_onehot shape: torch.Size([274, 4])
phi_angles shape: torch.Size([274, 1])
psi_angles shape: torch.Size([274, 1])
b_factors shape: torch.Size([274, 1])
Shape of labels: torch.Size([274])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])


 46%|████▋     | 469/1013 [01:23<01:21,  6.69it/s]

esm2_embeddings shape: torch.Size([177, 1280])
ss_onehot shape: torch.Size([177, 4])
phi_angles shape: torch.Size([177, 1])
psi_angles shape: torch.Size([177, 1])
b_factors shape: torch.Size([177, 1])
Shape of labels: torch.Size([177])


 46%|████▋     | 470/1013 [01:23<01:28,  6.16it/s]

esm2_embeddings shape: torch.Size([240, 1280])
ss_onehot shape: torch.Size([240, 4])
phi_angles shape: torch.Size([240, 1])
psi_angles shape: torch.Size([240, 1])
b_factors shape: torch.Size([240, 1])
Shape of labels: torch.Size([240])


 47%|████▋     | 472/1013 [01:24<02:17,  3.94it/s]

esm2_embeddings shape: torch.Size([289, 1280])
ss_onehot shape: torch.Size([289, 4])
phi_angles shape: torch.Size([289, 1])
psi_angles shape: torch.Size([289, 1])
b_factors shape: torch.Size([289, 1])
Shape of labels: torch.Size([289])
esm2_embeddings shape: torch.Size([172, 1280])
ss_onehot shape: torch.Size([172, 4])
phi_angles shape: torch.Size([172, 1])
psi_angles shape: torch.Size([172, 1])
b_factors shape: torch.Size([172, 1])
Shape of labels: torch.Size([172])
esm2_embeddings shape: torch.Size([37, 1280])
ss_onehot shape: torch.Size([37, 4])
phi_angles shape: torch.Size([37, 1])
psi_angles shape: torch.Size([37, 1])
b_factors shape: torch.Size([37, 1])
Shape of labels: torch.Size([37])
esm2_embeddings shape: torch.Size([46, 1280])
ss_onehot shape: torch.Size([46, 4])
phi_angles shape: torch.Size([46, 1])
psi_angles shape: torch.Size([46, 1])
b_factors shape: torch.Size([46, 1])


 47%|████▋     | 476/1013 [01:24<01:05,  8.23it/s]

Shape of labels: torch.Size([46])
esm2_embeddings shape: torch.Size([44, 1280])
ss_onehot shape: torch.Size([44, 4])
phi_angles shape: torch.Size([44, 1])
psi_angles shape: torch.Size([44, 1])
b_factors shape: torch.Size([44, 1])
Shape of labels: torch.Size([44])
esm2_embeddings shape: torch.Size([49, 1280])
ss_onehot shape: torch.Size([49, 4])
phi_angles shape: torch.Size([49, 1])
psi_angles shape: torch.Size([49, 1])
b_factors shape: torch.Size([49, 1])
Shape of labels: torch.Size([49])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])


 47%|████▋     | 480/1013 [01:25<00:49, 10.70it/s]

esm2_embeddings shape: torch.Size([40, 1280])
ss_onehot shape: torch.Size([40, 4])
phi_angles shape: torch.Size([40, 1])
psi_angles shape: torch.Size([40, 1])
b_factors shape: torch.Size([40, 1])
Shape of labels: torch.Size([40])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])
esm2_embeddings shape: torch.Size([68, 1280])
ss_onehot shape: torch.Size([68, 4])
phi_angles shape: torch.Size([68, 1])
psi_angles shape: torch.Size([68, 1])
b_factors shape: torch.Size([68, 1])
Shape of labels: torch.Size([68])


 48%|████▊     | 482/1013 [01:25<01:01,  8.68it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])
esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])
esm2_embeddings shape: torch.Size([95, 1280])
ss_onehot shape: torch.Size([95, 4])
phi_angles shape: torch.Size([95, 1])
psi_angles shape: torch.Size([95, 1])
b_factors shape: torch.Size([95, 1])
Shape of labels: torch.Size([95])


 48%|████▊     | 486/1013 [01:25<00:52,  9.97it/s]

esm2_embeddings shape: torch.Size([196, 1280])
ss_onehot shape: torch.Size([196, 4])
phi_angles shape: torch.Size([196, 1])
psi_angles shape: torch.Size([196, 1])
b_factors shape: torch.Size([196, 1])
Shape of labels: torch.Size([196])
esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])
esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])


 48%|████▊     | 488/1013 [01:25<00:54,  9.71it/s]

esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])
esm2_embeddings shape: torch.Size([137, 1280])
ss_onehot shape: torch.Size([137, 4])
phi_angles shape: torch.Size([137, 1])
psi_angles shape: torch.Size([137, 1])
b_factors shape: torch.Size([137, 1])
Shape of labels: torch.Size([137])
esm2_embeddings shape: torch.Size([106, 1280])
ss_onehot shape: torch.Size([106, 4])
phi_angles shape: torch.Size([106, 1])
psi_angles shape: torch.Size([106, 1])
b_factors shape: torch.Size([106, 1])
Shape of labels: torch.Size([106])


 48%|████▊     | 490/1013 [01:26<00:50, 10.33it/s]

esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])


 49%|████▊     | 492/1013 [01:26<01:27,  5.92it/s]

esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])
esm2_embeddings shape: torch.Size([117, 1280])
ss_onehot shape: torch.Size([117, 4])
phi_angles shape: torch.Size([117, 1])
psi_angles shape: torch.Size([117, 1])
b_factors shape: torch.Size([117, 1])
Shape of labels: torch.Size([117])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])


 49%|████▉     | 496/1013 [01:27<01:04,  8.05it/s]

esm2_embeddings shape: torch.Size([72, 1280])
ss_onehot shape: torch.Size([72, 4])
phi_angles shape: torch.Size([72, 1])
psi_angles shape: torch.Size([72, 1])
b_factors shape: torch.Size([72, 1])
Shape of labels: torch.Size([72])
esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])
esm2_embeddings shape: torch.Size([324, 1280])
ss_onehot shape: torch.Size([324, 4])
phi_angles shape: torch.Size([324, 1])
psi_angles shape: torch.Size([324, 1])
b_factors shape: torch.Size([324, 1])
Shape of labels: torch.Size([324])


 49%|████▉     | 499/1013 [01:27<01:27,  5.84it/s]

esm2_embeddings shape: torch.Size([248, 1280])
ss_onehot shape: torch.Size([248, 4])
phi_angles shape: torch.Size([248, 1])
psi_angles shape: torch.Size([248, 1])
b_factors shape: torch.Size([248, 1])
Shape of labels: torch.Size([248])
esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])


 49%|████▉     | 501/1013 [01:28<01:25,  5.98it/s]

esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])
esm2_embeddings shape: torch.Size([209, 1280])
ss_onehot shape: torch.Size([209, 4])
phi_angles shape: torch.Size([209, 1])
psi_angles shape: torch.Size([209, 1])
b_factors shape: torch.Size([209, 1])
Shape of labels: torch.Size([209])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])


 50%|████▉     | 504/1013 [01:29<01:53,  4.47it/s]

esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])
esm2_embeddings shape: torch.Size([220, 1280])
ss_onehot shape: torch.Size([220, 4])
phi_angles shape: torch.Size([220, 1])
psi_angles shape: torch.Size([220, 1])
b_factors shape: torch.Size([220, 1])
Shape of labels: torch.Size([220])


 50%|████▉     | 506/1013 [01:29<01:28,  5.74it/s]

esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])


 50%|█████     | 508/1013 [01:29<01:11,  7.02it/s]

esm2_embeddings shape: torch.Size([118, 1280])
ss_onehot shape: torch.Size([118, 4])
phi_angles shape: torch.Size([118, 1])
psi_angles shape: torch.Size([118, 1])
b_factors shape: torch.Size([118, 1])
Shape of labels: torch.Size([118])
esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])


 50%|█████     | 511/1013 [01:29<01:05,  7.66it/s]

esm2_embeddings shape: torch.Size([233, 1280])
ss_onehot shape: torch.Size([233, 4])
phi_angles shape: torch.Size([233, 1])
psi_angles shape: torch.Size([233, 1])
b_factors shape: torch.Size([233, 1])
Shape of labels: torch.Size([233])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])
esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])


 51%|█████     | 513/1013 [01:29<00:52,  9.44it/s]

esm2_embeddings shape: torch.Size([64, 1280])
ss_onehot shape: torch.Size([64, 4])
phi_angles shape: torch.Size([64, 1])
psi_angles shape: torch.Size([64, 1])
b_factors shape: torch.Size([64, 1])
Shape of labels: torch.Size([64])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])
esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])


 51%|█████     | 515/1013 [01:30<00:47, 10.38it/s]

esm2_embeddings shape: torch.Size([90, 1280])
ss_onehot shape: torch.Size([90, 4])
phi_angles shape: torch.Size([90, 1])
psi_angles shape: torch.Size([90, 1])
b_factors shape: torch.Size([90, 1])
Shape of labels: torch.Size([90])
esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])


 51%|█████     | 517/1013 [01:30<00:54,  9.15it/s]

esm2_embeddings shape: torch.Size([230, 1280])
ss_onehot shape: torch.Size([230, 4])
phi_angles shape: torch.Size([230, 1])
psi_angles shape: torch.Size([230, 1])
b_factors shape: torch.Size([230, 1])
Shape of labels: torch.Size([230])
esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])


 51%|█████     | 519/1013 [01:31<01:32,  5.32it/s]

esm2_embeddings shape: torch.Size([210, 1280])
ss_onehot shape: torch.Size([210, 4])
phi_angles shape: torch.Size([210, 1])
psi_angles shape: torch.Size([210, 1])
b_factors shape: torch.Size([210, 1])
Shape of labels: torch.Size([210])


 51%|█████▏    | 520/1013 [01:31<01:42,  4.80it/s]

esm2_embeddings shape: torch.Size([301, 1280])
ss_onehot shape: torch.Size([301, 4])
phi_angles shape: torch.Size([301, 1])
psi_angles shape: torch.Size([301, 1])
b_factors shape: torch.Size([301, 1])
Shape of labels: torch.Size([301])


 51%|█████▏    | 521/1013 [01:31<01:52,  4.36it/s]

esm2_embeddings shape: torch.Size([309, 1280])
ss_onehot shape: torch.Size([309, 4])
phi_angles shape: torch.Size([309, 1])
psi_angles shape: torch.Size([309, 1])
b_factors shape: torch.Size([309, 1])
Shape of labels: torch.Size([309])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])


 52%|█████▏    | 523/1013 [01:31<01:29,  5.45it/s]

esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])


 52%|█████▏    | 526/1013 [01:32<01:12,  6.76it/s]

esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([83, 1280])
ss_onehot shape: torch.Size([83, 4])
phi_angles shape: torch.Size([83, 1])
psi_angles shape: torch.Size([83, 1])
b_factors shape: torch.Size([83, 1])
Shape of labels: torch.Size([83])


 52%|█████▏    | 530/1013 [01:32<00:52,  9.18it/s]

esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])


 53%|█████▎    | 532/1013 [01:33<01:25,  5.64it/s]

esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])


 53%|█████▎    | 536/1013 [01:33<00:59,  8.00it/s]

esm2_embeddings shape: torch.Size([60, 1280])
ss_onehot shape: torch.Size([60, 4])
phi_angles shape: torch.Size([60, 1])
psi_angles shape: torch.Size([60, 1])
b_factors shape: torch.Size([60, 1])
Shape of labels: torch.Size([60])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])


 53%|█████▎    | 538/1013 [01:33<00:52,  9.04it/s]

esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([91, 1280])
ss_onehot shape: torch.Size([91, 4])
phi_angles shape: torch.Size([91, 1])
psi_angles shape: torch.Size([91, 1])
b_factors shape: torch.Size([91, 1])
Shape of labels: torch.Size([91])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])


 53%|█████▎    | 540/1013 [01:33<01:01,  7.67it/s]

esm2_embeddings shape: torch.Size([277, 1280])
ss_onehot shape: torch.Size([277, 4])
phi_angles shape: torch.Size([277, 1])
psi_angles shape: torch.Size([277, 1])
b_factors shape: torch.Size([277, 1])
Shape of labels: torch.Size([277])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 54%|█████▎    | 542/1013 [01:34<00:58,  8.01it/s]

esm2_embeddings shape: torch.Size([182, 1280])
ss_onehot shape: torch.Size([182, 4])
phi_angles shape: torch.Size([182, 1])
psi_angles shape: torch.Size([182, 1])
b_factors shape: torch.Size([182, 1])
Shape of labels: torch.Size([182])
esm2_embeddings shape: torch.Size([113, 1280])
ss_onehot shape: torch.Size([113, 4])
phi_angles shape: torch.Size([113, 1])
psi_angles shape: torch.Size([113, 1])
b_factors shape: torch.Size([113, 1])
Shape of labels: torch.Size([113])


 54%|█████▍    | 546/1013 [01:34<00:48,  9.64it/s]

esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])
esm2_embeddings shape: torch.Size([46, 1280])
ss_onehot shape: torch.Size([46, 4])
phi_angles shape: torch.Size([46, 1])
psi_angles shape: torch.Size([46, 1])
b_factors shape: torch.Size([46, 1])
Shape of labels: torch.Size([46])
esm2_embeddings shape: torch.Size([57, 1280])
ss_onehot shape: torch.Size([57, 4])
phi_angles shape: torch.Size([57, 1])
psi_angles shape: torch.Size([57, 1])
b_factors shape: torch.Size([57, 1])
Shape of labels: torch.Size([57])


 54%|█████▍    | 548/1013 [01:34<00:40, 11.34it/s]

esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])


 54%|█████▍    | 550/1013 [01:35<01:24,  5.47it/s]

esm2_embeddings shape: torch.Size([217, 1280])
ss_onehot shape: torch.Size([217, 4])
phi_angles shape: torch.Size([217, 1])
psi_angles shape: torch.Size([217, 1])
b_factors shape: torch.Size([217, 1])
Shape of labels: torch.Size([217])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([217, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([217, 1])
Shape of labels: torch.Size([169])


 55%|█████▍    | 554/1013 [01:35<00:57,  7.93it/s]

esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])
esm2_embeddings shape: torch.Size([84, 1280])
ss_onehot shape: torch.Size([84, 4])
phi_angles shape: torch.Size([84, 1])
psi_angles shape: torch.Size([84, 1])
b_factors shape: torch.Size([84, 1])
Shape of labels: torch.Size([84])
esm2_embeddings shape: torch.Size([112, 1280])
ss_onehot shape: torch.Size([112, 4])
phi_angles shape: torch.Size([112, 1])
psi_angles shape: torch.Size([112, 1])
b_factors shape: torch.Size([112, 1])
Shape of labels: torch.Size([112])
esm2_embeddings shape: torch.Size([63, 1280])
ss_onehot shape: torch.Size([63, 4])
phi_angles shape: torch.Size([63, 1])
psi_angles shape: torch.Size([63, 1])
b_factors shape: torch.Size([63, 1])
Shape of labels: torch.Size([63])


 55%|█████▍    | 556/1013 [01:35<00:59,  7.67it/s]

esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])
esm2_embeddings shape: torch.Size([189, 1280])
ss_onehot shape: torch.Size([189, 4])
phi_angles shape: torch.Size([189, 1])
psi_angles shape: torch.Size([189, 1])
b_factors shape: torch.Size([189, 1])
Shape of labels: torch.Size([189])


 55%|█████▌    | 558/1013 [01:36<01:08,  6.65it/s]

esm2_embeddings shape: torch.Size([297, 1280])
ss_onehot shape: torch.Size([297, 4])
phi_angles shape: torch.Size([297, 1])
psi_angles shape: torch.Size([297, 1])
b_factors shape: torch.Size([297, 1])
Shape of labels: torch.Size([297])
esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])


 55%|█████▌    | 561/1013 [01:36<00:58,  7.78it/s]

esm2_embeddings shape: torch.Size([148, 1280])
ss_onehot shape: torch.Size([148, 4])
phi_angles shape: torch.Size([148, 1])
psi_angles shape: torch.Size([148, 1])
b_factors shape: torch.Size([148, 1])
Shape of labels: torch.Size([148])
esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 56%|█████▌    | 563/1013 [01:36<00:53,  8.46it/s]

esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])
esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])


 56%|█████▌    | 565/1013 [01:37<01:22,  5.46it/s]

esm2_embeddings shape: torch.Size([175, 1280])
ss_onehot shape: torch.Size([175, 4])
phi_angles shape: torch.Size([175, 1])
psi_angles shape: torch.Size([175, 1])
b_factors shape: torch.Size([175, 1])
Shape of labels: torch.Size([175])
esm2_embeddings shape: torch.Size([142, 1280])
ss_onehot shape: torch.Size([142, 4])
phi_angles shape: torch.Size([142, 1])
psi_angles shape: torch.Size([142, 1])
b_factors shape: torch.Size([142, 1])
Shape of labels: torch.Size([142])


 56%|█████▌    | 566/1013 [01:37<01:36,  4.61it/s]

esm2_embeddings shape: torch.Size([323, 1280])
ss_onehot shape: torch.Size([323, 4])
phi_angles shape: torch.Size([323, 1])
psi_angles shape: torch.Size([323, 1])
b_factors shape: torch.Size([323, 1])
Shape of labels: torch.Size([323])


 56%|█████▌    | 568/1013 [01:38<01:32,  4.82it/s]

esm2_embeddings shape: torch.Size([295, 1280])
ss_onehot shape: torch.Size([295, 4])
phi_angles shape: torch.Size([295, 1])
psi_angles shape: torch.Size([295, 1])
b_factors shape: torch.Size([295, 1])
Shape of labels: torch.Size([295])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])


 56%|█████▋    | 570/1013 [01:38<01:32,  4.77it/s]

esm2_embeddings shape: torch.Size([271, 1280])
ss_onehot shape: torch.Size([271, 4])
phi_angles shape: torch.Size([271, 1])
psi_angles shape: torch.Size([271, 1])
b_factors shape: torch.Size([271, 1])
Shape of labels: torch.Size([271])
esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])


 56%|█████▋    | 571/1013 [01:39<02:15,  3.25it/s]

esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])


 56%|█████▋    | 572/1013 [01:39<02:03,  3.57it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])


 57%|█████▋    | 574/1013 [01:39<01:42,  4.28it/s]

esm2_embeddings shape: torch.Size([302, 1280])
ss_onehot shape: torch.Size([302, 4])
phi_angles shape: torch.Size([302, 1])
psi_angles shape: torch.Size([302, 1])
b_factors shape: torch.Size([302, 1])
Shape of labels: torch.Size([302])
esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])


 57%|█████▋    | 575/1013 [01:40<01:32,  4.75it/s]

esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])


 57%|█████▋    | 579/1013 [01:40<01:08,  6.35it/s]

esm2_embeddings shape: torch.Size([283, 1280])
ss_onehot shape: torch.Size([283, 4])
phi_angles shape: torch.Size([283, 1])
psi_angles shape: torch.Size([283, 1])
b_factors shape: torch.Size([283, 1])
Shape of labels: torch.Size([283])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])


 57%|█████▋    | 582/1013 [01:41<01:25,  5.06it/s]

esm2_embeddings shape: torch.Size([180, 1280])
ss_onehot shape: torch.Size([180, 4])
phi_angles shape: torch.Size([180, 1])
psi_angles shape: torch.Size([180, 1])
b_factors shape: torch.Size([180, 1])
Shape of labels: torch.Size([180])
esm2_embeddings shape: torch.Size([158, 1280])
ss_onehot shape: torch.Size([158, 4])
phi_angles shape: torch.Size([158, 1])
psi_angles shape: torch.Size([158, 1])
b_factors shape: torch.Size([158, 1])
Shape of labels: torch.Size([158])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])


 58%|█████▊    | 584/1013 [01:41<01:20,  5.31it/s]

esm2_embeddings shape: torch.Size([270, 1280])
ss_onehot shape: torch.Size([270, 4])
phi_angles shape: torch.Size([270, 1])
psi_angles shape: torch.Size([270, 1])
b_factors shape: torch.Size([270, 1])
Shape of labels: torch.Size([270])


 58%|█████▊    | 586/1013 [01:42<01:42,  4.17it/s]

esm2_embeddings shape: torch.Size([470, 1280])
ss_onehot shape: torch.Size([470, 4])
phi_angles shape: torch.Size([470, 1])
psi_angles shape: torch.Size([470, 1])
b_factors shape: torch.Size([470, 1])
Shape of labels: torch.Size([470])
esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])
esm2_embeddings shape: torch.Size([112, 1280])
ss_onehot shape: torch.Size([112, 4])
phi_angles shape: torch.Size([112, 1])
psi_angles shape: torch.Size([112, 1])
b_factors shape: torch.Size([112, 1])
Shape of labels: torch.Size([112])


 58%|█████▊    | 589/1013 [01:42<01:12,  5.84it/s]

esm2_embeddings shape: torch.Size([176, 1280])
ss_onehot shape: torch.Size([176, 4])
phi_angles shape: torch.Size([176, 1])
psi_angles shape: torch.Size([176, 1])
b_factors shape: torch.Size([176, 1])
Shape of labels: torch.Size([176])
esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])
esm2_embeddings shape: torch.Size([116, 1280])
ss_onehot shape: torch.Size([116, 4])
phi_angles shape: torch.Size([116, 1])
psi_angles shape: torch.Size([116, 1])
b_factors shape: torch.Size([116, 1])
Shape of labels: torch.Size([116])


 58%|█████▊    | 592/1013 [01:43<01:23,  5.05it/s]

esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([132, 1280])
ss_onehot shape: torch.Size([132, 4])
phi_angles shape: torch.Size([132, 1])
psi_angles shape: torch.Size([132, 1])
b_factors shape: torch.Size([132, 1])
Shape of labels: torch.Size([132])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])


 59%|█████▊    | 595/1013 [01:43<01:01,  6.85it/s]

esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])


 59%|█████▉    | 596/1013 [01:43<00:58,  7.08it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])


 59%|█████▉    | 598/1013 [01:44<01:03,  6.58it/s]

esm2_embeddings shape: torch.Size([255, 1280])
ss_onehot shape: torch.Size([255, 4])
phi_angles shape: torch.Size([255, 1])
psi_angles shape: torch.Size([255, 1])
b_factors shape: torch.Size([255, 1])
Shape of labels: torch.Size([255])
esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])


 59%|█████▉    | 601/1013 [01:44<00:47,  8.72it/s]

esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])


 59%|█████▉    | 602/1013 [01:44<00:45,  8.98it/s]

esm2_embeddings shape: torch.Size([137, 1280])
ss_onehot shape: torch.Size([137, 4])
phi_angles shape: torch.Size([137, 1])
psi_angles shape: torch.Size([137, 1])
b_factors shape: torch.Size([137, 1])
Shape of labels: torch.Size([137])


 60%|█████▉    | 603/1013 [01:44<00:58,  6.96it/s]

esm2_embeddings shape: torch.Size([264, 1280])
ss_onehot shape: torch.Size([264, 4])
phi_angles shape: torch.Size([264, 1])
psi_angles shape: torch.Size([264, 1])
b_factors shape: torch.Size([264, 1])
Shape of labels: torch.Size([264])


 60%|█████▉    | 604/1013 [01:45<01:46,  3.82it/s]

esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


 60%|█████▉    | 606/1013 [01:45<01:37,  4.16it/s]

esm2_embeddings shape: torch.Size([350, 1280])
ss_onehot shape: torch.Size([350, 4])
phi_angles shape: torch.Size([350, 1])
psi_angles shape: torch.Size([350, 1])
b_factors shape: torch.Size([350, 1])
Shape of labels: torch.Size([350])
esm2_embeddings shape: torch.Size([134, 1280])
ss_onehot shape: torch.Size([134, 4])
phi_angles shape: torch.Size([134, 1])
psi_angles shape: torch.Size([134, 1])
b_factors shape: torch.Size([134, 1])
Shape of labels: torch.Size([134])


 60%|██████    | 608/1013 [01:46<01:13,  5.54it/s]

esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])
esm2_embeddings shape: torch.Size([120, 1280])
ss_onehot shape: torch.Size([120, 4])
phi_angles shape: torch.Size([120, 1])
psi_angles shape: torch.Size([120, 1])
b_factors shape: torch.Size([120, 1])
Shape of labels: torch.Size([120])
esm2_embeddings shape: torch.Size([69, 1280])
ss_onehot shape: torch.Size([69, 4])
phi_angles shape: torch.Size([69, 1])
psi_angles shape: torch.Size([69, 1])
b_factors shape: torch.Size([69, 1])
Shape of labels: torch.Size([69])


 60%|██████    | 610/1013 [01:46<01:00,  6.69it/s]

esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


 61%|██████    | 613/1013 [01:46<00:58,  6.85it/s]

esm2_embeddings shape: torch.Size([295, 1280])
ss_onehot shape: torch.Size([295, 4])
phi_angles shape: torch.Size([295, 1])
psi_angles shape: torch.Size([295, 1])
b_factors shape: torch.Size([295, 1])
Shape of labels: torch.Size([295])
esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])
esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])


 61%|██████    | 614/1013 [01:47<00:54,  7.34it/s]

esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 61%|██████    | 615/1013 [01:47<01:47,  3.70it/s]

esm2_embeddings shape: torch.Size([279, 1280])
ss_onehot shape: torch.Size([279, 4])
phi_angles shape: torch.Size([279, 1])
psi_angles shape: torch.Size([279, 1])
b_factors shape: torch.Size([279, 1])
Shape of labels: torch.Size([279])


 61%|██████    | 616/1013 [01:48<02:36,  2.54it/s]

esm2_embeddings shape: torch.Size([540, 1280])
ss_onehot shape: torch.Size([540, 4])
phi_angles shape: torch.Size([540, 1])
psi_angles shape: torch.Size([540, 1])
b_factors shape: torch.Size([540, 1])
Shape of labels: torch.Size([540])


 61%|██████    | 618/1013 [01:48<01:56,  3.40it/s]

esm2_embeddings shape: torch.Size([274, 1280])
ss_onehot shape: torch.Size([274, 4])
phi_angles shape: torch.Size([274, 1])
psi_angles shape: torch.Size([274, 1])
b_factors shape: torch.Size([274, 1])
Shape of labels: torch.Size([274])
esm2_embeddings shape: torch.Size([166, 1280])
ss_onehot shape: torch.Size([166, 4])
phi_angles shape: torch.Size([166, 1])
psi_angles shape: torch.Size([166, 1])
b_factors shape: torch.Size([166, 1])
Shape of labels: torch.Size([166])
esm2_embeddings shape: torch.Size([79, 1280])
ss_onehot shape: torch.Size([79, 4])
phi_angles shape: torch.Size([79, 1])
psi_angles shape: torch.Size([79, 1])
b_factors shape: torch.Size([79, 1])
Shape of labels: torch.Size([79])


 61%|██████▏   | 622/1013 [01:49<00:58,  6.65it/s]

esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])


 62%|██████▏   | 624/1013 [01:49<00:52,  7.36it/s]

esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])
esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])


 62%|██████▏   | 626/1013 [01:50<01:25,  4.51it/s]

esm2_embeddings shape: torch.Size([217, 1280])
ss_onehot shape: torch.Size([217, 4])
phi_angles shape: torch.Size([217, 1])
psi_angles shape: torch.Size([217, 1])
b_factors shape: torch.Size([217, 1])
Shape of labels: torch.Size([217])
esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])


 62%|██████▏   | 628/1013 [01:50<01:11,  5.40it/s]

esm2_embeddings shape: torch.Size([134, 1280])
ss_onehot shape: torch.Size([134, 4])
phi_angles shape: torch.Size([134, 1])
psi_angles shape: torch.Size([134, 1])
b_factors shape: torch.Size([134, 1])
Shape of labels: torch.Size([134])
esm2_embeddings shape: torch.Size([198, 1280])
ss_onehot shape: torch.Size([198, 4])
phi_angles shape: torch.Size([198, 1])
psi_angles shape: torch.Size([198, 1])
b_factors shape: torch.Size([198, 1])
Shape of labels: torch.Size([198])


 62%|██████▏   | 631/1013 [01:50<01:01,  6.17it/s]

esm2_embeddings shape: torch.Size([312, 1280])
ss_onehot shape: torch.Size([312, 4])
phi_angles shape: torch.Size([312, 1])
psi_angles shape: torch.Size([312, 1])
b_factors shape: torch.Size([312, 1])
Shape of labels: torch.Size([312])
esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])
esm2_embeddings shape: torch.Size([141, 1280])
ss_onehot shape: torch.Size([141, 4])
phi_angles shape: torch.Size([141, 1])
psi_angles shape: torch.Size([141, 1])
b_factors shape: torch.Size([141, 1])
Shape of labels: torch.Size([141])


 62%|██████▏   | 633/1013 [01:51<00:54,  7.03it/s]

esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])
esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])


 63%|██████▎   | 636/1013 [01:51<00:44,  8.41it/s]

esm2_embeddings shape: torch.Size([223, 1280])
ss_onehot shape: torch.Size([223, 4])
phi_angles shape: torch.Size([223, 1])
psi_angles shape: torch.Size([223, 1])
b_factors shape: torch.Size([223, 1])
Shape of labels: torch.Size([223])
esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])
esm2_embeddings shape: torch.Size([122, 1280])
ss_onehot shape: torch.Size([122, 4])
phi_angles shape: torch.Size([122, 1])
psi_angles shape: torch.Size([122, 1])
b_factors shape: torch.Size([122, 1])
Shape of labels: torch.Size([122])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])


 63%|██████▎   | 639/1013 [01:52<01:18,  4.75it/s]

esm2_embeddings shape: torch.Size([267, 1280])
ss_onehot shape: torch.Size([267, 4])
phi_angles shape: torch.Size([267, 1])
psi_angles shape: torch.Size([267, 1])
b_factors shape: torch.Size([267, 1])
Shape of labels: torch.Size([267])
esm2_embeddings shape: torch.Size([207, 1280])
ss_onehot shape: torch.Size([207, 4])
phi_angles shape: torch.Size([207, 1])
psi_angles shape: torch.Size([207, 1])
b_factors shape: torch.Size([207, 1])
Shape of labels: torch.Size([207])


 63%|██████▎   | 641/1013 [01:52<01:12,  5.15it/s]

esm2_embeddings shape: torch.Size([222, 1280])
ss_onehot shape: torch.Size([222, 4])
phi_angles shape: torch.Size([222, 1])
psi_angles shape: torch.Size([222, 1])
b_factors shape: torch.Size([222, 1])
Shape of labels: torch.Size([222])
esm2_embeddings shape: torch.Size([182, 1280])
ss_onehot shape: torch.Size([182, 4])
phi_angles shape: torch.Size([182, 1])
psi_angles shape: torch.Size([182, 1])
b_factors shape: torch.Size([182, 1])
Shape of labels: torch.Size([182])


 63%|██████▎   | 643/1013 [01:53<01:01,  5.99it/s]

esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])
esm2_embeddings shape: torch.Size([219, 1280])
ss_onehot shape: torch.Size([219, 4])
phi_angles shape: torch.Size([219, 1])
psi_angles shape: torch.Size([219, 1])
b_factors shape: torch.Size([219, 1])
Shape of labels: torch.Size([219])


 64%|██████▎   | 645/1013 [01:53<00:58,  6.34it/s]

esm2_embeddings shape: torch.Size([193, 1280])
ss_onehot shape: torch.Size([193, 4])
phi_angles shape: torch.Size([193, 1])
psi_angles shape: torch.Size([193, 1])
b_factors shape: torch.Size([193, 1])
Shape of labels: torch.Size([193])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


 64%|██████▍   | 647/1013 [01:53<01:00,  6.10it/s]

esm2_embeddings shape: torch.Size([183, 1280])
ss_onehot shape: torch.Size([183, 4])
phi_angles shape: torch.Size([183, 1])
psi_angles shape: torch.Size([183, 1])
b_factors shape: torch.Size([183, 1])
Shape of labels: torch.Size([183])
esm2_embeddings shape: torch.Size([228, 1280])
ss_onehot shape: torch.Size([228, 4])
phi_angles shape: torch.Size([228, 1])
psi_angles shape: torch.Size([228, 1])
b_factors shape: torch.Size([228, 1])
Shape of labels: torch.Size([228])


 64%|██████▍   | 648/1013 [01:54<01:56,  3.13it/s]

esm2_embeddings shape: torch.Size([294, 1280])
ss_onehot shape: torch.Size([294, 4])
phi_angles shape: torch.Size([294, 1])
psi_angles shape: torch.Size([294, 1])
b_factors shape: torch.Size([294, 1])
Shape of labels: torch.Size([294])


 64%|██████▍   | 649/1013 [01:54<01:48,  3.37it/s]

esm2_embeddings shape: torch.Size([261, 1280])
ss_onehot shape: torch.Size([261, 4])
phi_angles shape: torch.Size([261, 1])
psi_angles shape: torch.Size([261, 1])
b_factors shape: torch.Size([261, 1])
Shape of labels: torch.Size([261])


 64%|██████▍   | 651/1013 [01:54<01:22,  4.38it/s]

esm2_embeddings shape: torch.Size([235, 1280])
ss_onehot shape: torch.Size([235, 4])
phi_angles shape: torch.Size([235, 1])
psi_angles shape: torch.Size([235, 1])
b_factors shape: torch.Size([235, 1])
Shape of labels: torch.Size([235])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


 64%|██████▍   | 653/1013 [01:55<00:58,  6.14it/s]

esm2_embeddings shape: torch.Size([130, 1280])
ss_onehot shape: torch.Size([130, 4])
phi_angles shape: torch.Size([130, 1])
psi_angles shape: torch.Size([130, 1])
b_factors shape: torch.Size([130, 1])
Shape of labels: torch.Size([130])
esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])


 65%|██████▍   | 654/1013 [01:55<01:18,  4.59it/s]

esm2_embeddings shape: torch.Size([356, 1280])
ss_onehot shape: torch.Size([356, 4])
phi_angles shape: torch.Size([356, 1])
psi_angles shape: torch.Size([356, 1])
b_factors shape: torch.Size([356, 1])
Shape of labels: torch.Size([356])


 65%|██████▍   | 656/1013 [01:56<01:36,  3.70it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])
esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])


 65%|██████▌   | 659/1013 [01:56<00:59,  5.92it/s]

esm2_embeddings shape: torch.Size([177, 1280])
ss_onehot shape: torch.Size([177, 4])
phi_angles shape: torch.Size([177, 1])
psi_angles shape: torch.Size([177, 1])
b_factors shape: torch.Size([177, 1])
Shape of labels: torch.Size([177])
esm2_embeddings shape: torch.Size([113, 1280])
ss_onehot shape: torch.Size([113, 4])
phi_angles shape: torch.Size([113, 1])
psi_angles shape: torch.Size([113, 1])
b_factors shape: torch.Size([113, 1])
Shape of labels: torch.Size([113])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])


 65%|██████▌   | 661/1013 [01:56<00:47,  7.41it/s]

esm2_embeddings shape: torch.Size([68, 1280])
ss_onehot shape: torch.Size([68, 4])
phi_angles shape: torch.Size([68, 1])
psi_angles shape: torch.Size([68, 1])
b_factors shape: torch.Size([68, 1])
Shape of labels: torch.Size([68])
esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])
esm2_embeddings shape: torch.Size([40, 1280])
ss_onehot shape: torch.Size([40, 4])
phi_angles shape: torch.Size([40, 1])
psi_angles shape: torch.Size([40, 1])
b_factors shape: torch.Size([40, 1])
Shape of labels: torch.Size([40])


 65%|██████▌   | 663/1013 [01:56<00:38,  8.98it/s]

esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([205, 1280])
ss_onehot shape: torch.Size([205, 4])
phi_angles shape: torch.Size([205, 1])
psi_angles shape: torch.Size([205, 1])
b_factors shape: torch.Size([205, 1])
Shape of labels: torch.Size([205])


 66%|██████▌   | 665/1013 [01:57<00:38,  9.05it/s]

esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])


 66%|██████▌   | 667/1013 [01:57<00:42,  8.05it/s]

esm2_embeddings shape: torch.Size([244, 1280])
ss_onehot shape: torch.Size([244, 4])
phi_angles shape: torch.Size([244, 1])
psi_angles shape: torch.Size([244, 1])
b_factors shape: torch.Size([244, 1])
Shape of labels: torch.Size([244])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 66%|██████▌   | 669/1013 [01:57<00:43,  7.82it/s]

esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])


 66%|██████▌   | 671/1013 [01:58<01:08,  4.98it/s]

esm2_embeddings shape: torch.Size([156, 1280])
ss_onehot shape: torch.Size([156, 4])
phi_angles shape: torch.Size([156, 1])
psi_angles shape: torch.Size([156, 1])
b_factors shape: torch.Size([156, 1])
Shape of labels: torch.Size([156])
esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size([170, 4])
phi_angles shape: torch.Size([170, 1])
psi_angles shape: torch.Size([170, 1])
b_factors shape: torch.Size([170, 1])
Shape of labels: torch.Size([170])
esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])


 67%|██████▋   | 674/1013 [01:58<00:45,  7.43it/s]

esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 67%|██████▋   | 676/1013 [01:58<00:43,  7.75it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])
esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])


 67%|██████▋   | 677/1013 [01:58<00:41,  8.18it/s]

esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 67%|██████▋   | 679/1013 [01:59<00:54,  6.18it/s]

esm2_embeddings shape: torch.Size([257, 1280])
ss_onehot shape: torch.Size([257, 4])
phi_angles shape: torch.Size([257, 1])
psi_angles shape: torch.Size([257, 1])
b_factors shape: torch.Size([257, 1])
Shape of labels: torch.Size([257])
esm2_embeddings shape: torch.Size([218, 1280])
ss_onehot shape: torch.Size([218, 4])
phi_angles shape: torch.Size([218, 1])
psi_angles shape: torch.Size([218, 1])
b_factors shape: torch.Size([218, 1])
Shape of labels: torch.Size([218])


 67%|██████▋   | 681/1013 [01:59<00:50,  6.59it/s]

esm2_embeddings shape: torch.Size([228, 1280])
ss_onehot shape: torch.Size([228, 4])
phi_angles shape: torch.Size([228, 1])
psi_angles shape: torch.Size([228, 1])
b_factors shape: torch.Size([228, 1])
Shape of labels: torch.Size([228])
esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])


 67%|██████▋   | 683/1013 [02:00<01:19,  4.17it/s]

esm2_embeddings shape: torch.Size([200, 1280])
ss_onehot shape: torch.Size([200, 4])
phi_angles shape: torch.Size([200, 1])
psi_angles shape: torch.Size([200, 1])
b_factors shape: torch.Size([200, 1])
Shape of labels: torch.Size([200])


 68%|██████▊   | 686/1013 [02:00<00:57,  5.66it/s]

esm2_embeddings shape: torch.Size([249, 1280])
ss_onehot shape: torch.Size([249, 4])
phi_angles shape: torch.Size([249, 1])
psi_angles shape: torch.Size([249, 1])
b_factors shape: torch.Size([249, 1])
Shape of labels: torch.Size([249])
esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])
esm2_embeddings shape: torch.Size([117, 1280])
ss_onehot shape: torch.Size([117, 4])
phi_angles shape: torch.Size([117, 1])
psi_angles shape: torch.Size([117, 1])
b_factors shape: torch.Size([117, 1])
Shape of labels: torch.Size([117])


 68%|██████▊   | 687/1013 [02:00<00:54,  5.97it/s]

esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size([170, 4])
phi_angles shape: torch.Size([170, 1])
psi_angles shape: torch.Size([170, 1])
b_factors shape: torch.Size([170, 1])
Shape of labels: torch.Size([170])


 68%|██████▊   | 690/1013 [02:01<00:43,  7.42it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])
esm2_embeddings shape: torch.Size([76, 1280])
ss_onehot shape: torch.Size([76, 4])
phi_angles shape: torch.Size([76, 1])
psi_angles shape: torch.Size([76, 1])
b_factors shape: torch.Size([76, 1])
Shape of labels: torch.Size([76])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])
esm2_embeddings shape: torch.Size([86, 1280])
ss_onehot shape: torch.Size([86, 4])
phi_angles shape: torch.Size([86, 1])
psi_angles shape: torch.Size([86, 1])
b_factors shape: torch.Size([86, 1])
Shape of labels: torch.Size([86])


 68%|██████▊   | 693/1013 [02:01<00:42,  7.50it/s]

esm2_embeddings shape: torch.Size([246, 1280])
ss_onehot shape: torch.Size([246, 4])
phi_angles shape: torch.Size([246, 1])
psi_angles shape: torch.Size([246, 1])
b_factors shape: torch.Size([246, 1])
Shape of labels: torch.Size([246])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


 69%|██████▊   | 694/1013 [02:01<00:41,  7.77it/s]

esm2_embeddings shape: torch.Size([132, 1280])
ss_onehot shape: torch.Size([132, 4])
phi_angles shape: torch.Size([132, 1])
psi_angles shape: torch.Size([132, 1])
b_factors shape: torch.Size([132, 1])
Shape of labels: torch.Size([132])
esm2_embeddings shape: torch.Size([93, 1280])
ss_onehot shape: torch.Size([93, 4])
phi_angles shape: torch.Size([93, 1])
psi_angles shape: torch.Size([93, 1])
b_factors shape: torch.Size([93, 1])
Shape of labels: torch.Size([93])


 69%|██████▊   | 696/1013 [02:02<01:08,  4.64it/s]

esm2_embeddings shape: torch.Size([237, 1280])
ss_onehot shape: torch.Size([237, 4])
phi_angles shape: torch.Size([237, 1])
psi_angles shape: torch.Size([237, 1])
b_factors shape: torch.Size([237, 1])
Shape of labels: torch.Size([237])
esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])


 69%|██████▉   | 700/1013 [02:02<00:43,  7.18it/s]

esm2_embeddings shape: torch.Size([196, 1280])
ss_onehot shape: torch.Size([196, 4])
phi_angles shape: torch.Size([196, 1])
psi_angles shape: torch.Size([196, 1])
b_factors shape: torch.Size([196, 1])
Shape of labels: torch.Size([196])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])
esm2_embeddings shape: torch.Size([90, 1280])
ss_onehot shape: torch.Size([90, 4])
phi_angles shape: torch.Size([90, 1])
psi_angles shape: torch.Size([90, 1])
b_factors shape: torch.Size([90, 1])
Shape of labels: torch.Size([90])


 69%|██████▉   | 703/1013 [02:03<00:34,  9.04it/s]

esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])
esm2_embeddings shape: torch.Size([87, 1280])
ss_onehot shape: torch.Size([87, 4])
phi_angles shape: torch.Size([87, 1])
psi_angles shape: torch.Size([87, 1])
b_factors shape: torch.Size([87, 1])
Shape of labels: torch.Size([87])
esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])


 70%|██████▉   | 705/1013 [02:03<00:41,  7.45it/s]

esm2_embeddings shape: torch.Size([220, 1280])
ss_onehot shape: torch.Size([220, 4])
phi_angles shape: torch.Size([220, 1])
psi_angles shape: torch.Size([220, 1])
b_factors shape: torch.Size([220, 1])
Shape of labels: torch.Size([220])
esm2_embeddings shape: torch.Size([214, 1280])
ss_onehot shape: torch.Size([214, 4])
phi_angles shape: torch.Size([214, 1])
psi_angles shape: torch.Size([214, 1])
b_factors shape: torch.Size([214, 1])
Shape of labels: torch.Size([214])


 70%|██████▉   | 707/1013 [02:04<01:13,  4.19it/s]

esm2_embeddings shape: torch.Size([387, 1280])
ss_onehot shape: torch.Size([387, 4])
phi_angles shape: torch.Size([387, 1])
psi_angles shape: torch.Size([387, 1])
b_factors shape: torch.Size([387, 1])
Shape of labels: torch.Size([387])
esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])


 70%|██████▉   | 709/1013 [02:04<01:01,  4.95it/s]

esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])
esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])


 70%|███████   | 710/1013 [02:04<01:00,  5.04it/s]

esm2_embeddings shape: torch.Size([214, 1280])
ss_onehot shape: torch.Size([214, 4])
phi_angles shape: torch.Size([214, 1])
psi_angles shape: torch.Size([214, 1])
b_factors shape: torch.Size([214, 1])
Shape of labels: torch.Size([214])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])


 70%|███████   | 712/1013 [02:05<00:49,  6.08it/s]

esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])


 70%|███████   | 714/1013 [02:05<00:51,  5.81it/s]

esm2_embeddings shape: torch.Size([298, 1280])
ss_onehot shape: torch.Size([298, 4])
phi_angles shape: torch.Size([298, 1])
psi_angles shape: torch.Size([298, 1])
b_factors shape: torch.Size([298, 1])
Shape of labels: torch.Size([298])
esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])


 71%|███████   | 718/1013 [02:05<00:35,  8.31it/s]

esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([53, 1280])
ss_onehot shape: torch.Size([53, 4])
phi_angles shape: torch.Size([53, 1])
psi_angles shape: torch.Size([53, 1])
b_factors shape: torch.Size([53, 1])
Shape of labels: torch.Size([53])
esm2_embeddings shape: torch.Size([175, 1280])
ss_onehot shape: torch.Size([175, 4])
phi_angles shape: torch.Size([175, 1])
psi_angles shape: torch.Size([175, 1])
b_factors shape: torch.Size([175, 1])
Shape of labels: torch.Size([175])


 71%|███████   | 719/1013 [02:06<01:08,  4.28it/s]

esm2_embeddings shape: torch.Size([211, 1280])
ss_onehot shape: torch.Size([211, 4])
phi_angles shape: torch.Size([211, 1])
psi_angles shape: torch.Size([211, 1])
b_factors shape: torch.Size([211, 1])
Shape of labels: torch.Size([211])
esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])


 71%|███████   | 721/1013 [02:07<01:10,  4.13it/s]

esm2_embeddings shape: torch.Size([373, 1280])
ss_onehot shape: torch.Size([373, 4])
phi_angles shape: torch.Size([373, 1])
psi_angles shape: torch.Size([373, 1])
b_factors shape: torch.Size([373, 1])
Shape of labels: torch.Size([373])


 71%|███████▏  | 722/1013 [02:07<01:15,  3.86it/s]

esm2_embeddings shape: torch.Size([318, 1280])
ss_onehot shape: torch.Size([318, 4])
phi_angles shape: torch.Size([318, 1])
psi_angles shape: torch.Size([318, 1])
b_factors shape: torch.Size([318, 1])
Shape of labels: torch.Size([318])


 71%|███████▏  | 723/1013 [02:07<01:18,  3.72it/s]

esm2_embeddings shape: torch.Size([310, 1280])
ss_onehot shape: torch.Size([310, 4])
phi_angles shape: torch.Size([310, 1])
psi_angles shape: torch.Size([310, 1])
b_factors shape: torch.Size([310, 1])
Shape of labels: torch.Size([310])


 71%|███████▏  | 724/1013 [02:08<02:08,  2.25it/s]

esm2_embeddings shape: torch.Size([446, 1280])
ss_onehot shape: torch.Size([446, 4])
phi_angles shape: torch.Size([446, 1])
psi_angles shape: torch.Size([446, 1])
b_factors shape: torch.Size([446, 1])
Shape of labels: torch.Size([446])
esm2_embeddings shape: torch.Size([117, 1280])
ss_onehot shape: torch.Size([117, 4])
phi_angles shape: torch.Size([117, 1])
psi_angles shape: torch.Size([117, 1])
b_factors shape: torch.Size([117, 1])
Shape of labels: torch.Size([117])


 72%|███████▏  | 726/1013 [02:08<01:27,  3.28it/s]

esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])


 72%|███████▏  | 729/1013 [02:09<01:03,  4.45it/s]

esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])
esm2_embeddings shape: torch.Size([220, 1280])
ss_onehot shape: torch.Size([220, 4])
phi_angles shape: torch.Size([220, 1])
psi_angles shape: torch.Size([220, 1])
b_factors shape: torch.Size([220, 1])
Shape of labels: torch.Size([220])


 72%|███████▏  | 731/1013 [02:09<01:05,  4.28it/s]

esm2_embeddings shape: torch.Size([338, 1280])
ss_onehot shape: torch.Size([338, 4])
phi_angles shape: torch.Size([338, 1])
psi_angles shape: torch.Size([338, 1])
b_factors shape: torch.Size([338, 1])
Shape of labels: torch.Size([338])
esm2_embeddings shape: torch.Size([199, 1280])
ss_onehot shape: torch.Size([199, 4])
phi_angles shape: torch.Size([199, 1])
psi_angles shape: torch.Size([199, 1])
b_factors shape: torch.Size([199, 1])
Shape of labels: torch.Size([199])


 72%|███████▏  | 733/1013 [02:10<00:50,  5.50it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])


 72%|███████▏  | 734/1013 [02:10<01:31,  3.05it/s]

esm2_embeddings shape: torch.Size([274, 1280])
ss_onehot shape: torch.Size([274, 4])
phi_angles shape: torch.Size([274, 1])
psi_angles shape: torch.Size([274, 1])
b_factors shape: torch.Size([274, 1])
Shape of labels: torch.Size([274])


 73%|███████▎  | 736/1013 [02:11<01:15,  3.67it/s]

esm2_embeddings shape: torch.Size([329, 1280])
ss_onehot shape: torch.Size([329, 4])
phi_angles shape: torch.Size([329, 1])
psi_angles shape: torch.Size([329, 1])
b_factors shape: torch.Size([329, 1])
Shape of labels: torch.Size([329])
esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])


 73%|███████▎  | 739/1013 [02:11<00:46,  5.92it/s]

esm2_embeddings shape: torch.Size([201, 1280])
ss_onehot shape: torch.Size([201, 4])
phi_angles shape: torch.Size([201, 1])
psi_angles shape: torch.Size([201, 1])
b_factors shape: torch.Size([201, 1])
Shape of labels: torch.Size([201])
esm2_embeddings shape: torch.Size([130, 1280])
ss_onehot shape: torch.Size([130, 4])
phi_angles shape: torch.Size([130, 1])
psi_angles shape: torch.Size([130, 1])
b_factors shape: torch.Size([130, 1])
Shape of labels: torch.Size([130])
esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])


 73%|███████▎  | 741/1013 [02:11<00:47,  5.74it/s]

esm2_embeddings shape: torch.Size([282, 1280])
ss_onehot shape: torch.Size([282, 4])
phi_angles shape: torch.Size([282, 1])
psi_angles shape: torch.Size([282, 1])
b_factors shape: torch.Size([282, 1])
Shape of labels: torch.Size([282])
esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])


 73%|███████▎  | 743/1013 [02:12<01:10,  3.81it/s]

esm2_embeddings shape: torch.Size([264, 1280])
ss_onehot shape: torch.Size([264, 4])
phi_angles shape: torch.Size([264, 1])
psi_angles shape: torch.Size([264, 1])
b_factors shape: torch.Size([264, 1])
Shape of labels: torch.Size([264])
esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])
esm2_embeddings shape: torch.Size([65, 1280])
ss_onehot shape: torch.Size([65, 4])
phi_angles shape: torch.Size([65, 1])
psi_angles shape: torch.Size([65, 1])
b_factors shape: torch.Size([65, 1])
Shape of labels: torch.Size([65])


 74%|███████▎  | 746/1013 [02:12<00:43,  6.11it/s]

esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])


 74%|███████▍  | 748/1013 [02:13<00:34,  7.77it/s]

esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])
esm2_embeddings shape: torch.Size([95, 1280])
ss_onehot shape: torch.Size([95, 4])
phi_angles shape: torch.Size([95, 1])
psi_angles shape: torch.Size([95, 1])
b_factors shape: torch.Size([95, 1])
Shape of labels: torch.Size([95])


 74%|███████▍  | 751/1013 [02:13<00:34,  7.64it/s]

esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])
esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])


 74%|███████▍  | 753/1013 [02:13<00:41,  6.33it/s]

esm2_embeddings shape: torch.Size([288, 1280])
ss_onehot shape: torch.Size([288, 4])
phi_angles shape: torch.Size([288, 1])
psi_angles shape: torch.Size([288, 1])
b_factors shape: torch.Size([288, 1])
Shape of labels: torch.Size([288])
esm2_embeddings shape: torch.Size([172, 1280])
ss_onehot shape: torch.Size([172, 4])
phi_angles shape: torch.Size([172, 1])
psi_angles shape: torch.Size([172, 1])
b_factors shape: torch.Size([172, 1])
Shape of labels: torch.Size([172])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])


 75%|███████▍  | 756/1013 [02:14<00:51,  4.94it/s]

esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])
esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])


 75%|███████▍  | 759/1013 [02:15<00:39,  6.42it/s]

esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])
esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])


 75%|███████▌  | 761/1013 [02:15<00:45,  5.55it/s]

esm2_embeddings shape: torch.Size([322, 1280])
ss_onehot shape: torch.Size([322, 4])
phi_angles shape: torch.Size([322, 1])
psi_angles shape: torch.Size([322, 1])
b_factors shape: torch.Size([322, 1])
Shape of labels: torch.Size([322])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])


 75%|███████▌  | 763/1013 [02:15<00:42,  5.86it/s]

esm2_embeddings shape: torch.Size([246, 1280])
ss_onehot shape: torch.Size([246, 4])
phi_angles shape: torch.Size([246, 1])
psi_angles shape: torch.Size([246, 1])
b_factors shape: torch.Size([246, 1])
Shape of labels: torch.Size([246])
esm2_embeddings shape: torch.Size([60, 1280])
ss_onehot shape: torch.Size([60, 4])
phi_angles shape: torch.Size([60, 1])
psi_angles shape: torch.Size([60, 1])
b_factors shape: torch.Size([60, 1])
Shape of labels: torch.Size([60])


 76%|███████▌  | 766/1013 [02:16<00:34,  7.25it/s]

esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])
esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])


 76%|███████▌  | 767/1013 [02:16<00:57,  4.26it/s]

esm2_embeddings shape: torch.Size([154, 1280])
ss_onehot shape: torch.Size([154, 4])
phi_angles shape: torch.Size([154, 1])
psi_angles shape: torch.Size([154, 1])
b_factors shape: torch.Size([154, 1])
Shape of labels: torch.Size([154])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])


 76%|███████▌  | 770/1013 [02:17<00:43,  5.64it/s]

esm2_embeddings shape: torch.Size([227, 1280])
ss_onehot shape: torch.Size([227, 4])
phi_angles shape: torch.Size([227, 1])
psi_angles shape: torch.Size([227, 1])
b_factors shape: torch.Size([227, 1])
Shape of labels: torch.Size([227])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 76%|███████▌  | 772/1013 [02:17<00:35,  6.70it/s]

esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


 76%|███████▋  | 773/1013 [02:17<00:36,  6.51it/s]

esm2_embeddings shape: torch.Size([204, 1280])
ss_onehot shape: torch.Size([204, 4])
phi_angles shape: torch.Size([204, 1])
psi_angles shape: torch.Size([204, 1])
b_factors shape: torch.Size([204, 1])
Shape of labels: torch.Size([204])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])


 77%|███████▋  | 776/1013 [02:17<00:35,  6.75it/s]

esm2_embeddings shape: torch.Size([233, 1280])
ss_onehot shape: torch.Size([233, 4])
phi_angles shape: torch.Size([233, 1])
psi_angles shape: torch.Size([233, 1])
b_factors shape: torch.Size([233, 1])
Shape of labels: torch.Size([233])
esm2_embeddings shape: torch.Size([186, 1280])
ss_onehot shape: torch.Size([186, 4])
phi_angles shape: torch.Size([186, 1])
psi_angles shape: torch.Size([186, 1])
b_factors shape: torch.Size([186, 1])
Shape of labels: torch.Size([186])


 77%|███████▋  | 778/1013 [02:18<00:28,  8.25it/s]

esm2_embeddings shape: torch.Size([96, 1280])
ss_onehot shape: torch.Size([96, 4])
phi_angles shape: torch.Size([96, 1])
psi_angles shape: torch.Size([96, 1])
b_factors shape: torch.Size([96, 1])
Shape of labels: torch.Size([96])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 77%|███████▋  | 779/1013 [02:18<00:28,  8.16it/s]

esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])


 77%|███████▋  | 781/1013 [02:18<00:44,  5.21it/s]

esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])


 77%|███████▋  | 783/1013 [02:19<00:35,  6.54it/s]

esm2_embeddings shape: torch.Size([150, 1280])
ss_onehot shape: torch.Size([150, 4])
phi_angles shape: torch.Size([150, 1])
psi_angles shape: torch.Size([150, 1])
b_factors shape: torch.Size([150, 1])
Shape of labels: torch.Size([150])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])


 77%|███████▋  | 785/1013 [02:19<00:27,  8.21it/s]

esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])


 78%|███████▊  | 787/1013 [02:19<00:34,  6.49it/s]

esm2_embeddings shape: torch.Size([300, 1280])
ss_onehot shape: torch.Size([300, 4])
phi_angles shape: torch.Size([300, 1])
psi_angles shape: torch.Size([300, 1])
b_factors shape: torch.Size([300, 1])
Shape of labels: torch.Size([300])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])


 78%|███████▊  | 789/1013 [02:19<00:30,  7.30it/s]

esm2_embeddings shape: torch.Size([143, 1280])
ss_onehot shape: torch.Size([143, 4])
phi_angles shape: torch.Size([143, 1])
psi_angles shape: torch.Size([143, 1])
b_factors shape: torch.Size([143, 1])
Shape of labels: torch.Size([143])
esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])
esm2_embeddings shape: torch.Size([90, 1280])
ss_onehot shape: torch.Size([90, 4])
phi_angles shape: torch.Size([90, 1])
psi_angles shape: torch.Size([90, 1])
b_factors shape: torch.Size([90, 1])
Shape of labels: torch.Size([90])


 78%|███████▊  | 792/1013 [02:20<00:25,  8.51it/s]

esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([175, 1280])
ss_onehot shape: torch.Size([175, 4])
phi_angles shape: torch.Size([175, 1])
psi_angles shape: torch.Size([175, 1])
b_factors shape: torch.Size([175, 1])
Shape of labels: torch.Size([175])


 78%|███████▊  | 793/1013 [02:20<00:25,  8.72it/s]

esm2_embeddings shape: torch.Size([141, 1280])
ss_onehot shape: torch.Size([141, 4])
phi_angles shape: torch.Size([141, 1])
psi_angles shape: torch.Size([141, 1])
b_factors shape: torch.Size([141, 1])
Shape of labels: torch.Size([141])
esm2_embeddings shape: torch.Size([120, 1280])
ss_onehot shape: torch.Size([120, 4])
phi_angles shape: torch.Size([120, 1])
psi_angles shape: torch.Size([120, 1])
b_factors shape: torch.Size([120, 1])
Shape of labels: torch.Size([120])


 79%|███████▊  | 797/1013 [02:21<00:42,  5.07it/s]

esm2_embeddings shape: torch.Size([421, 1280])
ss_onehot shape: torch.Size([421, 4])
phi_angles shape: torch.Size([421, 1])
psi_angles shape: torch.Size([421, 1])
b_factors shape: torch.Size([421, 1])
Shape of labels: torch.Size([421])
esm2_embeddings shape: torch.Size([41, 1280])
ss_onehot shape: torch.Size([41, 4])
phi_angles shape: torch.Size([41, 1])
psi_angles shape: torch.Size([41, 1])
b_factors shape: torch.Size([41, 1])
Shape of labels: torch.Size([41])
esm2_embeddings shape: torch.Size([112, 1280])
ss_onehot shape: torch.Size([112, 4])
phi_angles shape: torch.Size([112, 1])
psi_angles shape: torch.Size([112, 1])
b_factors shape: torch.Size([112, 1])
Shape of labels: torch.Size([112])


 79%|███████▉  | 799/1013 [02:21<00:36,  5.83it/s]

esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([79, 1280])
ss_onehot shape: torch.Size([79, 4])
phi_angles shape: torch.Size([79, 1])
psi_angles shape: torch.Size([79, 1])
b_factors shape: torch.Size([79, 1])
Shape of labels: torch.Size([79])


 79%|███████▉  | 801/1013 [02:22<00:43,  4.92it/s]

esm2_embeddings shape: torch.Size([398, 1280])
ss_onehot shape: torch.Size([398, 4])
phi_angles shape: torch.Size([398, 1])
psi_angles shape: torch.Size([398, 1])
b_factors shape: torch.Size([398, 1])
Shape of labels: torch.Size([398])


 79%|███████▉  | 802/1013 [02:22<00:57,  3.64it/s]

esm2_embeddings shape: torch.Size([461, 1280])
ss_onehot shape: torch.Size([461, 4])
phi_angles shape: torch.Size([461, 1])
psi_angles shape: torch.Size([461, 1])
b_factors shape: torch.Size([461, 1])
Shape of labels: torch.Size([461])


 79%|███████▉  | 803/1013 [02:23<01:10,  2.97it/s]

esm2_embeddings shape: torch.Size([166, 1280])
ss_onehot shape: torch.Size([166, 4])
phi_angles shape: torch.Size([166, 1])
psi_angles shape: torch.Size([166, 1])
b_factors shape: torch.Size([166, 1])
Shape of labels: torch.Size([166])
esm2_embeddings shape: torch.Size([116, 1280])
ss_onehot shape: torch.Size([116, 4])
phi_angles shape: torch.Size([116, 1])
psi_angles shape: torch.Size([116, 1])
b_factors shape: torch.Size([116, 1])
Shape of labels: torch.Size([116])


 80%|███████▉  | 807/1013 [02:23<00:40,  5.14it/s]

esm2_embeddings shape: torch.Size([181, 1280])
ss_onehot shape: torch.Size([181, 4])
phi_angles shape: torch.Size([181, 1])
psi_angles shape: torch.Size([181, 1])
b_factors shape: torch.Size([181, 1])
Shape of labels: torch.Size([181])
esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 80%|███████▉  | 810/1013 [02:24<00:30,  6.65it/s]

esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])
esm2_embeddings shape: torch.Size([79, 1280])
ss_onehot shape: torch.Size([79, 4])
phi_angles shape: torch.Size([79, 1])
psi_angles shape: torch.Size([79, 1])
b_factors shape: torch.Size([79, 1])
Shape of labels: torch.Size([79])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])


 80%|████████  | 812/1013 [02:24<00:29,  6.82it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])
esm2_embeddings shape: torch.Size([90, 1280])
ss_onehot shape: torch.Size([90, 4])
phi_angles shape: torch.Size([90, 1])
psi_angles shape: torch.Size([90, 1])
b_factors shape: torch.Size([90, 1])
Shape of labels: torch.Size([90])


 80%|████████  | 814/1013 [02:24<00:30,  6.60it/s]

esm2_embeddings shape: torch.Size([280, 1280])
ss_onehot shape: torch.Size([280, 4])
phi_angles shape: torch.Size([280, 1])
psi_angles shape: torch.Size([280, 1])
b_factors shape: torch.Size([280, 1])
Shape of labels: torch.Size([280])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])


 81%|████████  | 817/1013 [02:25<00:44,  4.42it/s]

esm2_embeddings shape: torch.Size([365, 1280])
ss_onehot shape: torch.Size([365, 4])
phi_angles shape: torch.Size([365, 1])
psi_angles shape: torch.Size([365, 1])
b_factors shape: torch.Size([365, 1])
Shape of labels: torch.Size([365])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


 81%|████████  | 820/1013 [02:26<00:32,  6.00it/s]

esm2_embeddings shape: torch.Size([158, 1280])
ss_onehot shape: torch.Size([158, 4])
phi_angles shape: torch.Size([158, 1])
psi_angles shape: torch.Size([158, 1])
b_factors shape: torch.Size([158, 1])
Shape of labels: torch.Size([158])
esm2_embeddings shape: torch.Size([137, 1280])
ss_onehot shape: torch.Size([137, 4])
phi_angles shape: torch.Size([137, 1])
psi_angles shape: torch.Size([137, 1])
b_factors shape: torch.Size([137, 1])
Shape of labels: torch.Size([137])
esm2_embeddings shape: torch.Size([135, 1280])
ss_onehot shape: torch.Size([135, 4])
phi_angles shape: torch.Size([135, 1])
psi_angles shape: torch.Size([135, 1])
b_factors shape: torch.Size([135, 1])
Shape of labels: torch.Size([135])


 81%|████████  | 822/1013 [02:26<00:26,  7.23it/s]

esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 81%|████████▏ | 824/1013 [02:26<00:27,  6.86it/s]

esm2_embeddings shape: torch.Size([254, 1280])
ss_onehot shape: torch.Size([254, 4])
phi_angles shape: torch.Size([254, 1])
psi_angles shape: torch.Size([254, 1])
b_factors shape: torch.Size([254, 1])
Shape of labels: torch.Size([254])
esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])


 82%|████████▏ | 827/1013 [02:26<00:24,  7.69it/s]

esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])
esm2_embeddings shape: torch.Size([207, 1280])
ss_onehot shape: torch.Size([207, 4])
phi_angles shape: torch.Size([207, 1])
psi_angles shape: torch.Size([207, 1])
b_factors shape: torch.Size([207, 1])
Shape of labels: torch.Size([207])
esm2_embeddings shape: torch.Size([113, 1280])
ss_onehot shape: torch.Size([113, 4])
phi_angles shape: torch.Size([113, 1])
psi_angles shape: torch.Size([113, 1])
b_factors shape: torch.Size([113, 1])
Shape of labels: torch.Size([113])


 82%|████████▏ | 829/1013 [02:27<00:28,  6.43it/s]

esm2_embeddings shape: torch.Size([309, 1280])
ss_onehot shape: torch.Size([309, 4])
phi_angles shape: torch.Size([309, 1])
psi_angles shape: torch.Size([309, 1])
b_factors shape: torch.Size([309, 1])
Shape of labels: torch.Size([309])


 82%|████████▏ | 831/1013 [02:27<00:40,  4.48it/s]

esm2_embeddings shape: torch.Size([114, 1280])
ss_onehot shape: torch.Size([114, 4])
phi_angles shape: torch.Size([114, 1])
psi_angles shape: torch.Size([114, 1])
b_factors shape: torch.Size([114, 1])
Shape of labels: torch.Size([114])
esm2_embeddings shape: torch.Size([220, 1280])
ss_onehot shape: torch.Size([220, 4])
phi_angles shape: torch.Size([220, 1])
psi_angles shape: torch.Size([220, 1])
b_factors shape: torch.Size([220, 1])
Shape of labels: torch.Size([220])


 82%|████████▏ | 833/1013 [02:28<00:32,  5.61it/s]

esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])
esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])


 83%|████████▎ | 836/1013 [02:28<00:22,  7.72it/s]

esm2_embeddings shape: torch.Size([168, 1280])
ss_onehot shape: torch.Size([168, 4])
phi_angles shape: torch.Size([168, 1])
psi_angles shape: torch.Size([168, 1])
b_factors shape: torch.Size([168, 1])
Shape of labels: torch.Size([168])
esm2_embeddings shape: torch.Size([99, 1280])
ss_onehot shape: torch.Size([99, 4])
phi_angles shape: torch.Size([99, 1])
psi_angles shape: torch.Size([99, 1])
b_factors shape: torch.Size([99, 1])
Shape of labels: torch.Size([99])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 83%|████████▎ | 837/1013 [02:28<00:25,  6.99it/s]

esm2_embeddings shape: torch.Size([203, 1280])
ss_onehot shape: torch.Size([203, 4])
phi_angles shape: torch.Size([203, 1])
psi_angles shape: torch.Size([203, 1])
b_factors shape: torch.Size([203, 1])
Shape of labels: torch.Size([203])


 83%|████████▎ | 839/1013 [02:29<00:32,  5.34it/s]

esm2_embeddings shape: torch.Size([342, 1280])
ss_onehot shape: torch.Size([342, 4])
phi_angles shape: torch.Size([342, 1])
psi_angles shape: torch.Size([342, 1])
b_factors shape: torch.Size([342, 1])
Shape of labels: torch.Size([342])
esm2_embeddings shape: torch.Size([195, 1280])
ss_onehot shape: torch.Size([195, 4])
phi_angles shape: torch.Size([195, 1])
psi_angles shape: torch.Size([195, 1])
b_factors shape: torch.Size([195, 1])
Shape of labels: torch.Size([195])
esm2_embeddings shape: torch.Size([101, 1280])
ss_onehot shape: torch.Size([101, 4])
phi_angles shape: torch.Size([101, 1])
psi_angles shape: torch.Size([101, 1])
b_factors shape: torch.Size([101, 1])
Shape of labels: torch.Size([101])


 83%|████████▎ | 842/1013 [02:30<00:42,  3.99it/s]

esm2_embeddings shape: torch.Size([351, 1280])
ss_onehot shape: torch.Size([351, 4])
phi_angles shape: torch.Size([351, 1])
psi_angles shape: torch.Size([351, 1])
b_factors shape: torch.Size([351, 1])
Shape of labels: torch.Size([351])
esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])
esm2_embeddings shape: torch.Size([61, 1280])
ss_onehot shape: torch.Size([61, 4])
phi_angles shape: torch.Size([61, 1])
psi_angles shape: torch.Size([61, 1])
b_factors shape: torch.Size([61, 1])
Shape of labels: torch.Size([61])
esm2_embeddings shape: torch.Size([70, 1280])
ss_onehot shape: torch.Size([70, 4])
phi_angles shape: torch.Size([70, 1])
psi_angles shape: torch.Size([70, 1])
b_factors shape: torch.Size([70, 1])
Shape of labels: torch.Size([70])


 83%|████████▎ | 845/1013 [02:30<00:24,  6.72it/s]

esm2_embeddings shape: torch.Size([66, 1280])
ss_onehot shape: torch.Size([66, 4])
phi_angles shape: torch.Size([66, 1])
psi_angles shape: torch.Size([66, 1])
b_factors shape: torch.Size([66, 1])
Shape of labels: torch.Size([66])
esm2_embeddings shape: torch.Size([221, 1280])
ss_onehot shape: torch.Size([221, 4])
phi_angles shape: torch.Size([221, 1])
psi_angles shape: torch.Size([221, 1])
b_factors shape: torch.Size([221, 1])
Shape of labels: torch.Size([221])


 84%|████████▎ | 848/1013 [02:30<00:25,  6.57it/s]

esm2_embeddings shape: torch.Size([102, 1280])
ss_onehot shape: torch.Size([102, 4])
phi_angles shape: torch.Size([102, 1])
psi_angles shape: torch.Size([102, 1])
b_factors shape: torch.Size([102, 1])
Shape of labels: torch.Size([102])
esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])


 84%|████████▍ | 850/1013 [02:31<00:25,  6.46it/s]

esm2_embeddings shape: torch.Size([262, 1280])
ss_onehot shape: torch.Size([262, 4])
phi_angles shape: torch.Size([262, 1])
psi_angles shape: torch.Size([262, 1])
b_factors shape: torch.Size([262, 1])
Shape of labels: torch.Size([262])
esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([72, 1280])
ss_onehot shape: torch.Size([72, 4])
phi_angles shape: torch.Size([72, 1])
psi_angles shape: torch.Size([72, 1])
b_factors shape: torch.Size([72, 1])
Shape of labels: torch.Size([72])


 84%|████████▍ | 852/1013 [02:31<00:26,  6.06it/s]

esm2_embeddings shape: torch.Size([313, 1280])
ss_onehot shape: torch.Size([313, 4])
phi_angles shape: torch.Size([313, 1])
psi_angles shape: torch.Size([313, 1])
b_factors shape: torch.Size([313, 1])
Shape of labels: torch.Size([313])


 84%|████████▍ | 855/1013 [02:32<00:32,  4.90it/s]

esm2_embeddings shape: torch.Size([279, 1280])
ss_onehot shape: torch.Size([279, 4])
phi_angles shape: torch.Size([279, 1])
psi_angles shape: torch.Size([279, 1])
b_factors shape: torch.Size([279, 1])
Shape of labels: torch.Size([279])
esm2_embeddings shape: torch.Size([75, 1280])
ss_onehot shape: torch.Size([75, 4])
phi_angles shape: torch.Size([75, 1])
psi_angles shape: torch.Size([75, 1])
b_factors shape: torch.Size([75, 1])
Shape of labels: torch.Size([75])
esm2_embeddings shape: torch.Size([109, 1280])
ss_onehot shape: torch.Size([109, 4])
phi_angles shape: torch.Size([109, 1])
psi_angles shape: torch.Size([109, 1])
b_factors shape: torch.Size([109, 1])
Shape of labels: torch.Size([109])


 85%|████████▍ | 858/1013 [02:32<00:23,  6.66it/s]

esm2_embeddings shape: torch.Size([217, 1280])
ss_onehot shape: torch.Size([217, 4])
phi_angles shape: torch.Size([217, 1])
psi_angles shape: torch.Size([217, 1])
b_factors shape: torch.Size([217, 1])
Shape of labels: torch.Size([217])
esm2_embeddings shape: torch.Size([32, 1280])
ss_onehot shape: torch.Size([32, 4])
phi_angles shape: torch.Size([32, 1])
psi_angles shape: torch.Size([32, 1])
b_factors shape: torch.Size([32, 1])
Shape of labels: torch.Size([32])
esm2_embeddings shape: torch.Size([151, 1280])
ss_onehot shape: torch.Size([151, 4])
phi_angles shape: torch.Size([151, 1])
psi_angles shape: torch.Size([151, 1])
b_factors shape: torch.Size([151, 1])
Shape of labels: torch.Size([151])
esm2_embeddings shape: torch.Size([31, 1280])
ss_onehot shape: torch.Size([31, 4])
phi_angles shape: torch.Size([31, 1])
psi_angles shape: torch.Size([31, 1])
b_factors shape: torch.Size([31, 1])
Shape of labels: torch.Size([31])


 85%|████████▍ | 860/1013 [02:32<00:18,  8.29it/s]

esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([230, 1280])
ss_onehot shape: torch.Size([230, 4])
phi_angles shape: torch.Size([230, 1])
psi_angles shape: torch.Size([230, 1])
b_factors shape: torch.Size([230, 1])
Shape of labels: torch.Size([230])


 85%|████████▌ | 863/1013 [02:33<00:20,  7.49it/s]

esm2_embeddings shape: torch.Size([171, 1280])
ss_onehot shape: torch.Size([171, 4])
phi_angles shape: torch.Size([171, 1])
psi_angles shape: torch.Size([171, 1])
b_factors shape: torch.Size([171, 1])
Shape of labels: torch.Size([171])
esm2_embeddings shape: torch.Size([152, 1280])
ss_onehot shape: torch.Size([152, 4])
phi_angles shape: torch.Size([152, 1])
psi_angles shape: torch.Size([152, 1])
b_factors shape: torch.Size([152, 1])
Shape of labels: torch.Size([152])


 85%|████████▌ | 865/1013 [02:33<00:22,  6.50it/s]

esm2_embeddings shape: torch.Size([253, 1280])
ss_onehot shape: torch.Size([253, 4])
phi_angles shape: torch.Size([253, 1])
psi_angles shape: torch.Size([253, 1])
b_factors shape: torch.Size([253, 1])
Shape of labels: torch.Size([253])
esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])


 85%|████████▌ | 866/1013 [02:33<00:21,  6.81it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])


 86%|████████▌ | 868/1013 [02:34<00:31,  4.61it/s]

esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])


 86%|████████▌ | 869/1013 [02:34<00:31,  4.55it/s]

esm2_embeddings shape: torch.Size([249, 1280])
ss_onehot shape: torch.Size([249, 4])
phi_angles shape: torch.Size([249, 1])
psi_angles shape: torch.Size([249, 1])
b_factors shape: torch.Size([249, 1])
Shape of labels: torch.Size([249])
esm2_embeddings shape: torch.Size([82, 1280])
ss_onehot shape: torch.Size([82, 4])
phi_angles shape: torch.Size([82, 1])
psi_angles shape: torch.Size([82, 1])
b_factors shape: torch.Size([82, 1])
Shape of labels: torch.Size([82])


 86%|████████▌ | 872/1013 [02:35<00:25,  5.49it/s]

esm2_embeddings shape: torch.Size([232, 1280])
ss_onehot shape: torch.Size([232, 4])
phi_angles shape: torch.Size([232, 1])
psi_angles shape: torch.Size([232, 1])
b_factors shape: torch.Size([232, 1])
Shape of labels: torch.Size([232])
esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])
esm2_embeddings shape: torch.Size([113, 1280])
ss_onehot shape: torch.Size([113, 4])
phi_angles shape: torch.Size([113, 1])
psi_angles shape: torch.Size([113, 1])
b_factors shape: torch.Size([113, 1])
Shape of labels: torch.Size([113])


 86%|████████▋ | 876/1013 [02:35<00:18,  7.28it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])
esm2_embeddings shape: torch.Size([112, 1280])
ss_onehot shape: torch.Size([112, 4])
phi_angles shape: torch.Size([112, 1])
psi_angles shape: torch.Size([112, 1])
b_factors shape: torch.Size([112, 1])
Shape of labels: torch.Size([112])
esm2_embeddings shape: torch.Size([133, 1280])
ss_onehot shape: torch.Size([133, 4])
phi_angles shape: torch.Size([133, 1])
psi_angles shape: torch.Size([133, 1])
b_factors shape: torch.Size([133, 1])
Shape of labels: torch.Size([133])


 87%|████████▋ | 877/1013 [02:35<00:19,  6.82it/s]

esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])


 87%|████████▋ | 878/1013 [02:36<00:33,  3.97it/s]

esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])


 87%|████████▋ | 879/1013 [02:36<00:37,  3.58it/s]

esm2_embeddings shape: torch.Size([341, 1280])
ss_onehot shape: torch.Size([341, 4])
phi_angles shape: torch.Size([341, 1])
psi_angles shape: torch.Size([341, 1])
b_factors shape: torch.Size([341, 1])
Shape of labels: torch.Size([341])
esm2_embeddings shape: torch.Size([110, 1280])
ss_onehot shape: torch.Size([110, 4])
phi_angles shape: torch.Size([110, 1])
psi_angles shape: torch.Size([110, 1])
b_factors shape: torch.Size([110, 1])
Shape of labels: torch.Size([110])


 87%|████████▋ | 882/1013 [02:37<00:29,  4.42it/s]

esm2_embeddings shape: torch.Size([273, 1280])
ss_onehot shape: torch.Size([273, 4])
phi_angles shape: torch.Size([273, 1])
psi_angles shape: torch.Size([273, 1])
b_factors shape: torch.Size([273, 1])
Shape of labels: torch.Size([273])
esm2_embeddings shape: torch.Size([222, 1280])
ss_onehot shape: torch.Size([222, 4])
phi_angles shape: torch.Size([222, 1])
psi_angles shape: torch.Size([222, 1])
b_factors shape: torch.Size([222, 1])
Shape of labels: torch.Size([222])


 87%|████████▋ | 883/1013 [02:37<00:30,  4.29it/s]

esm2_embeddings shape: torch.Size([273, 1280])
ss_onehot shape: torch.Size([273, 4])
phi_angles shape: torch.Size([273, 1])
psi_angles shape: torch.Size([273, 1])
b_factors shape: torch.Size([273, 1])
Shape of labels: torch.Size([273])
esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])


 87%|████████▋ | 885/1013 [02:37<00:23,  5.43it/s]

esm2_embeddings shape: torch.Size([205, 1280])
ss_onehot shape: torch.Size([205, 4])
phi_angles shape: torch.Size([205, 1])
psi_angles shape: torch.Size([205, 1])
b_factors shape: torch.Size([205, 1])
Shape of labels: torch.Size([205])
esm2_embeddings shape: torch.Size([134, 1280])
ss_onehot shape: torch.Size([134, 4])
phi_angles shape: torch.Size([134, 1])
psi_angles shape: torch.Size([134, 1])
b_factors shape: torch.Size([134, 1])
Shape of labels: torch.Size([134])


 88%|████████▊ | 887/1013 [02:38<00:19,  6.46it/s]

esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])
esm2_embeddings shape: torch.Size([56, 1280])
ss_onehot shape: torch.Size([56, 4])
phi_angles shape: torch.Size([56, 1])
psi_angles shape: torch.Size([56, 1])
b_factors shape: torch.Size([56, 1])
Shape of labels: torch.Size([56])


 88%|████████▊ | 890/1013 [02:38<00:24,  5.06it/s]

esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([202, 1280])
ss_onehot shape: torch.Size([202, 4])
phi_angles shape: torch.Size([202, 1])
psi_angles shape: torch.Size([202, 1])
b_factors shape: torch.Size([202, 1])
Shape of labels: torch.Size([202])


 88%|████████▊ | 892/1013 [02:38<00:19,  6.23it/s]

esm2_embeddings shape: torch.Size([134, 1280])
ss_onehot shape: torch.Size([134, 4])
phi_angles shape: torch.Size([134, 1])
psi_angles shape: torch.Size([134, 1])
b_factors shape: torch.Size([134, 1])
Shape of labels: torch.Size([134])
esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([60, 1280])
ss_onehot shape: torch.Size([60, 4])
phi_angles shape: torch.Size([60, 1])
psi_angles shape: torch.Size([60, 1])
b_factors shape: torch.Size([60, 1])
Shape of labels: torch.Size([60])


 88%|████████▊ | 896/1013 [02:39<00:12,  9.61it/s]

esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])
esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])
esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])


 89%|████████▊ | 898/1013 [02:39<00:10, 10.85it/s]

esm2_embeddings shape: torch.Size([98, 1280])
ss_onehot shape: torch.Size([98, 4])
phi_angles shape: torch.Size([98, 1])
psi_angles shape: torch.Size([98, 1])
b_factors shape: torch.Size([98, 1])
Shape of labels: torch.Size([98])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])
esm2_embeddings shape: torch.Size([159, 1280])
ss_onehot shape: torch.Size([159, 4])
phi_angles shape: torch.Size([159, 1])
psi_angles shape: torch.Size([159, 1])
b_factors shape: torch.Size([159, 1])
Shape of labels: torch.Size([159])


 89%|████████▉ | 900/1013 [02:39<00:10, 10.37it/s]

esm2_embeddings shape: torch.Size([119, 1280])
ss_onehot shape: torch.Size([119, 4])
phi_angles shape: torch.Size([119, 1])
psi_angles shape: torch.Size([119, 1])
b_factors shape: torch.Size([119, 1])
Shape of labels: torch.Size([119])
esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])


 89%|████████▉ | 902/1013 [02:39<00:12,  8.92it/s]

esm2_embeddings shape: torch.Size([215, 1280])
ss_onehot shape: torch.Size([215, 4])
phi_angles shape: torch.Size([215, 1])
psi_angles shape: torch.Size([215, 1])
b_factors shape: torch.Size([215, 1])
Shape of labels: torch.Size([215])
esm2_embeddings shape: torch.Size([212, 1280])
ss_onehot shape: torch.Size([212, 4])
phi_angles shape: torch.Size([212, 1])
psi_angles shape: torch.Size([212, 1])
b_factors shape: torch.Size([212, 1])
Shape of labels: torch.Size([212])


 89%|████████▉ | 904/1013 [02:40<00:13,  8.08it/s]

esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])


 89%|████████▉ | 905/1013 [02:40<00:21,  5.09it/s]

esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])
esm2_embeddings shape: torch.Size([137, 1280])
ss_onehot shape: torch.Size([137, 4])
phi_angles shape: torch.Size([137, 1])
psi_angles shape: torch.Size([137, 1])
b_factors shape: torch.Size([137, 1])
Shape of labels: torch.Size([137])


 90%|████████▉ | 908/1013 [02:41<00:16,  6.29it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])


 90%|████████▉ | 909/1013 [02:41<00:16,  6.35it/s]

esm2_embeddings shape: torch.Size([185, 1280])
ss_onehot shape: torch.Size([185, 4])
phi_angles shape: torch.Size([185, 1])
psi_angles shape: torch.Size([185, 1])
b_factors shape: torch.Size([185, 1])
Shape of labels: torch.Size([185])


 90%|█████████ | 913/1013 [02:41<00:11,  8.41it/s]

esm2_embeddings shape: torch.Size([326, 1280])
ss_onehot shape: torch.Size([326, 4])
phi_angles shape: torch.Size([326, 1])
psi_angles shape: torch.Size([326, 1])
b_factors shape: torch.Size([326, 1])
Shape of labels: torch.Size([326])
esm2_embeddings shape: torch.Size([71, 1280])
ss_onehot shape: torch.Size([71, 4])
phi_angles shape: torch.Size([71, 1])
psi_angles shape: torch.Size([71, 1])
b_factors shape: torch.Size([71, 1])
Shape of labels: torch.Size([71])
esm2_embeddings shape: torch.Size([46, 1280])
ss_onehot shape: torch.Size([46, 4])
phi_angles shape: torch.Size([46, 1])
psi_angles shape: torch.Size([46, 1])
b_factors shape: torch.Size([46, 1])
Shape of labels: torch.Size([46])
esm2_embeddings shape: torch.Size([41, 1280])
ss_onehot shape: torch.Size([41, 4])
phi_angles shape: torch.Size([41, 1])
psi_angles shape: torch.Size([41, 1])
b_factors shape: torch.Size([41, 1])
Shape of labels: torch.Size([41])
esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size

 90%|█████████ | 915/1013 [02:42<00:13,  7.32it/s]

esm2_embeddings shape: torch.Size([236, 1280])
ss_onehot shape: torch.Size([236, 4])
phi_angles shape: torch.Size([236, 1])
psi_angles shape: torch.Size([236, 1])
b_factors shape: torch.Size([236, 1])
Shape of labels: torch.Size([236])
esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])


 91%|█████████ | 917/1013 [02:42<00:23,  4.11it/s]

esm2_embeddings shape: torch.Size([363, 1280])
ss_onehot shape: torch.Size([363, 4])
phi_angles shape: torch.Size([363, 1])
psi_angles shape: torch.Size([363, 1])
b_factors shape: torch.Size([363, 1])
Shape of labels: torch.Size([363])
esm2_embeddings shape: torch.Size([91, 1280])
ss_onehot shape: torch.Size([91, 4])
phi_angles shape: torch.Size([91, 1])
psi_angles shape: torch.Size([91, 1])
b_factors shape: torch.Size([91, 1])
Shape of labels: torch.Size([91])


 91%|█████████ | 920/1013 [02:43<00:17,  5.18it/s]

esm2_embeddings shape: torch.Size([216, 1280])
ss_onehot shape: torch.Size([216, 4])
phi_angles shape: torch.Size([216, 1])
psi_angles shape: torch.Size([216, 1])
b_factors shape: torch.Size([216, 1])
Shape of labels: torch.Size([216])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])


 91%|█████████ | 921/1013 [02:43<00:16,  5.73it/s]

esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])
esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])


 91%|█████████ | 923/1013 [02:43<00:13,  6.62it/s]

esm2_embeddings shape: torch.Size([189, 1280])
ss_onehot shape: torch.Size([189, 4])
phi_angles shape: torch.Size([189, 1])
psi_angles shape: torch.Size([189, 1])
b_factors shape: torch.Size([189, 1])
Shape of labels: torch.Size([189])
esm2_embeddings shape: torch.Size([116, 1280])
ss_onehot shape: torch.Size([116, 4])
phi_angles shape: torch.Size([116, 1])
psi_angles shape: torch.Size([116, 1])
b_factors shape: torch.Size([116, 1])
Shape of labels: torch.Size([116])


 91%|█████████▏| 926/1013 [02:44<00:14,  5.94it/s]

esm2_embeddings shape: torch.Size([349, 1280])
ss_onehot shape: torch.Size([349, 4])
phi_angles shape: torch.Size([349, 1])
psi_angles shape: torch.Size([349, 1])
b_factors shape: torch.Size([349, 1])
Shape of labels: torch.Size([349])
esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])


 92%|█████████▏| 927/1013 [02:44<00:15,  5.72it/s]

esm2_embeddings shape: torch.Size([232, 1280])
ss_onehot shape: torch.Size([232, 4])
phi_angles shape: torch.Size([232, 1])
psi_angles shape: torch.Size([232, 1])
b_factors shape: torch.Size([232, 1])
Shape of labels: torch.Size([232])
esm2_embeddings shape: torch.Size([73, 1280])
ss_onehot shape: torch.Size([73, 4])
phi_angles shape: torch.Size([73, 1])
psi_angles shape: torch.Size([73, 1])
b_factors shape: torch.Size([73, 1])
Shape of labels: torch.Size([73])


 92%|█████████▏| 930/1013 [02:45<00:16,  5.01it/s]

esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])


 92%|█████████▏| 931/1013 [02:45<00:15,  5.23it/s]

esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


 92%|█████████▏| 932/1013 [02:45<00:16,  4.80it/s]

esm2_embeddings shape: torch.Size([269, 1280])
ss_onehot shape: torch.Size([269, 4])
phi_angles shape: torch.Size([269, 1])
psi_angles shape: torch.Size([269, 1])
b_factors shape: torch.Size([269, 1])
Shape of labels: torch.Size([269])
esm2_embeddings shape: torch.Size([230, 1280])
ss_onehot shape: torch.Size([230, 4])
phi_angles shape: torch.Size([230, 1])
psi_angles shape: torch.Size([230, 1])
b_factors shape: torch.Size([230, 1])


 92%|█████████▏| 934/1013 [02:45<00:14,  5.37it/s]

Shape of labels: torch.Size([230])
esm2_embeddings shape: torch.Size([169, 1280])
ss_onehot shape: torch.Size([169, 4])
phi_angles shape: torch.Size([169, 1])
psi_angles shape: torch.Size([169, 1])
b_factors shape: torch.Size([169, 1])
Shape of labels: torch.Size([169])


 92%|█████████▏| 937/1013 [02:46<00:09,  7.73it/s]

esm2_embeddings shape: torch.Size([177, 1280])
ss_onehot shape: torch.Size([177, 4])
phi_angles shape: torch.Size([177, 1])
psi_angles shape: torch.Size([177, 1])
b_factors shape: torch.Size([177, 1])
Shape of labels: torch.Size([177])
esm2_embeddings shape: torch.Size([127, 1280])
ss_onehot shape: torch.Size([127, 4])
phi_angles shape: torch.Size([127, 1])
psi_angles shape: torch.Size([127, 1])
b_factors shape: torch.Size([127, 1])
Shape of labels: torch.Size([127])
esm2_embeddings shape: torch.Size([76, 1280])
ss_onehot shape: torch.Size([76, 4])
phi_angles shape: torch.Size([76, 1])
psi_angles shape: torch.Size([76, 1])
b_factors shape: torch.Size([76, 1])
Shape of labels: torch.Size([76])


 93%|█████████▎| 939/1013 [02:46<00:09,  7.74it/s]

esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([184, 1280])
ss_onehot shape: torch.Size([184, 4])
phi_angles shape: torch.Size([184, 1])
psi_angles shape: torch.Size([184, 1])
b_factors shape: torch.Size([184, 1])
Shape of labels: torch.Size([184])


 93%|█████████▎| 942/1013 [02:47<00:12,  5.48it/s]

esm2_embeddings shape: torch.Size([170, 1280])
ss_onehot shape: torch.Size([170, 4])
phi_angles shape: torch.Size([170, 1])
psi_angles shape: torch.Size([170, 1])
b_factors shape: torch.Size([170, 1])
Shape of labels: torch.Size([170])
esm2_embeddings shape: torch.Size([126, 1280])
ss_onehot shape: torch.Size([126, 4])
phi_angles shape: torch.Size([126, 1])
psi_angles shape: torch.Size([126, 1])
b_factors shape: torch.Size([126, 1])
Shape of labels: torch.Size([126])
esm2_embeddings shape: torch.Size([145, 1280])
ss_onehot shape: torch.Size([145, 4])
phi_angles shape: torch.Size([145, 1])
psi_angles shape: torch.Size([145, 1])
b_factors shape: torch.Size([145, 1])
Shape of labels: torch.Size([145])


 93%|█████████▎| 943/1013 [02:47<00:12,  5.65it/s]

esm2_embeddings shape: torch.Size([181, 1280])
ss_onehot shape: torch.Size([181, 4])
phi_angles shape: torch.Size([181, 1])
psi_angles shape: torch.Size([181, 1])
b_factors shape: torch.Size([181, 1])
Shape of labels: torch.Size([181])


 93%|█████████▎| 945/1013 [02:47<00:12,  5.35it/s]

esm2_embeddings shape: torch.Size([281, 1280])
ss_onehot shape: torch.Size([281, 4])
phi_angles shape: torch.Size([281, 1])
psi_angles shape: torch.Size([281, 1])
b_factors shape: torch.Size([281, 1])
Shape of labels: torch.Size([281])
esm2_embeddings shape: torch.Size([183, 1280])
ss_onehot shape: torch.Size([183, 4])
phi_angles shape: torch.Size([183, 1])
psi_angles shape: torch.Size([183, 1])
b_factors shape: torch.Size([183, 1])
Shape of labels: torch.Size([183])


 93%|█████████▎| 947/1013 [02:48<00:10,  6.30it/s]

esm2_embeddings shape: torch.Size([107, 1280])
ss_onehot shape: torch.Size([107, 4])
phi_angles shape: torch.Size([107, 1])
psi_angles shape: torch.Size([107, 1])
b_factors shape: torch.Size([107, 1])
Shape of labels: torch.Size([107])
esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])


 94%|█████████▎| 949/1013 [02:48<00:09,  6.62it/s]

esm2_embeddings shape: torch.Size([108, 1280])
ss_onehot shape: torch.Size([108, 4])
phi_angles shape: torch.Size([108, 1])
psi_angles shape: torch.Size([108, 1])
b_factors shape: torch.Size([108, 1])
Shape of labels: torch.Size([108])
esm2_embeddings shape: torch.Size([232, 1280])
ss_onehot shape: torch.Size([232, 4])
phi_angles shape: torch.Size([232, 1])
psi_angles shape: torch.Size([232, 1])
b_factors shape: torch.Size([232, 1])
Shape of labels: torch.Size([232])


 94%|█████████▍| 951/1013 [02:48<00:08,  7.34it/s]

esm2_embeddings shape: torch.Size([131, 1280])
ss_onehot shape: torch.Size([131, 4])
phi_angles shape: torch.Size([131, 1])
psi_angles shape: torch.Size([131, 1])
b_factors shape: torch.Size([131, 1])
Shape of labels: torch.Size([131])
esm2_embeddings shape: torch.Size([163, 1280])
ss_onehot shape: torch.Size([163, 4])
phi_angles shape: torch.Size([163, 1])
psi_angles shape: torch.Size([163, 1])
b_factors shape: torch.Size([163, 1])
Shape of labels: torch.Size([163])


 94%|█████████▍| 952/1013 [02:49<00:12,  4.73it/s]

esm2_embeddings shape: torch.Size([109, 1280])
ss_onehot shape: torch.Size([109, 4])
phi_angles shape: torch.Size([109, 1])
psi_angles shape: torch.Size([109, 1])
b_factors shape: torch.Size([109, 1])
Shape of labels: torch.Size([109])


 94%|█████████▍| 953/1013 [02:49<00:14,  4.25it/s]

esm2_embeddings shape: torch.Size([313, 1280])
ss_onehot shape: torch.Size([313, 4])
phi_angles shape: torch.Size([313, 1])
psi_angles shape: torch.Size([313, 1])
b_factors shape: torch.Size([313, 1])
Shape of labels: torch.Size([313])


 94%|█████████▍| 955/1013 [02:49<00:12,  4.64it/s]

esm2_embeddings shape: torch.Size([313, 1280])
ss_onehot shape: torch.Size([313, 4])
phi_angles shape: torch.Size([313, 1])
psi_angles shape: torch.Size([313, 1])
b_factors shape: torch.Size([313, 1])
Shape of labels: torch.Size([313])
esm2_embeddings shape: torch.Size([138, 1280])
ss_onehot shape: torch.Size([138, 4])
phi_angles shape: torch.Size([138, 1])
psi_angles shape: torch.Size([138, 1])
b_factors shape: torch.Size([138, 1])
Shape of labels: torch.Size([138])
esm2_embeddings shape: torch.Size([128, 1280])
ss_onehot shape: torch.Size([128, 4])
phi_angles shape: torch.Size([128, 1])
psi_angles shape: torch.Size([128, 1])
b_factors shape: torch.Size([128, 1])
Shape of labels: torch.Size([128])


 95%|█████████▍| 958/1013 [02:50<00:09,  5.56it/s]

esm2_embeddings shape: torch.Size([282, 1280])
ss_onehot shape: torch.Size([282, 4])
phi_angles shape: torch.Size([282, 1])
psi_angles shape: torch.Size([282, 1])
b_factors shape: torch.Size([282, 1])
Shape of labels: torch.Size([282])
esm2_embeddings shape: torch.Size([140, 1280])
ss_onehot shape: torch.Size([140, 4])
phi_angles shape: torch.Size([140, 1])
psi_angles shape: torch.Size([140, 1])
b_factors shape: torch.Size([140, 1])
Shape of labels: torch.Size([140])


 95%|█████████▍| 959/1013 [02:50<00:09,  5.58it/s]

esm2_embeddings shape: torch.Size([209, 1280])
ss_onehot shape: torch.Size([209, 4])
phi_angles shape: torch.Size([209, 1])
psi_angles shape: torch.Size([209, 1])
b_factors shape: torch.Size([209, 1])
Shape of labels: torch.Size([209])


 95%|█████████▍| 961/1013 [02:51<00:12,  4.17it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([160, 1280])
ss_onehot shape: torch.Size([160, 4])
phi_angles shape: torch.Size([160, 1])
psi_angles shape: torch.Size([160, 1])
b_factors shape: torch.Size([160, 1])
Shape of labels: torch.Size([160])
esm2_embeddings shape: torch.Size([84, 1280])
ss_onehot shape: torch.Size([84, 4])
phi_angles shape: torch.Size([84, 1])
psi_angles shape: torch.Size([84, 1])
b_factors shape: torch.Size([84, 1])
Shape of labels: torch.Size([84])


 95%|█████████▌| 964/1013 [02:51<00:08,  5.51it/s]

esm2_embeddings shape: torch.Size([229, 1280])
ss_onehot shape: torch.Size([229, 4])
phi_angles shape: torch.Size([229, 1])
psi_angles shape: torch.Size([229, 1])
b_factors shape: torch.Size([229, 1])
Shape of labels: torch.Size([229])
esm2_embeddings shape: torch.Size([182, 1280])
ss_onehot shape: torch.Size([182, 4])
phi_angles shape: torch.Size([182, 1])
psi_angles shape: torch.Size([182, 1])
b_factors shape: torch.Size([182, 1])
Shape of labels: torch.Size([182])


 95%|█████████▌| 966/1013 [02:52<00:09,  5.05it/s]

esm2_embeddings shape: torch.Size([310, 1280])
ss_onehot shape: torch.Size([310, 4])
phi_angles shape: torch.Size([310, 1])
psi_angles shape: torch.Size([310, 1])
b_factors shape: torch.Size([310, 1])
Shape of labels: torch.Size([310])
esm2_embeddings shape: torch.Size([176, 1280])
ss_onehot shape: torch.Size([176, 4])
phi_angles shape: torch.Size([176, 1])
psi_angles shape: torch.Size([176, 1])
b_factors shape: torch.Size([176, 1])
Shape of labels: torch.Size([176])


 96%|█████████▌| 969/1013 [02:52<00:05,  7.53it/s]

esm2_embeddings shape: torch.Size([158, 1280])
ss_onehot shape: torch.Size([158, 4])
phi_angles shape: torch.Size([158, 1])
psi_angles shape: torch.Size([158, 1])
b_factors shape: torch.Size([158, 1])
Shape of labels: torch.Size([158])
esm2_embeddings shape: torch.Size([114, 1280])
ss_onehot shape: torch.Size([114, 4])
phi_angles shape: torch.Size([114, 1])
psi_angles shape: torch.Size([114, 1])
b_factors shape: torch.Size([114, 1])
Shape of labels: torch.Size([114])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])


 96%|█████████▌| 970/1013 [02:52<00:05,  7.40it/s]

esm2_embeddings shape: torch.Size([188, 1280])
ss_onehot shape: torch.Size([188, 4])
phi_angles shape: torch.Size([188, 1])
psi_angles shape: torch.Size([188, 1])
b_factors shape: torch.Size([188, 1])
Shape of labels: torch.Size([188])


 96%|█████████▌| 971/1013 [02:53<00:10,  3.83it/s]

esm2_embeddings shape: torch.Size([228, 1280])
ss_onehot shape: torch.Size([228, 4])
phi_angles shape: torch.Size([228, 1])
psi_angles shape: torch.Size([228, 1])
b_factors shape: torch.Size([228, 1])
Shape of labels: torch.Size([228])


 96%|█████████▌| 972/1013 [02:53<00:10,  4.03it/s]

esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])


 96%|█████████▌| 974/1013 [02:53<00:08,  4.81it/s]

esm2_embeddings shape: torch.Size([267, 1280])
ss_onehot shape: torch.Size([267, 4])
phi_angles shape: torch.Size([267, 1])
psi_angles shape: torch.Size([267, 1])
b_factors shape: torch.Size([267, 1])
Shape of labels: torch.Size([267])
esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])
esm2_embeddings shape: torch.Size([118, 1280])
ss_onehot shape: torch.Size([118, 4])
phi_angles shape: torch.Size([118, 1])
psi_angles shape: torch.Size([118, 1])
b_factors shape: torch.Size([118, 1])
Shape of labels: torch.Size([118])


 96%|█████████▋| 977/1013 [02:53<00:05,  6.46it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([162, 1280])
ss_onehot shape: torch.Size([162, 4])
phi_angles shape: torch.Size([162, 1])
psi_angles shape: torch.Size([162, 1])
b_factors shape: torch.Size([162, 1])
Shape of labels: torch.Size([162])
esm2_embeddings shape: torch.Size([136, 1280])
ss_onehot shape: torch.Size([136, 4])
phi_angles shape: torch.Size([136, 1])
psi_angles shape: torch.Size([136, 1])
b_factors shape: torch.Size([136, 1])
Shape of labels: torch.Size([136])


 97%|█████████▋| 980/1013 [02:54<00:05,  6.49it/s]

esm2_embeddings shape: torch.Size([234, 1280])
ss_onehot shape: torch.Size([234, 4])
phi_angles shape: torch.Size([234, 1])
psi_angles shape: torch.Size([234, 1])
b_factors shape: torch.Size([234, 1])
Shape of labels: torch.Size([234])
esm2_embeddings shape: torch.Size([187, 1280])
ss_onehot shape: torch.Size([187, 4])
phi_angles shape: torch.Size([187, 1])
psi_angles shape: torch.Size([187, 1])
b_factors shape: torch.Size([187, 1])
Shape of labels: torch.Size([187])


 97%|█████████▋| 981/1013 [02:55<00:13,  2.41it/s]

esm2_embeddings shape: torch.Size([556, 1280])
ss_onehot shape: torch.Size([556, 4])
phi_angles shape: torch.Size([556, 1])
psi_angles shape: torch.Size([556, 1])
b_factors shape: torch.Size([556, 1])
Shape of labels: torch.Size([556])


 97%|█████████▋| 983/1013 [02:56<00:10,  2.93it/s]

esm2_embeddings shape: torch.Size([337, 1280])
ss_onehot shape: torch.Size([337, 4])
phi_angles shape: torch.Size([337, 1])
psi_angles shape: torch.Size([337, 1])
b_factors shape: torch.Size([337, 1])
Shape of labels: torch.Size([337])
esm2_embeddings shape: torch.Size([210, 1280])
ss_onehot shape: torch.Size([210, 4])
phi_angles shape: torch.Size([210, 1])
psi_angles shape: torch.Size([210, 1])
b_factors shape: torch.Size([210, 1])
Shape of labels: torch.Size([210])


 97%|█████████▋| 984/1013 [02:56<00:08,  3.43it/s]

esm2_embeddings shape: torch.Size([191, 1280])
ss_onehot shape: torch.Size([191, 4])
phi_angles shape: torch.Size([191, 1])
psi_angles shape: torch.Size([191, 1])
b_factors shape: torch.Size([191, 1])
Shape of labels: torch.Size([191])


 97%|█████████▋| 985/1013 [02:56<00:09,  2.86it/s]

esm2_embeddings shape: torch.Size([422, 1280])
ss_onehot shape: torch.Size([422, 4])
phi_angles shape: torch.Size([422, 1])
psi_angles shape: torch.Size([422, 1])
b_factors shape: torch.Size([422, 1])
Shape of labels: torch.Size([422])


 98%|█████████▊| 988/1013 [02:57<00:06,  3.58it/s]

esm2_embeddings shape: torch.Size([225, 1280])
ss_onehot shape: torch.Size([225, 4])
phi_angles shape: torch.Size([225, 1])
psi_angles shape: torch.Size([225, 1])
b_factors shape: torch.Size([225, 1])
Shape of labels: torch.Size([225])
esm2_embeddings shape: torch.Size([75, 1280])
ss_onehot shape: torch.Size([75, 4])
phi_angles shape: torch.Size([75, 1])
psi_angles shape: torch.Size([75, 1])
b_factors shape: torch.Size([75, 1])
Shape of labels: torch.Size([75])
esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])


 98%|█████████▊| 991/1013 [02:57<00:04,  5.47it/s]

esm2_embeddings shape: torch.Size([115, 1280])
ss_onehot shape: torch.Size([115, 4])
phi_angles shape: torch.Size([115, 1])
psi_angles shape: torch.Size([115, 1])
b_factors shape: torch.Size([115, 1])
Shape of labels: torch.Size([115])
esm2_embeddings shape: torch.Size([116, 1280])
ss_onehot shape: torch.Size([116, 4])
phi_angles shape: torch.Size([116, 1])
psi_angles shape: torch.Size([116, 1])
b_factors shape: torch.Size([116, 1])
Shape of labels: torch.Size([116])
esm2_embeddings shape: torch.Size([141, 1280])
ss_onehot shape: torch.Size([141, 4])
phi_angles shape: torch.Size([141, 1])
psi_angles shape: torch.Size([141, 1])
b_factors shape: torch.Size([141, 1])
Shape of labels: torch.Size([141])


 98%|█████████▊| 993/1013 [02:58<00:02,  6.97it/s]

esm2_embeddings shape: torch.Size([90, 1280])
ss_onehot shape: torch.Size([90, 4])
phi_angles shape: torch.Size([90, 1])
psi_angles shape: torch.Size([90, 1])
b_factors shape: torch.Size([90, 1])
Shape of labels: torch.Size([90])
esm2_embeddings shape: torch.Size([124, 1280])
ss_onehot shape: torch.Size([124, 4])
phi_angles shape: torch.Size([124, 1])
psi_angles shape: torch.Size([124, 1])
b_factors shape: torch.Size([124, 1])
Shape of labels: torch.Size([124])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 98%|█████████▊| 996/1013 [02:58<00:02,  7.89it/s]

esm2_embeddings shape: torch.Size([117, 1280])
ss_onehot shape: torch.Size([117, 4])
phi_angles shape: torch.Size([117, 1])
psi_angles shape: torch.Size([117, 1])
b_factors shape: torch.Size([117, 1])
Shape of labels: torch.Size([117])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])
esm2_embeddings shape: torch.Size([80, 1280])
ss_onehot shape: torch.Size([80, 4])
phi_angles shape: torch.Size([80, 1])
psi_angles shape: torch.Size([80, 1])
b_factors shape: torch.Size([80, 1])
Shape of labels: torch.Size([80])


 99%|█████████▊| 998/1013 [02:58<00:01,  8.96it/s]

esm2_embeddings shape: torch.Size([153, 1280])
ss_onehot shape: torch.Size([153, 4])
phi_angles shape: torch.Size([153, 1])
psi_angles shape: torch.Size([153, 1])
b_factors shape: torch.Size([153, 1])
Shape of labels: torch.Size([153])
esm2_embeddings shape: torch.Size([125, 1280])
ss_onehot shape: torch.Size([125, 4])
phi_angles shape: torch.Size([125, 1])
psi_angles shape: torch.Size([125, 1])
b_factors shape: torch.Size([125, 1])
Shape of labels: torch.Size([125])


 99%|█████████▊| 1000/1013 [02:58<00:01,  9.24it/s]

esm2_embeddings shape: torch.Size([144, 1280])
ss_onehot shape: torch.Size([144, 4])
phi_angles shape: torch.Size([144, 1])
psi_angles shape: torch.Size([144, 1])
b_factors shape: torch.Size([144, 1])
Shape of labels: torch.Size([144])


 99%|█████████▉| 1001/1013 [02:59<00:01,  7.51it/s]

esm2_embeddings shape: torch.Size([262, 1280])
ss_onehot shape: torch.Size([262, 4])
phi_angles shape: torch.Size([262, 1])
psi_angles shape: torch.Size([262, 1])
b_factors shape: torch.Size([262, 1])
Shape of labels: torch.Size([262])


 99%|█████████▉| 1003/1013 [02:59<00:02,  4.69it/s]

esm2_embeddings shape: torch.Size([178, 1280])
ss_onehot shape: torch.Size([178, 4])
phi_angles shape: torch.Size([178, 1])
psi_angles shape: torch.Size([178, 1])
b_factors shape: torch.Size([178, 1])
Shape of labels: torch.Size([178])
esm2_embeddings shape: torch.Size([190, 1280])
ss_onehot shape: torch.Size([190, 4])
phi_angles shape: torch.Size([190, 1])
psi_angles shape: torch.Size([190, 1])
b_factors shape: torch.Size([190, 1])
Shape of labels: torch.Size([190])


 99%|█████████▉| 1004/1013 [02:59<00:01,  5.30it/s]

esm2_embeddings shape: torch.Size([149, 1280])
ss_onehot shape: torch.Size([149, 4])
phi_angles shape: torch.Size([149, 1])
psi_angles shape: torch.Size([149, 1])
b_factors shape: torch.Size([149, 1])
Shape of labels: torch.Size([149])


 99%|█████████▉| 1006/1013 [03:00<00:01,  4.97it/s]

esm2_embeddings shape: torch.Size([325, 1280])
ss_onehot shape: torch.Size([325, 4])
phi_angles shape: torch.Size([325, 1])
psi_angles shape: torch.Size([325, 1])
b_factors shape: torch.Size([325, 1])
Shape of labels: torch.Size([325])
esm2_embeddings shape: torch.Size([167, 1280])
ss_onehot shape: torch.Size([167, 4])
phi_angles shape: torch.Size([167, 1])
psi_angles shape: torch.Size([167, 1])
b_factors shape: torch.Size([167, 1])
Shape of labels: torch.Size([167])
esm2_embeddings shape: torch.Size([54, 1280])
ss_onehot shape: torch.Size([54, 4])
phi_angles shape: torch.Size([54, 1])
psi_angles shape: torch.Size([54, 1])
b_factors shape: torch.Size([54, 1])
Shape of labels: torch.Size([54])


100%|█████████▉| 1009/1013 [03:00<00:00,  6.62it/s]

esm2_embeddings shape: torch.Size([214, 1280])
ss_onehot shape: torch.Size([214, 4])
phi_angles shape: torch.Size([214, 1])
psi_angles shape: torch.Size([214, 1])
b_factors shape: torch.Size([214, 1])
Shape of labels: torch.Size([214])
esm2_embeddings shape: torch.Size([168, 1280])
ss_onehot shape: torch.Size([168, 4])
phi_angles shape: torch.Size([168, 1])
psi_angles shape: torch.Size([168, 1])
b_factors shape: torch.Size([168, 1])
Shape of labels: torch.Size([168])


100%|█████████▉| 1010/1013 [03:01<00:00,  5.00it/s]

esm2_embeddings shape: torch.Size([327, 1280])
ss_onehot shape: torch.Size([327, 4])
phi_angles shape: torch.Size([327, 1])
psi_angles shape: torch.Size([327, 1])
b_factors shape: torch.Size([327, 1])
Shape of labels: torch.Size([327])


100%|█████████▉| 1011/1013 [03:01<00:00,  2.86it/s]

esm2_embeddings shape: torch.Size([318, 1280])
ss_onehot shape: torch.Size([318, 4])
phi_angles shape: torch.Size([318, 1])
psi_angles shape: torch.Size([318, 1])
b_factors shape: torch.Size([318, 1])
Shape of labels: torch.Size([318])


100%|█████████▉| 1012/1013 [03:02<00:00,  3.04it/s]

esm2_embeddings shape: torch.Size([283, 1280])
ss_onehot shape: torch.Size([283, 4])
phi_angles shape: torch.Size([283, 1])
psi_angles shape: torch.Size([283, 1])
b_factors shape: torch.Size([283, 1])
Shape of labels: torch.Size([283])


100%|██████████| 1013/1013 [03:02<00:00,  5.56it/s]


esm2_embeddings shape: torch.Size([272, 1280])
ss_onehot shape: torch.Size([272, 4])
phi_angles shape: torch.Size([272, 1])
psi_angles shape: torch.Size([272, 1])
b_factors shape: torch.Size([272, 1])
Shape of labels: torch.Size([272])


  2%|▏         | 1/46 [00:00<00:05,  7.78it/s]

esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])


  4%|▍         | 2/46 [00:00<00:05,  8.45it/s]

esm2_embeddings shape: torch.Size([147, 1280])
ss_onehot shape: torch.Size([147, 4])
phi_angles shape: torch.Size([147, 1])
psi_angles shape: torch.Size([147, 1])
b_factors shape: torch.Size([147, 1])
Shape of labels: torch.Size([147])
esm2_embeddings shape: torch.Size([83, 1280])
ss_onehot shape: torch.Size([83, 4])
phi_angles shape: torch.Size([83, 1])
psi_angles shape: torch.Size([83, 1])
b_factors shape: torch.Size([83, 1])
Shape of labels: torch.Size([83])


 13%|█▎        | 6/46 [00:00<00:03, 11.36it/s]

esm2_embeddings shape: torch.Size([209, 1280])
ss_onehot shape: torch.Size([209, 4])
phi_angles shape: torch.Size([209, 1])
psi_angles shape: torch.Size([209, 1])
b_factors shape: torch.Size([209, 1])
Shape of labels: torch.Size([209])
esm2_embeddings shape: torch.Size([89, 1280])
ss_onehot shape: torch.Size([89, 4])
phi_angles shape: torch.Size([89, 1])
psi_angles shape: torch.Size([89, 1])
b_factors shape: torch.Size([89, 1])
Shape of labels: torch.Size([89])
esm2_embeddings shape: torch.Size([78, 1280])
ss_onehot shape: torch.Size([78, 4])
phi_angles shape: torch.Size([78, 1])
psi_angles shape: torch.Size([78, 1])
b_factors shape: torch.Size([78, 1])
Shape of labels: torch.Size([78])
esm2_embeddings shape: torch.Size([74, 1280])
ss_onehot shape: torch.Size([74, 4])
phi_angles shape: torch.Size([74, 1])
psi_angles shape: torch.Size([74, 1])
b_factors shape: torch.Size([74, 1])
Shape of labels: torch.Size([74])
esm2_embeddings shape: torch.Size([38, 1280])
ss_onehot shape: torch.Size(

 24%|██▍       | 11/46 [00:00<00:02, 13.76it/s]

esm2_embeddings shape: torch.Size([85, 1280])
ss_onehot shape: torch.Size([85, 4])
phi_angles shape: torch.Size([85, 1])
psi_angles shape: torch.Size([85, 1])
b_factors shape: torch.Size([85, 1])
Shape of labels: torch.Size([85])
esm2_embeddings shape: torch.Size([84, 1280])
ss_onehot shape: torch.Size([84, 4])
phi_angles shape: torch.Size([84, 1])
psi_angles shape: torch.Size([84, 1])
b_factors shape: torch.Size([84, 1])
Shape of labels: torch.Size([84])
esm2_embeddings shape: torch.Size([155, 1280])
ss_onehot shape: torch.Size([155, 4])
phi_angles shape: torch.Size([155, 1])
psi_angles shape: torch.Size([155, 1])
b_factors shape: torch.Size([155, 1])
Shape of labels: torch.Size([155])


 28%|██▊       | 13/46 [00:01<00:02, 13.05it/s]

esm2_embeddings shape: torch.Size([109, 1280])
ss_onehot shape: torch.Size([109, 4])
phi_angles shape: torch.Size([109, 1])
psi_angles shape: torch.Size([109, 1])
b_factors shape: torch.Size([109, 1])
Shape of labels: torch.Size([109])
esm2_embeddings shape: torch.Size([123, 1280])
ss_onehot shape: torch.Size([123, 4])
phi_angles shape: torch.Size([123, 1])
psi_angles shape: torch.Size([123, 1])
b_factors shape: torch.Size([123, 1])
Shape of labels: torch.Size([123])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])


 33%|███▎      | 15/46 [00:01<00:02, 13.92it/s]

esm2_embeddings shape: torch.Size([96, 1280])
ss_onehot shape: torch.Size([96, 4])
phi_angles shape: torch.Size([96, 1])
psi_angles shape: torch.Size([96, 1])
b_factors shape: torch.Size([96, 1])
Shape of labels: torch.Size([96])
esm2_embeddings shape: torch.Size([154, 1280])
ss_onehot shape: torch.Size([154, 4])
phi_angles shape: torch.Size([154, 1])
psi_angles shape: torch.Size([154, 1])
b_factors shape: torch.Size([154, 1])
Shape of labels: torch.Size([154])


 37%|███▋      | 17/46 [00:01<00:05,  5.76it/s]

esm2_embeddings shape: torch.Size([268, 1280])
ss_onehot shape: torch.Size([268, 4])
phi_angles shape: torch.Size([268, 1])
psi_angles shape: torch.Size([268, 1])
b_factors shape: torch.Size([268, 1])
Shape of labels: torch.Size([268])


 41%|████▏     | 19/46 [00:02<00:04,  6.19it/s]

esm2_embeddings shape: torch.Size([238, 1280])
ss_onehot shape: torch.Size([238, 4])
phi_angles shape: torch.Size([238, 1])
psi_angles shape: torch.Size([238, 1])
b_factors shape: torch.Size([238, 1])
Shape of labels: torch.Size([238])
esm2_embeddings shape: torch.Size([59, 1280])
ss_onehot shape: torch.Size([59, 4])
phi_angles shape: torch.Size([59, 1])
psi_angles shape: torch.Size([59, 1])
b_factors shape: torch.Size([59, 1])
Shape of labels: torch.Size([59])
esm2_embeddings shape: torch.Size([40, 1280])
ss_onehot shape: torch.Size([40, 4])
phi_angles shape: torch.Size([40, 1])
psi_angles shape: torch.Size([40, 1])
b_factors shape: torch.Size([40, 1])
Shape of labels: torch.Size([40])
esm2_embeddings shape: torch.Size([97, 1280])
ss_onehot shape: torch.Size([97, 4])
phi_angles shape: torch.Size([97, 1])
psi_angles shape: torch.Size([97, 1])
b_factors shape: torch.Size([97, 1])
Shape of labels: torch.Size([97])


 52%|█████▏    | 24/46 [00:02<00:02,  9.17it/s]

esm2_embeddings shape: torch.Size([132, 1280])
ss_onehot shape: torch.Size([132, 4])
phi_angles shape: torch.Size([132, 1])
psi_angles shape: torch.Size([132, 1])
b_factors shape: torch.Size([132, 1])
Shape of labels: torch.Size([132])
esm2_embeddings shape: torch.Size([104, 1280])
ss_onehot shape: torch.Size([104, 4])
phi_angles shape: torch.Size([104, 1])
psi_angles shape: torch.Size([104, 1])
b_factors shape: torch.Size([104, 1])
Shape of labels: torch.Size([104])
esm2_embeddings shape: torch.Size([103, 1280])
ss_onehot shape: torch.Size([103, 4])
phi_angles shape: torch.Size([103, 1])
psi_angles shape: torch.Size([103, 1])
b_factors shape: torch.Size([103, 1])
Shape of labels: torch.Size([103])


 57%|█████▋    | 26/46 [00:02<00:02,  7.97it/s]

esm2_embeddings shape: torch.Size([165, 1280])
ss_onehot shape: torch.Size([165, 4])
phi_angles shape: torch.Size([165, 1])
psi_angles shape: torch.Size([165, 1])
b_factors shape: torch.Size([165, 1])
Shape of labels: torch.Size([165])
esm2_embeddings shape: torch.Size([217, 1280])
ss_onehot shape: torch.Size([217, 4])
phi_angles shape: torch.Size([217, 1])
psi_angles shape: torch.Size([217, 1])
b_factors shape: torch.Size([217, 1])
Shape of labels: torch.Size([217])
esm2_embeddings shape: torch.Size([161, 1280])
ss_onehot shape: torch.Size([161, 4])
phi_angles shape: torch.Size([161, 1])
psi_angles shape: torch.Size([161, 1])
b_factors shape: torch.Size([161, 1])
Shape of labels: torch.Size([161])


 61%|██████    | 28/46 [00:03<00:02,  6.82it/s]

esm2_embeddings shape: torch.Size([288, 1280])
ss_onehot shape: torch.Size([288, 4])
phi_angles shape: torch.Size([288, 1])
psi_angles shape: torch.Size([288, 1])
b_factors shape: torch.Size([288, 1])
Shape of labels: torch.Size([288])


 67%|██████▋   | 31/46 [00:04<00:02,  5.95it/s]

esm2_embeddings shape: torch.Size([176, 1280])
ss_onehot shape: torch.Size([176, 4])
phi_angles shape: torch.Size([176, 1])
psi_angles shape: torch.Size([176, 1])
b_factors shape: torch.Size([176, 1])
Shape of labels: torch.Size([176])
esm2_embeddings shape: torch.Size([94, 1280])
ss_onehot shape: torch.Size([94, 4])
phi_angles shape: torch.Size([94, 1])
psi_angles shape: torch.Size([94, 1])
b_factors shape: torch.Size([94, 1])
Shape of labels: torch.Size([94])
esm2_embeddings shape: torch.Size([83, 1280])
ss_onehot shape: torch.Size([83, 4])
phi_angles shape: torch.Size([83, 1])
psi_angles shape: torch.Size([83, 1])
b_factors shape: torch.Size([83, 1])
Shape of labels: torch.Size([83])


 74%|███████▍  | 34/46 [00:04<00:01,  6.63it/s]

esm2_embeddings shape: torch.Size([275, 1280])
ss_onehot shape: torch.Size([275, 4])
phi_angles shape: torch.Size([275, 1])
psi_angles shape: torch.Size([275, 1])
b_factors shape: torch.Size([275, 1])
Shape of labels: torch.Size([275])
esm2_embeddings shape: torch.Size([88, 1280])
ss_onehot shape: torch.Size([88, 4])
phi_angles shape: torch.Size([88, 1])
psi_angles shape: torch.Size([88, 1])
b_factors shape: torch.Size([88, 1])
Shape of labels: torch.Size([88])
esm2_embeddings shape: torch.Size([143, 1280])
ss_onehot shape: torch.Size([143, 4])
phi_angles shape: torch.Size([143, 1])
psi_angles shape: torch.Size([143, 1])
b_factors shape: torch.Size([143, 1])
Shape of labels: torch.Size([143])


 76%|███████▌  | 35/46 [00:04<00:01,  6.88it/s]

esm2_embeddings shape: torch.Size([157, 1280])
ss_onehot shape: torch.Size([157, 4])
phi_angles shape: torch.Size([157, 1])
psi_angles shape: torch.Size([157, 1])
b_factors shape: torch.Size([157, 1])
Shape of labels: torch.Size([157])
esm2_embeddings shape: torch.Size([100, 1280])
ss_onehot shape: torch.Size([100, 4])
phi_angles shape: torch.Size([100, 1])
psi_angles shape: torch.Size([100, 1])
b_factors shape: torch.Size([100, 1])
Shape of labels: torch.Size([100])


 85%|████████▍ | 39/46 [00:05<00:00,  7.47it/s]

esm2_embeddings shape: torch.Size([276, 1280])
ss_onehot shape: torch.Size([276, 4])
phi_angles shape: torch.Size([276, 1])
psi_angles shape: torch.Size([276, 1])
b_factors shape: torch.Size([276, 1])
Shape of labels: torch.Size([276])
esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])
esm2_embeddings shape: torch.Size([121, 1280])
ss_onehot shape: torch.Size([121, 4])
phi_angles shape: torch.Size([121, 1])
psi_angles shape: torch.Size([121, 1])
b_factors shape: torch.Size([121, 1])
Shape of labels: torch.Size([121])


 89%|████████▉ | 41/46 [00:05<00:00,  8.45it/s]

esm2_embeddings shape: torch.Size([92, 1280])
ss_onehot shape: torch.Size([92, 4])
phi_angles shape: torch.Size([92, 1])
psi_angles shape: torch.Size([92, 1])
b_factors shape: torch.Size([92, 1])
Shape of labels: torch.Size([92])
esm2_embeddings shape: torch.Size([139, 1280])
ss_onehot shape: torch.Size([139, 4])
phi_angles shape: torch.Size([139, 1])
psi_angles shape: torch.Size([139, 1])
b_factors shape: torch.Size([139, 1])
Shape of labels: torch.Size([139])


 91%|█████████▏| 42/46 [00:05<00:00,  8.15it/s]

esm2_embeddings shape: torch.Size([179, 1280])
ss_onehot shape: torch.Size([179, 4])
phi_angles shape: torch.Size([179, 1])
psi_angles shape: torch.Size([179, 1])
b_factors shape: torch.Size([179, 1])
Shape of labels: torch.Size([179])


 96%|█████████▌| 44/46 [00:06<00:00,  5.12it/s]

esm2_embeddings shape: torch.Size([164, 1280])
ss_onehot shape: torch.Size([164, 4])
phi_angles shape: torch.Size([164, 1])
psi_angles shape: torch.Size([164, 1])
b_factors shape: torch.Size([164, 1])
Shape of labels: torch.Size([164])
esm2_embeddings shape: torch.Size([146, 1280])
ss_onehot shape: torch.Size([146, 4])
phi_angles shape: torch.Size([146, 1])
psi_angles shape: torch.Size([146, 1])
b_factors shape: torch.Size([146, 1])
Shape of labels: torch.Size([146])
esm2_embeddings shape: torch.Size([105, 1280])
ss_onehot shape: torch.Size([105, 4])
phi_angles shape: torch.Size([105, 1])
psi_angles shape: torch.Size([105, 1])
b_factors shape: torch.Size([105, 1])
Shape of labels: torch.Size([105])


100%|██████████| 46/46 [00:06<00:00,  7.31it/s]

esm2_embeddings shape: torch.Size([129, 1280])
ss_onehot shape: torch.Size([129, 4])
phi_angles shape: torch.Size([129, 1])
psi_angles shape: torch.Size([129, 1])
b_factors shape: torch.Size([129, 1])
Shape of labels: torch.Size([129])





### Create baseline GNN

In [None]:
class BindingSiteGCN(nn.Module):
    def __init__(self, node_dim, edge_dim=2, hidden_dim=512):
        """
        A baseline GCN for binding site prediction.

        Args:
            node_dim (int): Dimension of node features (1288 in your case).
            edge_dim (int): Dimension of edge features (2 in your case).
            hidden_dim (int): Hidden dimension for GCN layers.
        """
        super(BindingSiteGCN, self).__init__()
        # GCN layers with edge features
        self.conv1 = GCNConv(node_dim, hidden_dim)  # Remove edge_dim
        self.conv2 = GCNConv(hidden_dim, 256)  # Remove edge_dim
        self.conv3 = GCNConv(256, 128)  # Remove edge_dim

        # Linear layer to process edge features
        self.edge_lin = nn.Linear(edge_dim, hidden_dim)

        # Final classifier
        # self.fc = nn.Linear(128, 2)  # 2 classes: 0 (non-binding), 1 (binding)
        self.pre_fc = nn.Linear(128, 16)  # Reduce dimension before KAN
        self.fc = KAN([16, 8, 2], grid=2, k=2)

        # Activation and dropout
        self.activation_func = nn.LeakyReLU(negative_slope=0.1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # Get the number of nodes in the current batch

        # Process edge features separately
        edge_attr = self.edge_lin(edge_attr)

        # GCN layers
        # Pass num_nodes to gcn_norm
        x = self.conv1(x, edge_index)  # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index) # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index) # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)

        # Final classification
        x = self.fc(x)  # Shape: [num_residues, 2]

        return x  # Logits for each residue

In [None]:
class BindingSiteGAT(nn.Module):
    def __init__(self, node_dim=1287, hidden_dim=512, heads=4):
        """
        A GAT-based model for binding site prediction with a KAN classifier.

        Args:
            node_dim (int): Dimension of node features (e.g., 1287 for ESM-2 + 7 features).
            hidden_dim (int): Hidden dimension for GAT layers.
            heads (int): Number of attention heads in GAT layers.
        """
        super(BindingSiteGAT, self).__init__()
        # GAT layers with multi-head attention
        self.conv1 = GATConv(node_dim, hidden_dim, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_dim * heads, 256, heads=heads, concat=True)
        self.conv3 = GATConv(256 * heads, 128, heads=1, concat=False)  # Single head for final layer

        # KAN layer for classification
        self.pre_fc = nn.Linear(128, 16)  # Reduce dimension before KAN
        self.fc = KAN([16, 8, 2], grid=2, k=2)  # Smaller KAN with reduced spline complexity

        # Activation and dropout
        self.activation_func = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # edge_attr is unused in GATConv

        # GAT layers
        x = self.conv1(x, edge_index)
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index)
        x = self.activation_func(x)

        # KAN classifier
        x = self.fc(x)  # Shape: [num_residues, 2]

        return x  # Logits for each residue

In [None]:
train_graphs, val_graphs = train_test_split(
    train_graphs_data,
    test_size=0.05,
    random_state=42,
    shuffle=True
)

print(f"Number of training graphs: {len(train_graphs)}")
print(f"Number of validation graphs: {len(val_graphs)}")
# print(f"Number of testing graphs: {len(test_graphs_data)}")

Number of training graphs: 962
Number of validation graphs: 51


In [None]:
for idx, train_graph in enumerate(train_graphs):
    if train_graph.x.shape[0] != train_graph.y.shape[0]:
        print(f"Abnormal at index: {idx}")
        print(f"Train graph's input shape: {train_graph['x'].shape[0]} and labels shape: {train_graph['y'].shape[0]}")

In [None]:
for idx, val_graph in enumerate(val_graphs):
    if val_graph.x.shape[0] != val_graph.y.shape[0]:
        print(f"Abnormal at index: {idx}")
        print(f"Train graph's input shape: {val_graph['x'].shape[0]} and labels shape: {val_graph['y'].shape[0]}")

### Model Training

In [None]:
# Create DataLoaders
batch_size = 16

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_graphs_data, batch_size=batch_size, shuffle=False)

### Customized loss function

In [None]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, pos_weight):
        super().__init__()
        self.pos_weight = pos_weight

    def forward(self, logits, labels):
        # Handle class imbalance with weighted loss
        weight = torch.tensor([1.0, self.pos_weight]).to(logits.device)  # [weight for class 0, weight for class 1]
        loss_fct = nn.CrossEntropyLoss(weight=weight)
        return loss_fct(logits, labels)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha  # Weight for the positive class
        self.gamma = gamma  # Focusing parameter
        self.reduction = reduction

    def forward(self, logits, labels):
        # Compute the cross-entropy loss (without reduction)
        ce_loss = nn.functional.cross_entropy(logits, labels, reduction='none')

        # Compute the probability of the true class
        probs = torch.softmax(logits, dim=-1)
        true_probs = probs[torch.arange(probs.size(0), device=probs.device), labels]

        # Compute the focal loss term: (1 - p_t)^gamma
        focal_term = (1 - true_probs) ** self.gamma

        # Apply the alpha weighting
        alpha_weight = torch.where(labels == 1, self.alpha, 1.0 - self.alpha).to(logits.device)

        # Compute the focal loss
        loss = alpha_weight * focal_term * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class PositionAwareLoss(nn.Module):
    def __init__(self, pos_weight, position_weight, alpha=0.25, gamma=2.0):
        super().__init__()
        self.weighted_ce = WeightedCrossEntropyLoss(pos_weight)
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma, reduction='mean')
        self.position_weight = position_weight

    def forward(self, logits, labels, batch=None):
        # Compute the base loss (weighted cross-entropy + focal loss)
        ce_loss = self.weighted_ce(logits, labels)
        focal_loss = self.focal_loss(logits, labels)
        base_loss = ce_loss + focal_loss

        # Position-aware component
        probs = torch.softmax(logits, dim=-1)[:, 1]  # Get binding probabilities [total_num_nodes]
        position_loss = torch.tensor(0.0).to(logits.device)

        # Since we're using a GNN, we need to account for the graph structure
        # batch.batch indicates which nodes belong to which graph
        if batch is not None and batch.num_graphs == 1:  # Single graph per batch
            num_nodes = batch.num_nodes
            # Penalize offset predictions by checking neighboring nodes
            for i in range(1, num_nodes - 1):
                # Encourage predictions to match true binding site positions
                if labels[i] == 1 or labels[i-1] == 1 or labels[i+1] == 1:
                    position_loss += torch.abs(probs[i] - (labels[i] == 1).float())

        return base_loss + self.position_weight * position_loss

In [None]:
input_train_embedding_dim = torch.tensor(train_embeddings[0]['embeddings']).shape[1] + 7

# model = BindingSiteGCN(node_dim = input_train_embedding_dim, edge_dim=2, hidden_dim=512).to(device)
model = BindingSiteGAT(node_dim = input_train_embedding_dim, hidden_dim=512, heads=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# class_weights = torch.tensor([0.2, 0.8]).to(device)
pos_weight = 5.0
position_weight = 1
focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
criterion = PositionAwareLoss(pos_weight=pos_weight, position_weight=position_weight, alpha=0.25, gamma=2.0)

checkpoint directory created: ./model
saving model version 0.0


In [None]:
def evaluate(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model(batch)

            # Debug: Print shapes
            # print(f"Batch num_graphs: {batch.num_graphs}")
            # print(f"Batch num_nodes: {batch.num_nodes}")
            # print(f"Out shape: {out.shape}")
            # print(f"Labels shape: {batch.y.shape}")

            probs = torch.softmax(out, dim=1)[:, 1]
            preds = torch.argmax(out, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    auc = roc_auc_score(all_labels, all_probs)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "mcc": mcc
    }

In [None]:
gc.collect()

num_epochs = 50
best_val_f1 = 0
best_model_state = None

model.train()
for epoch in range(num_epochs):
    total_train_loss = 0
    model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        # print(f"Batch shape: {batch}")
        out = model(batch)
        # print(f"Out shape: {out.shape}")
        # print(f"Labels shape: {batch.y.shape}")
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    val_metrics = evaluate(model, eval_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {avg_train_loss:.4f}")
    print(f"Validation Precision: {val_metrics['precision']:.4f}")
    print(f"Validation Recall: {val_metrics['recall']:.4f}")
    print(f"Validation F1-Score: {val_metrics['f1']:.4f}")
    print(f"Validation AUC-ROC: {val_metrics['auc']:.4f}")

    if val_metrics['f1'] > best_val_f1:
        best_val_f1 = val_metrics['f1']
        best_model_state = model.state_dict()
        print("Best validation F1-score improved! Saving model state.")

Epoch 1/50
Training Loss: 0.5445
Validation Precision: 0.3869
Validation Recall: 0.5325
Validation F1-Score: 0.4482
Validation AUC-ROC: 0.8649
Best validation F1-score improved! Saving model state.
Epoch 2/50
Training Loss: 0.4849
Validation Precision: 0.4686
Validation Recall: 0.4329
Validation F1-Score: 0.4500
Validation AUC-ROC: 0.8719
Best validation F1-score improved! Saving model state.
Epoch 3/50
Training Loss: 0.4517
Validation Precision: 0.4064
Validation Recall: 0.5436
Validation F1-Score: 0.4651
Validation AUC-ROC: 0.8733
Best validation F1-score improved! Saving model state.
Epoch 4/50
Training Loss: 0.4165
Validation Precision: 0.4113
Validation Recall: 0.5450
Validation F1-Score: 0.4688
Validation AUC-ROC: 0.8663
Best validation F1-score improved! Saving model state.
Epoch 5/50
Training Loss: 0.3816
Validation Precision: 0.4503
Validation Recall: 0.4827
Validation F1-Score: 0.4660
Validation AUC-ROC: 0.8625
Epoch 6/50
Training Loss: 0.3253
Validation Precision: 0.3523
Val

In [None]:
if best_model_state is not None:
    model.load_state_dict(best_model_state)

# with open('/content/drive/MyDrive/Protein-binding/data/test300_loader.pkl', 'rb') as f:
#     saved_test_loader = pickle.load(f)

model.eval()
test_metrics = evaluate(model, test_loader, device)

In [None]:
print(test_metrics)

{'precision': 0.3533834586466165, 'recall': 0.3269565217391304, 'f1': 0.33965672990063234, 'auc': 0.7553989353518569, 'mcc': 0.27559158269020767}


In [None]:
print(saved_model_name)

/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_small_binding_Test46_23Apr_model.pth


In [None]:
torch.save(best_model_state, saved_model_name)

### Save and load model weights

In [None]:
def create_labels(data_loader):
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            all_labels.extend(batch.y.cpu().numpy())
    return all_labels

In [None]:
def inference(model_name, data_loader):
    all_preds = []
    model_instance = BindingSiteGAT(node_dim = input_train_embedding_dim, hidden_dim=512, heads=4).to(device)
    # model_instance = BindingSiteGCN(node_dim = input_train_embedding_dim, edge_dim = 2, hidden_dim=512).to(device)
    model_instance.load_state_dict(torch.load(model_name))
    model_instance.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model_instance(batch)

            probs = torch.softmax(out, dim=1)[:, 1]
            preds = torch.argmax(out, dim=1)

            all_preds.extend(preds.cpu().numpy())

    return all_preds

nuclear_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_nuclear_binding_Test46_23Apr_model.pth"
nuclear_preds = inference(nuclear_model_name, test_loader)

metal_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_metal_binding_Test46_23Apr_model.pth"
metal_preds = inference(metal_model_name, test_loader)

small_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_small_binding_Test46_23Apr_model.pth"
small_preds = inference(small_model_name, test_loader)

merged_preds_list = [max(a, b, c) for a, b, c in zip(nuclear_preds, metal_preds, small_preds)]
print(merged_preds_list)

checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [None]:
all_labels = create_labels(test_loader)

In [None]:
metal_precision, metal_recall, metal_f1, _ = precision_recall_fscore_support(all_labels, metal_preds, average='binary')
metal_mcc = matthews_corrcoef(all_labels, metal_preds)

nuclear_precision, nuclear_recall, nuclear_f1, _ = precision_recall_fscore_support(all_labels, nuclear_preds, average='binary')
nuclear_mcc = matthews_corrcoef(all_labels, nuclear_preds)

small_precision, small_recall, small_f1, _ = precision_recall_fscore_support(all_labels, small_preds, average='binary')
small_mcc = matthews_corrcoef(all_labels, small_preds)

overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(all_labels, merged_preds_list, average='binary')
overall_mcc = matthews_corrcoef(all_labels, merged_preds_list)


print(f"Metal Precision: {metal_precision:.4f}")
print(f"Metal Recall: {metal_recall:.4f}")
print(f"Metal F1-Score: {metal_f1:.4f}")
print(f"Metal MCC score: {metal_mcc:.4f}")
print("\n")

print(f"Nuclear Precision: {nuclear_precision:.4f}")
print(f"Nuclear Recall: {nuclear_recall:.4f}")
print(f"Nuclear F1-Score: {nuclear_f1:.4f}")
print(f"Nuclear MCC score: {nuclear_mcc:.4f}")
print("\n")

print(f"Small Precision: {small_precision:.4f}")
print(f"Small Recall: {small_recall:.4f}")
print(f"Small F1-Score: {small_f1:.4f}")
print(f"Small MCC score: {small_mcc:.4f}")
print("\n")

print(f"Overall Precision: {overall_precision:.4f}")
print(f"Overall Recall: {overall_recall:.4f}")
print(f"Overall F1-Score: {overall_f1:.4f}")
print(f"Overall MCC score: {overall_mcc:.4f}")

Metal Precision: 0.3770
Metal Recall: 0.3357
Metal F1-Score: 0.3551
Metal MCC score: 0.2943


Nuclear Precision: 0.3584
Nuclear Recall: 0.3565
Nuclear F1-Score: 0.3575
Nuclear MCC score: 0.2923


Small Precision: 0.3534
Small Recall: 0.3270
Small F1-Score: 0.3397
Small MCC score: 0.2756


Overall Precision: 0.3143
Overall Recall: 0.4400
Overall F1-Score: 0.3667
Overall MCC score: 0.2954


### Error analysis

In [None]:
# all_preds, all_labels, all_probs = [], [], []

# with torch.no_grad():
#     for test_batch in test_loader:
#         test_batch = test_batch.to(device)
#         out = model(test_batch)

#         probs = torch.softmax(out, dim=1)[:, 1]
#         preds = torch.argmax(out, dim=1)

#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(batch.y.cpu().numpy())
#         all_probs.extend(probs.cpu().numpy())

In [None]:
# all_test_prot_sequences = test_df['sequence'].tolist()
# all_sequences = []
# for prot_seq in all_test_prot_sequences:
#     all_sequences.extend(list(prot_seq))

In [None]:
# from collections import Counter
# from copy import deepcopy

# test_aa_counter = Counter(all_sequences)
# false_negatives_dict = {}

# for prob, label, pred, acid in zip(all_probs, all_labels, all_preds, all_sequences):
#     if label == 1 and pred == 0:
#         print(f"Acid: {acid} and its probability: {prob}")

#         if acid not in false_negatives_dict:
#             false_negatives_dict[acid] = 1
#         else:
#             false_negatives_dict[acid] += 1

In [None]:
# false_negatives_percentage_dict = deepcopy(false_negatives_dict)

# for acid in false_negatives_percentage_dict:
#     false_negatives_percentage_dict[acid] /= test_aa_counter[acid]

In [None]:
# # print(false_negatives_percentage_dict)
# sorted_false_negatives_percentage_dict = dict(sorted(false_negatives_percentage_dict.items(), key=lambda item: item[1], reverse=True))
# print(sorted_false_negatives_percentage_dict)

In [None]:
# sorted_false_negatives_percentage_dict