<a href="https://colab.research.google.com/github/DougChul/RNA/blob/Colab/relpos%2Bpairwise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Import

In [None]:
import os
import sys
import pickle
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import math


In [None]:
#set seed for everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
print(torch.__version__)
print(pd.__version__)

2.6.0+cu124
2.2.2


## Config

In [None]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,  ##
    "batch_size": 1,
    "learning_rate": 0.3*1e-5,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 50,  ##
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "min_len_filter": 10,
    "structural_violation_epoch": 50,
    "balance_weight": False,
    "n_tokens": 4,
    "d_model": 256,  ##
    "n_heads": 8,
    "dropout": 0.1,
    "d_ff": 1024, ##
    "norm_ratio": 1.0, ##
    "n_layers": 48, ##
    "pairwise_dimension": 128,
    "dim_msa": 32,
}

## Set Sample Data To make Model

In [None]:
os.getcwd()

'/content'

In [None]:

# folder_path = '/content/drive/MyDrive/RNA/stanford-rna-3d-folding'

# ## Select Data
# # set_data_fomat = '5models'
# # set_data_fomat = 'v2'
# set_data_fomat = 'v1'

# if set_data_fomat == '5models':
#     train_labels = pd.read_csv(os.path.join(folder_path, 'pdb_labels_5models.csv'))
#     train_sequences = pd.read_csv(os.path.join(folder_path, 'pdb_sequences_5models.csv'))
# elif set_data_fomat == 'v2':
#     train_labels = pd.read_csv(os.path.join(folder_path, 'train_labels.v2.csv'))
#     train_sequences = pd.read_csv(os.path.join(folder_path, 'train_sequences.v2.csv'))
# elif set_data_fomat == 'v1':
#     train_labels = pd.read_csv(os.path.join(folder_path, 'train_labels.csv'))
#     train_sequences = pd.read_csv(os.path.join(folder_path, 'train_sequences.csv'))
# else:
#     raise ValueError("Invalid set_data_fomat")


# print(train_labels.head())

# print(train_sequences.head())


In [None]:
# train_labels["pdb_id"] = train_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
# # train_sequences["pdb_id"]

In [None]:
# all_xyz=[]

# test_sample = False
# count = 0

# for pdb_id in tqdm(train_sequences['target_id']):
#     df = train_labels[train_labels["pdb_id"]==pdb_id]
#     #break
#     # xyz=df[['x_1','y_1','z_1','x_2','y_2','z_2','x_3','y_3','z_3','x_4','y_4','z_4','x_5','y_5','z_5',]].to_numpy().astype('float32')
#     xyz=df[['x_1','y_1','z_1']].to_numpy().astype('float32')
#     xyz[xyz<-1e17]=float('Nan');
#     all_xyz.append(xyz)
#     if test_sample == True:
#       count += 1
#       if count == 100:
#         break
# # all_xyz[13]

### filter data

In [None]:
# # filter the data
# # Filter and process data
# filter_nan = []
# max_len = 0
# filter_ratio = 0 # All data are valid
# for xyz in all_xyz:
#     if len(xyz) > max_len:
#         max_len = len(xyz)

#     filter_nan.append((np.isnan(xyz).mean() <= filter_ratio) & \
#                       (len(xyz)<config['max_len_filter']) & \
#                       (len(xyz)>config['min_len_filter']))


# print(f"Longest sequence in train: {max_len}")

# filter_nan = np.array(filter_nan)
# non_nan_indices = np.arange(len(filter_nan))[filter_nan]
# print('remain sequences:', len(non_nan_indices))
# train_sequences = train_sequences.loc[non_nan_indices].reset_index(drop=True)
# non_nan_xyz=[all_xyz[i] for i in non_nan_indices]

In [None]:
# #pack data into a dictionary

# data={
#       "pdb_id":train_sequences['target_id'].to_list(),
#       "sequence":train_sequences['sequence'].to_list(),
#       "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
#       "description": train_sequences['description'].to_list(),
#       "all_sequences": train_sequences['all_sequences'].to_list(),
#       "xyz": non_nan_xyz
# }
# print(data['pdb_id'][2])
# print(data['sequence'][1])
# print(data['temporal_cutoff'][1])
# # print(data['description'][1])
# # print(data['all_sequences'][1])
# # print(data['xyz'][1])

In [None]:
# with open('/content/drive/MyDrive/RNA/stanford-rna-3d-folding/processed_data.v1.pkl', 'wb') as f:
#     pickle.dump(data, f)
# print("데이터가 processed_data.pkl 파일로 저장되었습니다.")

In [None]:
## Load Preprocess DAta

with open('/content/drive/MyDrive/RNA/stanford-rna-3d-folding/processed_data.v2_max=384.pkl', 'rb') as f:
    data = pickle.load(f)

### Split Train / Validation set

In [None]:
# Split data into train and test
all_index = np.arange(len(data['sequence']))
cutoff_date = pd.Timestamp(config['cutoff_date'])
test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
train_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
test_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

### To Pytorch Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix

class RNA3D_Dataset(Dataset):
    def __init__(self,indices,data):
        self.indices=indices
        self.data=data
        self.tokens={nt:i for i,nt in enumerate('ACGU')}
        # self.tokens = {'A': 2, 'U':-2, 'G':3, 'C':-3}

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):

        idx=self.indices[idx]
        sequence=[self.tokens[nt] for nt in (self.data['sequence'][idx])]
        sequence=np.array(sequence)
        sequence=torch.tensor(sequence)

        #get C1' xyz
        xyz=self.data['xyz'][idx]
        xyz=torch.tensor(np.array(xyz))


        if len(sequence)>config['max_len']:
            crop_start=np.random.randint(len(sequence)-config['max_len'])
            crop_end=crop_start+config['max_len']

            sequence=sequence[crop_start:crop_end]
            xyz=xyz[crop_start:crop_end]
        #center at first atom if first atom does not exit go until it does
        for i in range(len(xyz)):
            if (~torch.isnan(xyz[i])).all():
                break
        xyz=xyz-xyz[i]

        return {'pbd_id':self.data['pdb_id'][idx],
            'sequence':sequence,
                'xyz':xyz}

In [None]:

train_dataset=RNA3D_Dataset(train_index,data)
val_dataset=RNA3D_Dataset(test_index,data)


In [None]:
train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=1,shuffle=False)

# print(train_dataset[2]['xyz'][6])


## Model definition

### Example with one data

In [None]:
test_sample = True

if test_sample == True:
  from torch.utils.data import Subset

  target_index = 0
  train_dataset = Subset(train_loader.dataset, [target_index])
  train_loader = DataLoader(train_dataset, batch_size=1)

  val_dataset = Subset(val_loader.dataset, [target_index])
  val_loader = DataLoader(val_dataset, batch_size=1)

print(len(train_dataset[0]['sequence']))
print(train_dataset[0])
print(len(train_loader))
print(len(val_loader))
# print(f"새로운 데이터셋 크기: {len(single_data_loader.dataset)}")
# print(f"새로운 DataLoader의 배치 개수: {len(single_data_loader)}")

12
{'pbd_id': '3J0Q_3', 'sequence': tensor([2, 1, 1, 0, 2, 3, 2, 0, 0, 0, 3, 0]), 'xyz': tensor([[  0.0000,   0.0000,   0.0000],
        [ -1.0760,  -3.3350,  -3.4670],
        [ -4.4130,  -6.6660,  -7.6410],
        [ -7.0260,  -0.6310,  -9.7380],
        [ -3.1140,   3.6950, -11.0880],
        [ -2.5960,   8.3430, -14.4300],
        [  1.2760,  10.8900, -15.9760],
        [  7.1280,  12.5880, -19.2280],
        [  8.5000,   9.1200, -14.2150],
        [  4.5950,   6.0570,  -8.3870],
        [  6.2800,   5.2940,  -2.2990],
        [  6.1480,   2.7090,   3.4880]])}
1
1


### Check Pairwise

In [None]:
# import RNA
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# src = torch.tensor([2, 2, 1, 2, 3, 0, 0, 2, 2, 0, 3, 3, 0, 1, 1, 3, 0, 3, 2, 1, 1], dtype=torch.long)
# # AUGC를 0123으로 매핑
# nucleotide_map = {0: 'A', 1: 'C', 2: 'G', 3: 'U'}

# str_A = ''.join([nucleotide_map[x.item()] for x in src])

# print(len(str_A))
# print(str_A)

# md = RNA.md()

# fc = RNA.fold_compound(str_A, md)

# # predict Minmum Free Energy and corresponding secondary structure
# # (ss, mfe) =
# # print(fc.mfe())
# fc.pf()
# B = torch.tensor(fc.bpp())

# # B = torch.tensor(B)
# B = B[1:,1:]
# print(type(B))
# print(B.shape)
# B = B + B.T
# # B = np.array(B, dtype=np.bool())
# # print(B)

# plt.figure(figsize=(5,4))
# sns.heatmap(B, cmap='viridis', annot=False)  # annot=True로 설정하면 각 셀에 값 표시
# plt.title('Pairwise Contact Probability Matrix')
# plt.show()

# # def get_pairwise_features(src, seq_len, d_model):
# #     nucleotide_map = {0: 'A', 1: 'C', 2: 'G', 3: 'U'}

# #     # print("pairwise",src)
# #     src = src.squeeze(0)
# #     str_seq = ''.join([nucleotide_map[x.item()] for x in src])
# #     # print("str_seq",str_seq)

# #     md = RNA.md()
# #     fc = RNA.fold_compound(str_seq, md)
# #     fc.pf() ##???

# #     pair_matrix = torch.tensor(fc.bpp(), dtype=torch.float32)
# #     # print("pair_matrix",pair_matrix)
# #     pair_matrix = pair_matrix[1:,1:] # remove first row and column 0 index in bpp is always 0
# #     pair_matrix = pair_matrix + pair_matrix.T # symmetric matrix

# #     pair_matrix = pair_matrix.unsqueeze(0).unsqueeze(0)
# #     pair_matrix = F.interpolate(pair_matrix, size=(d_model, d_model), mode='bilinear', align_corners=False)
# #     # pair_matrix = F.interpolate(pair_matrix, size=(d_model, d_model), mode='nearest', align_corners=False)

# #     pair_matrix = pair_matrix.squeeze()
# #     # print("pair_matrix",pair_matrix.shape)

# #     return pair_matrix

### Pairwise Module

In [None]:
from torch import einsum

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class OuterProductMean(nn.Module):
    def __init__(self, in_dim=256, dim_msa=32, pairwise_dim=128):
        super(OuterProductMean, self).__init__()
        self.proj_down1 = nn.Linear(in_dim, dim_msa)
        self.proj_down2 = nn.Linear(dim_msa ** 2, pairwise_dim)

    def forward(self,seq_rep, pair_rep=None):
        seq_rep=self.proj_down1(seq_rep)
        outer_product = torch.einsum('bid,bjc -> bijcd', seq_rep, seq_rep)
        outer_product = rearrange(outer_product, 'b i j c d -> b i j (c d)')
        outer_product = self.proj_down2(outer_product)

        if pair_rep is not None:
            outer_product=outer_product+pair_rep

        return outer_product


class TriangleAttention(nn.Module):
    def __init__(self, in_dim=128, dim=32, n_heads=4, wise='row'):
        super(TriangleAttention, self).__init__()
        self.n_heads = n_heads
        self.wise = wise
        self.norm = nn.LayerNorm(in_dim)
        self.to_qkv = nn.Linear(in_dim, dim * 3 * n_heads, bias=False)
        self.linear_for_pair = nn.Linear(in_dim, n_heads, bias=False)
        self.to_gate = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.Sigmoid()
        )
        self.to_out = nn.Linear(n_heads * dim, in_dim)
        # self.to_out.weight.data.fill_(0.)
        # self.to_out.bias.data.fill_(0.)

    def forward(self, z, src_mask):
        """
        how to do masking
        for row tri attention:
        attention matrix is brijh, where b is batch, r is row, h is head
        so mask should be b()ijh, i.e. take self attention mask and unsqueeze(1,-1)
        add negative inf to matrix before softmax

        for col tri attention
        attention matrix is bijlh, so take self attention mask and unsqueeze(3,-1)

        take src_mask and spawn pairwise mask, and unsqueeze accordingly
        """

        #spwan pair mask
        # print('spwan pair mask')
        src_mask[src_mask==0]=-1
        src_mask=src_mask.unsqueeze(-1).float()

        attn_mask=torch.matmul(src_mask,src_mask.permute(0,2,1))


        wise = self.wise
        z = self.norm(z)
        q, k, v = torch.chunk(self.to_qkv(z), 3, -1)
        q, k, v = map(lambda x: rearrange(x, 'b i j (h d)->b i j h d', h=self.n_heads), (q, k, v))
        b = self.linear_for_pair(z)
        gate = self.to_gate(z)
        scale = q.size(-1) ** .5
        if wise == 'row':
            eq_attn = 'brihd,brjhd->brijh'
            eq_multi = 'brijh,brjhd->brihd'
            b = rearrange(b, 'b i j (r h)->b r i j h', r=1)
            softmax_dim = 3
            attn_mask=rearrange(attn_mask, 'b i j->b 1 i j 1')
        elif wise == 'col':
            eq_attn = 'bilhd,bjlhd->bijlh'
            eq_multi = 'bijlh,bjlhd->bilhd'
            b = rearrange(b, 'b i j (l h)->b i j l h', l=1)
            softmax_dim = 2
            attn_mask=rearrange(attn_mask, 'b i j->b i j 1 1')
        else:
            raise ValueError('wise should be col or row!')
        logits = (torch.einsum(eq_attn, q, k) / scale + b)
        # plt.imshow(attn_mask[0,0,:,:,0])
        # plt.show()
        # exit()
        logits = logits.masked_fill(attn_mask == -1, float('-1e-9'))
        attn = logits.softmax(softmax_dim)
        # print(attn.shape)
        # print(v.shape)
        out = torch.einsum(eq_multi, attn, v)
        out = gate * rearrange(out, 'b i j h d-> b i j (h d)')
        z_ = self.to_out(out)
        return z_


class TriangleMultiplicativeModule(nn.Module):
    def __init__(
        self,
        *,
        dim,
        hidden_dim = None,
        mix = 'ingoing'
    ):
        super().__init__()
        assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'

        hidden_dim = default(hidden_dim, dim)
        self.norm = nn.LayerNorm(dim)

        self.left_proj = nn.Linear(dim, hidden_dim)
        self.right_proj = nn.Linear(dim, hidden_dim)

        self.left_gate = nn.Linear(dim, hidden_dim)
        self.right_gate = nn.Linear(dim, hidden_dim)
        self.out_gate = nn.Linear(dim, hidden_dim)

        # initialize all gating to be identity

        for gate in (self.left_gate, self.right_gate, self.out_gate):
            nn.init.constant_(gate.weight, 0.)
            nn.init.constant_(gate.bias, 1.)

        if mix == 'outgoing':
            self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
        elif mix == 'ingoing':
            self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

        self.to_out_norm = nn.LayerNorm(hidden_dim)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, src_mask):
        # print(src_mask.shape)
        src_mask=src_mask.unsqueeze(-1).float()
        mask = torch.matmul(src_mask,src_mask.permute(0,2,1))
        assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
        if exists(mask):
            mask = rearrange(mask, 'b i j -> b i j ()')

        x = self.norm(x)

        left = self.left_proj(x)
        right = self.right_proj(x)

        if exists(mask):
            left = left * mask
            right = right * mask

        left_gate = self.left_gate(x).sigmoid()
        right_gate = self.right_gate(x).sigmoid()
        out_gate = self.out_gate(x).sigmoid()

        left = left * left_gate
        right = right * right_gate

        out = einsum(self.mix_einsum_eq, left, right)

        out = self.to_out_norm(out)
        out = out * out_gate
        return self.to_out(out)


### dropout Module

In [None]:
from functools import partialmethod
from typing import Union, List


class Dropout(nn.Module):
    """
    Implementation of dropout with the ability to share the dropout mask
    along a particular dimension.

    If not in training mode, this module computes the identity function.
    """

    def __init__(self, r: float, batch_dim: Union[int, List[int]]):
        """
        Args:
            r:
                Dropout rate
            batch_dim:
                Dimension(s) along which the dropout mask is shared
        """
        super(Dropout, self).__init__()

        self.r = r
        if type(batch_dim) == int:
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim
        self.dropout = nn.Dropout(self.r)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x:
                Tensor to which dropout is applied. Can have any shape
                compatible with self.batch_dim
        """
        shape = list(x.shape)
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd] = 1
        mask = x.new_ones(shape)
        mask = self.dropout(mask)
        x = x * mask
        return x


class DropoutRowwise(Dropout):
    """
    Convenience class for rowwise dropout as described in subsection
    1.11.6.
    """

    __init__ = partialmethod(Dropout.__init__, batch_dim=-3)


class DropoutColumnwise(Dropout):
    """
    Convenience class for columnwise dropout as described in subsection
    1.11.6.
    """

    __init__ = partialmethod(Dropout.__init__, batch_dim=-2)

### Back Bone

In [None]:
# import RNA
import torch.nn.functional as F
from einops import rearrange, repeat, reduce


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=config['max_len']):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        self.pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = self.pe.unsqueeze(0)
        self.register_buffer("pe", self.pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

class relpos(nn.Module):

    def __init__(self, dim=64):
        super(relpos, self).__init__()
        self.linear = nn.Linear(33, dim)

    def forward(self, src):
        L=src.shape[1]
        res_id = torch.arange(L).to(src.device).unsqueeze(0)
        device = res_id.device
        bin_values = torch.arange(-16, 17, device=device)
        #print((bin_values))
        d = res_id[:, :, None] - res_id[:, None, :]
        bdy = torch.tensor(16, device=device)
        d = torch.minimum(torch.maximum(-bdy, d), bdy)
        d_onehot = (d[..., None] == bin_values).float()
        #print(d_onehot.sum(dim=-1).min())
        assert d_onehot.sum(dim=-1).min() == 1
        p = self.linear(d_onehot)
        return p

# class ConstrainedPositionalEncoding(nn.Module):
#     def __init__(self, d_model, max_relative_position=32, constrained_position=4):
#         super().__init__()
#         self.max_relative_position = max_relative_position
#         self.relative_embedding = nn.Embedding(2 * max_relative_position + 1, d_model)
#         self.constrained_position = constrained_position

#     def forward(self, seq_len):

#         return pos_encoding(seq_len, self.d_model)


class EmbedSequence(nn.Module):
    def __init__(self, d_model, out_dim, n_tokens = config['n_tokens']):
        super().__init__()

        self.embedder = nn.Embedding(n_tokens, d_model)
        self.outer_product_mean = OuterProductMean(pairwise_dim=out_dim)
        self.pos_encoder = relpos(out_dim)

    def forward(self, sequence):
        B,L = sequence.shape
        # print(seq_len)
        # pos_encoding = pos_encoder(seq_len, self.d_model, sequence.device)
        sequence = sequence.long()
        sequence = self.embedder(sequence).reshape(B,L,-1)

        pairwise_feature = self.outer_product_mean(sequence)
        pairwise_feature = pairwise_feature + self.pos_encoder(sequence)

        return sequence, pairwise_feature


class MultiheadAtt(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

        self.w_q = nn.Linear(d_model, n_heads * self.d_head)
        self.w_k = nn.Linear(d_model, n_heads * self.d_head)
        self.w_v = nn.Linear(d_model, n_heads * self.d_head)

        self.w_o = nn.Linear(n_heads * self.d_head, d_model)

    def forward(self, query, key, value, mask=None):
        # print('query.shape',query.shape)
        batch_size, length, d_model = query.size()

        q = self.w_q(query)
        k = self.w_k(key)
        v = self.w_v(value)

        q = q.view(batch_size, length, self.n_heads, self.d_head)
        k = k.view(batch_size, length, self.n_heads, self.d_head)
        v = v.view(batch_size, length, self.n_heads, self.d_head)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        if mask is not None: # pairwise bias
          scores = scores + mask  # For head axis broadcasting
        scores = self.dropout_1(F.softmax(scores, dim=-1))

        output = torch.matmul(scores, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, length, -1)

        output = self.dropout_2(self.w_o(output))

        return output


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.dropout = nn.Dropout(dropout)

        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.w_1(x)
        x = self.activation(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.w_2(x)

        return x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, config=config, dropout=0, norm_type='post_ln'):
        super().__init__()

        self.d_ff = config['d_ff']

        pairwise_dimension=config['pairwise_dimension']
        self.outer_product_mean = OuterProductMean(pairwise_dim=pairwise_dimension)

        # self.use_triangular_attention=config['use_triangular_attention']

        if (norm_type == 'post_ln') or (norm_type == 'pre_ln'):
            self.norm_type = norm_type
        else:
            raise ValueError("Invalid norm_type")

        self.norm_0 = nn.LayerNorm(d_model)
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)

        self.attention = MultiheadAtt(d_model, n_heads, dropout)
        self.feedforward = FeedForward(d_model, self.d_ff, dropout)
        self.activation = nn.ReLU()

        self.pairwise_norm = nn.LayerNorm(pairwise_dimension)
        self.pairwise2heads=nn.Linear(pairwise_dimension,n_heads,bias=False) #nhead = 8

        ## triangle
        self.triangle_update_out=TriangleMultiplicativeModule(dim=pairwise_dimension,mix='outgoing')
        self.triangle_update_in=TriangleMultiplicativeModule(dim=pairwise_dimension,mix='ingoing')

        self.pair_dropout_out=DropoutRowwise(dropout) # dropout.py batch_dim=-3 q=0.1 row dropout
        self.pair_dropout_in=DropoutRowwise(dropout)  # dropout.py batch_dim=-2 q=0.1 column dropout

        self.triangle_attention_out=TriangleAttention(in_dim=pairwise_dimension,
                                                                    dim=pairwise_dimension//4,
                                                                    wise='row')
        self.triangle_attention_in=TriangleAttention(in_dim=pairwise_dimension,
                                                                    dim=pairwise_dimension//4,
                                                                    wise='col')

        self.pair_attention_dropout_out=DropoutRowwise(dropout)
        self.pair_attention_dropout_in=DropoutColumnwise(dropout)

        self.pair_transition=nn.Sequential(
                                           nn.LayerNorm(pairwise_dimension),
                                           nn.Linear(pairwise_dimension,pairwise_dimension*4),
                                           nn.ReLU(inplace=True),
                                           nn.Linear(pairwise_dimension*4,pairwise_dimension))


    def forward(self, src, pairwise_features=None, src_mask=None, use_gradient_checkpoint=False):
        # src = residual_connection(src, self.encoder(src))
        norm_type = self.norm_type
        pairwise_bias = self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2)

        if norm_type == 'post_ln':
            src = self.norm_0(src)
            src = src + self.attention(src,src,src,mask=pairwise_bias)  #residual conn
            src = self.norm_1(src)
            # src = self.activation(src)
            src = src + self.feedforward(src)  #residual conn
            src = self.norm_2(src)
            # src = self.activation(src)

        elif norm_type == 'pre_ln':
            src_temp = src
            src = self.norm_0(src)
            src = self.attention(src,src,src) + src_temp
            # src = self.activation(src)
            src_temp = src
            src = self.norm_1(src)
            src = self.feedforward(src) + src_temp
            # src = self.activation(src)
            src = self.norm_2(src)

        if use_gradient_checkpoint:
            pairwise_features=pairwise_features+checkpoint.checkpoint(self.custom(self.outer_product_mean), src)
            pairwise_features=pairwise_features+self.pair_dropout_out(
                checkpoint.checkpoint(self.custom(self.triangle_update_out), pairwise_features, src_mask))
            pairwise_features=pairwise_features+self.pair_dropout_in(
                checkpoint.checkpoint(self.custom(self.triangle_update_in), pairwise_features, src_mask))

        else:
            pairwise_features=pairwise_features+self.outer_product_mean(src)
            pairwise_features=pairwise_features+self.pair_dropout_out(self.triangle_update_out(pairwise_features,src_mask))
            pairwise_features=pairwise_features+self.pair_dropout_in(self.triangle_update_in(pairwise_features,src_mask))

        # if self.use_triangular_attention:
        pairwise_features=pairwise_features+self.pair_attention_dropout_out(self.triangle_attention_out(pairwise_features,src_mask))
        pairwise_features=pairwise_features+self.pair_attention_dropout_in(self.triangle_attention_in(pairwise_features,src_mask))

        if use_gradient_checkpoint:
            pairwise_features=pairwise_features+checkpoint.checkpoint(self.custom(self.pair_transition),pairwise_features)
        else:
            pairwise_features=pairwise_features+self.pair_transition(pairwise_features)

        return src, pairwise_features


class WuSubSol(nn.Module):
    def __init__(self, config):
        super(WuSubSol, self).__init__()
        self.config = config
        self.n_layers = config['n_layers']
        self.d_model = config['d_model']

        self.norm_ratio = config['norm_ratio']
        if (self.n_layers * self.norm_ratio).is_integer():
            self.n_post_ln = int(self.n_layers * self.norm_ratio)
        else:
            print("Invalid norm_ratio")
            exit()

        self.embedding = EmbedSequence(self.d_model, config['pairwise_dimension'])

        self.encoder_layers = []
        count_post_ln, count_pre_ln = 0,0
        for i in range(self.n_post_ln):
            self.encoder_layers.append(EncoderLayer(self.d_model, config['n_heads'], config, config['dropout'], 'post_ln'))
            count_post_ln += 1
        for i in range(self.n_layers - self.n_post_ln):
            self.encoder_layers.append(EncoderLayer(self.d_model, config['n_heads'], config['dropout'], 'pre_ln'))
            count_pre_ln += 1

        print(f"{count_post_ln} post_ln layers and {count_pre_ln} pre_ln layers")

        self.encoder_layers = nn.ModuleList(self.encoder_layers)

        self.final_linear = nn.Linear(self.d_model, 3)

        print(f"{self.n_layers} layers of encoder constructed")

    def forward(self, src, src_mask=None):

        src_mask = torch.ones_like(src).to(src.device)
        src, pairwise_feature = self.embedding(src) # L*d_model
        # print('after embedding', pairwise_feature.shape, pairwise_feature)


        for i,layer in enumerate(self.encoder_layers):
            # print("before",src.shape)
            src, pairwise_feature = layer(src, pairwise_feature, src_mask)
            # print("after",src.shape)
            # print(src, pairwise_feature)

        src = self.final_linear(src).squeeze()
        for i in range(len(src)):
            if (~torch.isnan(src[i])).all():
                break
        src=src-src[i]

        # print('final tensor-Shape : ', src.shape)

        return src



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = WuSubSol(config)
model.to(device)
# print(model)


48 post_ln layers and 0 pre_ln layers
48 layers of encoder constructed


WuSubSol(
  (embedding): EmbedSequence(
    (embedder): Embedding(4, 256)
    (outer_product_mean): OuterProductMean(
      (proj_down1): Linear(in_features=256, out_features=32, bias=True)
      (proj_down2): Linear(in_features=1024, out_features=128, bias=True)
    )
    (pos_encoder): relpos(
      (linear): Linear(in_features=33, out_features=128, bias=True)
    )
  )
  (encoder_layers): ModuleList(
    (0-47): 48 x EncoderLayer(
      (outer_product_mean): OuterProductMean(
        (proj_down1): Linear(in_features=256, out_features=32, bias=True)
        (proj_down2): Linear(in_features=1024, out_features=128, bias=True)
      )
      (norm_0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAtt(
        (dropout_1): Dropout(p=0.1, inplace=False)
        (dropout_2): Dropout(p=0.1, inplace=False)
        (w_q)

## Training Loop

### Define Loss Function

In [None]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)

    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z

def align_svd_mae(input, target, Z=10):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.

    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.

    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    # print('input-Shape', input.shape)
    # print('output-Shape', target.shape)
    # target = target[:, :3]
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask
    mask=~torch.isnan(target.sum(-1))

    input=input[mask]
    target=target[mask]

    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

    # # Compute RMSD loss
    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # return aligned_input, rmsd_loss
    return torch.abs(aligned_input-target).mean()/Z

def align_svd_rmsd(input, target):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.

    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.

    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    # print('input-Shape', input.shape)
    # print('output-Shape', target.shape)
    # target = target[:,:3]
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask
    mask=~torch.isnan(target.sum(-1))


    input=input[mask]
    target=target[mask]

    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

    # # Compute RMSD loss
    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # return aligned_input, rmsd_loss
    return torch.square(aligned_input-target).mean().sqrt()

def compute_lddt(ground_truth_atoms, predicted_atoms, cutoff=30.0, thresholds=[1.0, 2.0, 4.0, 8.0]):
    """
    Computes the lDDT score between ground truth and predicted atoms.

    Parameters:
        ground_truth_atoms (np.array): Nx3 array of ground truth atom coordinates.
        predicted_atoms (np.array): Nx3 array of predicted atom coordinates.
        cutoff (float): Distance cutoff in Ångstroms to consider neighbors. Default is 30 Å.
        thresholds (list): List of thresholds in Ångstroms for the lDDT computation. Default is [0.5, 1.0, 2.0, 4.0].

    Returns:
        float: The lDDT score.
    """
    # Number of atoms
    num_atoms = ground_truth_atoms.shape[0]

    # Initialize array to store lDDT fractions for each threshold
    fractions = np.zeros(len(thresholds))

    for i in range(num_atoms):
        # Get the distances from atom i to all other atoms for both ground truth and predicted atoms
        gt_distances = np.linalg.norm(ground_truth_atoms[i] - ground_truth_atoms, axis=1)
        pred_distances = np.linalg.norm(predicted_atoms[i] - predicted_atoms, axis=1)

        # print(gt_distances)
        # print(pred_distances)
        # exit()
        # Apply the cutoff to consider only distances within the cutoff range
        mask = (gt_distances > 0) & (gt_distances < cutoff)

        # Calculate the absolute difference between ground truth and predicted distances
        distance_diff = np.abs(gt_distances[mask] - pred_distances[mask])

        # Filter out any NaN values from the distance difference calculation
        valid_mask = ~np.isnan(distance_diff)
        distance_diff = distance_diff[valid_mask]

        # Compute the fractions for each threshold
        for j, threshold in enumerate(thresholds):
            if len(distance_diff)>0:
                fractions[j] += np.mean(distance_diff < threshold)
    # print(fractions)
    # print(num_atoms)

    # Average the fractions over the number of atoms
    fractions /= num_atoms

    # The final lDDT score is the average of these fractions
    lddt_score = np.mean(fractions)

    return lddt_score

### Training SetUp

In [None]:
from tqdm import tqdm

epochs=config['epochs']
# cos_epoch=config['cos_epoch']
cos_epoch=0


best_loss=np.inf
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0, lr=config['learning_rate']) #no weight decay following AF

batch_size=1

#for cycle in range(2):

criterion=torch.nn.CrossEntropyLoss(reduction='none')

#scaler = GradScaler()

schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(epochs-cos_epoch)*len(train_loader)//batch_size)


### Train

In [None]:
version_name = 'triangle_test_set_1'

best_val_loss=99999999999
loss_df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_loss'])

# print(type(train_loader))

# epochs = 20

for epoch in range(epochs):
    model.train()
    tbar=tqdm(train_loader)
    total_loss=0
    train_loss=0

    for idx, batch in enumerate(tbar):
        #try:
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].squeeze().cuda()

        mask=~torch.isnan(gt_xyz)
        gt_xyz[torch.isnan(gt_xyz)]=0
        # print('start sequence',sequence.shape)

        pred_xyz = model(sequence).squeeze()
        if epoch == epochs-1:
              print(pred_xyz)
              print(gt_xyz)

        loss = dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)

        if loss!=loss:
            stop

        try:
          (loss/batch_size).backward()
        except:
          print(gt_xyz.shape)

        if (idx+1)%batch_size==0 or idx+1 == len(tbar):

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            # scaler.scale(loss/batch_size).backward()
            # scaler.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            # scaler.step(optimizer)
            # scaler.update()

            if (epoch+1)>cos_epoch:
                schedule.step()

        total_loss += loss
        train_loss = total_loss/(idx+1)

        tbar.set_description(f"Epoch {epoch + 1} Loss: {total_loss/(idx+1)}")

    ### Validation
    tbar=tqdm(val_loader)

    model.eval()
    val_preds=[]
    val_loss=0

    for idx, batch in enumerate(tbar):
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].squeeze().cuda()

        with torch.no_grad():
            pred_xyz=model(sequence).squeeze()
            loss = dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
            if epoch == epochs-1:
              print(pred_xyz)
              print(gt_xyz)

        val_loss+=loss
        val_preds.append([gt_xyz.cpu().numpy(),pred_xyz.cpu().numpy()])

    val_loss=val_loss/len(tbar)
    print(f"val loss: {val_loss}")

    save_df = pd.DataFrame({'epoch': [epoch], 'train_loss': [train_loss.cpu().item()], 'val_loss': [val_loss.cpu().item()]})
    loss_df = pd.concat([loss_df, save_df], ignore_index=True)

    ## Check Best Loss .pt and Save
    if val_loss<best_val_loss:
        best_val_loss=val_loss
        best_preds=val_preds
        torch.save(model.state_dict(),f'/content/drive/MyDrive/RNA/WuSubSol/Save_Data/{version_name}_{epochs}_best.pt')

    # 1.053595052265986 train loss after epoch 0
torch.save(model.state_dict(),f'/content/drive/MyDrive/RNA/WuSubSol/Save_Data/{version_name}_{epochs}_last.pt')

# Save Loss,

loss_df.to_csv(f'/content/drive/MyDrive/RNA/WuSubSol/Save_Data/{version_name}_{epochs}_loss.csv', index=False)

print('save complete')

Epoch 1 Loss: 1.7904695272445679: 100%|██████████| 1/1 [00:01<00:00,  1.57s/it]
100%|██████████| 1/1 [00:00<00:00,  3.11it/s]

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.



val loss: 3.4629976749420166


Epoch 2 Loss: 1.8015655279159546: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.10it/s]


val loss: 3.460308074951172


Epoch 3 Loss: 1.8047804832458496: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  3.97it/s]


val loss: 3.45858097076416


Epoch 4 Loss: 1.7915139198303223: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]
100%|██████████| 1/1 [00:00<00:00,  4.35it/s]


val loss: 3.456367015838623


Epoch 5 Loss: 1.7953580617904663: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  3.67it/s]


val loss: 3.4538633823394775


Epoch 6 Loss: 1.7889316082000732: 100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
100%|██████████| 1/1 [00:00<00:00,  2.97it/s]


val loss: 3.451791524887085


Epoch 7 Loss: 1.7997839450836182: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  4.21it/s]


val loss: 3.45019268989563


Epoch 8 Loss: 1.7939746379852295: 100%|██████████| 1/1 [00:01<00:00,  1.33s/it]
100%|██████████| 1/1 [00:00<00:00,  4.16it/s]


val loss: 3.449519634246826


Epoch 9 Loss: 1.7883073091506958: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
100%|██████████| 1/1 [00:00<00:00,  2.90it/s]


val loss: 3.4495856761932373


Epoch 10 Loss: 1.7955527305603027: 100%|██████████| 1/1 [00:01<00:00,  1.14s/it]
100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


val loss: 3.4504454135894775


Epoch 11 Loss: 1.7908759117126465: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00,  3.87it/s]


val loss: 3.450075626373291


Epoch 12 Loss: 1.7847557067871094: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
100%|██████████| 1/1 [00:00<00:00,  4.03it/s]


val loss: 3.448533535003662


Epoch 13 Loss: 1.7719858884811401: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]
100%|██████████| 1/1 [00:00<00:00,  4.28it/s]


val loss: 3.44669246673584


Epoch 14 Loss: 1.7680602073669434: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]
100%|██████████| 1/1 [00:00<00:00,  4.33it/s]


val loss: 3.445756435394287


Epoch 15 Loss: 1.7697553634643555: 100%|██████████| 1/1 [00:01<00:00,  1.10s/it]
100%|██████████| 1/1 [00:00<00:00,  3.05it/s]


val loss: 3.4444780349731445


Epoch 16 Loss: 1.7653493881225586: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  3.67it/s]


val loss: 3.4435269832611084


Epoch 17 Loss: 1.7562270164489746: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.09it/s]


val loss: 3.4427361488342285


Epoch 18 Loss: 1.746379017829895: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  3.52it/s]


val loss: 3.4402456283569336


Epoch 19 Loss: 1.7865328788757324: 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]
100%|██████████| 1/1 [00:00<00:00,  3.06it/s]


val loss: 3.437685012817383


Epoch 20 Loss: 1.7595462799072266: 100%|██████████| 1/1 [00:01<00:00,  1.04s/it]
100%|██████████| 1/1 [00:00<00:00,  3.96it/s]


val loss: 3.43430495262146


Epoch 21 Loss: 1.7659273147583008: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.08it/s]


val loss: 3.430065870285034


Epoch 22 Loss: 1.7695096731185913: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
100%|██████████| 1/1 [00:00<00:00,  3.08it/s]


val loss: 3.4263079166412354


Epoch 23 Loss: 1.7540922164916992: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  4.15it/s]


val loss: 3.422123908996582


Epoch 24 Loss: 1.7473466396331787: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  4.36it/s]


val loss: 3.4174444675445557


Epoch 25 Loss: 1.7335705757141113: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
100%|██████████| 1/1 [00:00<00:00,  3.90it/s]


val loss: 3.413233995437622


Epoch 26 Loss: 1.7469279766082764: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
100%|██████████| 1/1 [00:00<00:00,  3.97it/s]


val loss: 3.4089882373809814


Epoch 27 Loss: 1.7340044975280762: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  4.35it/s]


val loss: 3.404493808746338


Epoch 28 Loss: 1.669268250465393: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  3.82it/s]


val loss: 3.4000484943389893


Epoch 29 Loss: 1.6987203359603882: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  4.30it/s]


val loss: 3.39547061920166


Epoch 30 Loss: 1.6879661083221436: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
100%|██████████| 1/1 [00:00<00:00,  4.32it/s]


val loss: 3.390557050704956


Epoch 31 Loss: 1.6977711915969849: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  3.98it/s]


val loss: 3.3861632347106934


Epoch 32 Loss: 1.6965217590332031: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  3.34it/s]


val loss: 3.3830347061157227


Epoch 33 Loss: 1.641868233680725: 100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


val loss: 3.379650115966797


Epoch 34 Loss: 1.6933892965316772: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]
100%|██████████| 1/1 [00:00<00:00,  3.83it/s]


val loss: 3.3761074542999268


Epoch 35 Loss: 1.665663242340088: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]
100%|██████████| 1/1 [00:00<00:00,  4.36it/s]


val loss: 3.372830390930176


Epoch 36 Loss: 1.645926594734192: 100%|██████████| 1/1 [00:00<00:00,  1.15it/s]
100%|██████████| 1/1 [00:00<00:00,  4.23it/s]


val loss: 3.3690366744995117


Epoch 37 Loss: 1.5995373725891113: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s]
100%|██████████| 1/1 [00:00<00:00,  4.25it/s]


val loss: 3.366209030151367


Epoch 38 Loss: 1.6560109853744507: 100%|██████████| 1/1 [00:01<00:00,  1.13s/it]
100%|██████████| 1/1 [00:00<00:00,  3.14it/s]


val loss: 3.3636980056762695


Epoch 39 Loss: 1.5940635204315186: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]
100%|██████████| 1/1 [00:00<00:00,  4.28it/s]


val loss: 3.3615775108337402


Epoch 40 Loss: 1.6197618246078491: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  3.94it/s]


val loss: 3.3601255416870117


Epoch 41 Loss: 1.6156129837036133: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.16it/s]


val loss: 3.3590869903564453


Epoch 42 Loss: 1.6110918521881104: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


val loss: 3.358269214630127


Epoch 43 Loss: 1.5965704917907715: 100%|██████████| 1/1 [00:01<00:00,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00,  4.27it/s]


val loss: 3.3575634956359863


Epoch 44 Loss: 1.6149322986602783: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]
100%|██████████| 1/1 [00:00<00:00,  4.20it/s]


val loss: 3.357110023498535


Epoch 45 Loss: 1.6233247518539429: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
100%|██████████| 1/1 [00:00<00:00,  2.90it/s]


val loss: 3.3567376136779785


Epoch 46 Loss: 1.6374107599258423: 100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
100%|██████████| 1/1 [00:00<00:00,  3.99it/s]


val loss: 3.356426239013672


Epoch 47 Loss: 1.600538730621338: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
100%|██████████| 1/1 [00:00<00:00,  4.32it/s]


val loss: 3.3562464714050293


Epoch 48 Loss: 1.5816577672958374: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]
100%|██████████| 1/1 [00:00<00:00,  4.19it/s]


val loss: 3.3561272621154785


Epoch 49 Loss: 1.5690792798995972: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.31it/s]


val loss: 3.3560667037963867


  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.0203, -0.5614, -0.1539],
        [-0.1074,  0.2944,  0.8067],
        [ 1.3420,  1.6670,  1.6438],
        [ 0.0341, -0.7036,  0.5366],
        [ 1.2859,  0.9763,  2.1010],
        [-0.3677, -0.4877,  0.8381],
        [ 1.7884,  1.8754,  2.4167],
        [ 1.4266,  1.5372,  2.5281],
        [ 2.3883,  1.8390,  2.1705],
        [ 1.2334,  1.1832,  1.9712],
        [ 2.6285,  1.8483,  1.9626]], device='cuda:0',
       grad_fn=<SqueezeBackward0>)
tensor([[  0.0000,   0.0000,   0.0000],
        [ -1.0760,  -3.3350,  -3.4670],
        [ -4.4130,  -6.6660,  -7.6410],
        [ -7.0260,  -0.6310,  -9.7380],
        [ -3.1140,   3.6950, -11.0880],
        [ -2.5960,   8.3430, -14.4300],
        [  1.2760,  10.8900, -15.9760],
        [  7.1280,  12.5880, -19.2280],
        [  8.5000,   9.1200, -14.2150],
        [  4.5950,   6.0570,  -8.3870],
        [  6.2800,   5.2940,  -2.2990],
        [  6.1480,   2.7090,   3.4880]], device='cuda:0')


Epoch 50 Loss: 1.597622275352478: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
100%|██████████| 1/1 [00:00<00:00,  3.95it/s]

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.0294e-03, -3.4233e-03, -4.1739e-03],
        [ 1.7007e+00,  1.9879e+00,  8.8714e-01],
        [-3.3333e-02,  1.8321e-01, -3.3374e-01],
        [-3.0462e-02,  1.7849e-01, -3.2829e-01],
        [ 1.5612e+00,  1.3867e+00,  9.6201e-01],
        [ 1.7058e+00,  1.9888e+00,  8.8780e-01],
        [-3.1648e-02,  1.8571e-01, -3.3299e-01],
        [ 1.5685e+00,  1.3868e+00,  9.6611e-01],
        [ 1.7106e+00,  1.9841e+00,  8.9008e-01],
        [-3.1808e-02,  1.8268e-01, -3.3836e-01],
        [ 1.6387e-03, -4.4628e-03, -2.2395e-03],
        [ 1.7140e+00,  1.9871e+00,  8.8825e-01],
        [ 5.0555e-03, -1.1941e-03,  3.2316e-03],
        [-3.4204e-02,  1.8508e-01, -3.2536e-01],
        [ 7.8284e-03, -4.6958e-03, -6.6966e-03],
        [-3.4074e-02,  1.8470e-01, -3.3213e-01],
        [-3.6754e-02,  1.8024e-01, -3.3602e-01],
        [ 1.7098e+00,  1.9694e+00,  8.9321e-01],
        [ 1.5626e+00,  1.3807e+00,  9.5211e-01],
        [ 1.5687e+00




save complete


# Plot Loss Graph

In [None]:
import pandas as pd
import plotly.graph_objects as go
epochs=config['epochs']
# CSV 로드
df = pd.read_csv(f'/content/drive/MyDrive/RNA/WuSubSol/Save_Data/{version_name}_{epochs}_loss.csv')
# df = pd.read_csv
epochs = df['epoch']
train_loss = df['train_loss']
val_loss = df['val_loss']

# 최솟값 정보
min_train_loss = train_loss.min()
min_train_epoch = epochs[train_loss.idxmin()]
min_val_loss = val_loss.min()
min_val_epoch = epochs[val_loss.idxmin()]

# 그래프 만들기
fig = go.Figure()

# Train Loss
fig.add_trace(go.Scatter(
    x=epochs,
    y=train_loss,
    mode='lines+markers',
    name='Train Loss',
    line=dict(color='blue'),
    hovertemplate='Epoch: %{x}<br>Train Loss: %{y:.4f}<extra></extra>'
))

# Val Loss
fig.add_trace(go.Scatter(
    x=epochs,
    y=val_loss,
    mode='lines+markers',
    name='Val Loss',
    line=dict(color='red'),
    hovertemplate='Epoch: %{x}<br>Val Loss: %{y:.4f}<extra></extra>'
))

# 최소값 표시
fig.add_trace(go.Scatter(
    x=[min_val_epoch],
    y=[min_val_loss],
    mode='markers+text',
    marker=dict(color='green', size=10),
    text=[f"Min Val Loss: {min_val_loss:.4f} (Epoch {min_val_epoch})"],
    textposition='top center',
    name='Min Val',

))

fig.add_trace(go.Scatter(
    x=[min_train_epoch],
    y=[min_train_loss],
    mode='markers+text',
    marker=dict(color='green', size=10),
    text=[f"Min Train Loss: {min_train_loss:.4f} (Epoch {min_train_epoch})"],
    textposition='top center',
    name='Min Train',
))
# 레이아웃 꾸미기
fig.update_layout(
    title=f"{version_name} - Train & Validation Loss : lr={config['learning_rate']}_nlayers={config['n_layers']}",
    xaxis_title="Epoch",
    yaxis_title="Loss",
    hovermode='x unified',
    template='plotly_white',
    legend=dict(x=0.01, y=0.99)
)

fig.show()
