In [41]:
import json
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATv2Conv
import numpy as np
import pandas as pd

In [4]:
train_data = []
with open('../data/stanford-covid-vaccine/train.json', 'r', encoding='utf-8') as file:
    for line in file:
        obj = json.loads(line.strip())
        train_data.append(obj)

In [42]:
train_data[0]

{'index': 0,
 'id': 'id_001f94081',
 'sequence': 'GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUAACUGGAAUAACCCAUACCAGCAGUUAGAGUUCGCUCUAACAAAAGAAACAACAACAACAAC',
 'structure': '.....((((((.......)))).)).((.....((..((((((....))))))..)).....))....(((((((....))))))).....................',
 'predicted_loop_type': 'EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHHHSSSSSSIISSIIIIISSXXXXSSSSSSSHHHHSSSSSSSEEEEEEEEEEEEEEEEEEEEE',
 'signal_to_noise': 6.894,
 'SN_filter': 1.0,
 'seq_length': 107,
 'seq_scored': 68,
 'reactivity_error': [0.1359,
  0.207,
  0.1633,
  0.1452,
  0.1314,
  0.105,
  0.0821,
  0.0964,
  0.0756,
  0.1087,
  0.1377,
  0.1544,
  0.1622,
  0.1388,
  0.1284,
  0.1009,
  0.0941,
  0.0564,
  0.0417,
  0.0596,
  0.0482,
  0.1041,
  0.0942,
  0.052,
  0.0583,
  0.0403,
  0.0491,
  0.1003,
  0.0525,
  0.081,
  0.1103,
  0.0707,
  0.0797,
  0.0997,
  0.0968,
  0.0939,
  0.0931,
  0.0604,
  0.0427,
  0.0331,
  0.0412,
  0.0286,
  0.0415,
  0.0394,
  0.0636,
  0.0816,
  0.0474,
  0.0295,

In [43]:
train_dataset = [{'id':x['id'], 'sequence': x['sequence'], 'structure': x['structure'], 'reactivity': x['reactivity'], 'deg_Mg_pH10': x['deg_Mg_pH10'], 'deg_Mg_50C': x['deg_Mg_50C'], 'seq_scored': x['seq_scored']} for x in train_data]
train_dataset[0]

{'id': 'id_001f94081',
 'sequence': 'GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUAACUGGAAUAACCCAUACCAGCAGUUAGAGUUCGCUCUAACAAAAGAAACAACAACAACAAC',
 'structure': '.....((((((.......)))).)).((.....((..((((((....))))))..)).....))....(((((((....))))))).....................',
 'reactivity': [0.3297,
  1.5693,
  1.1227,
  0.8686,
  0.7217,
  0.4384,
  0.256,
  0.3364,
  0.2168,
  0.3583,
  0.9541,
  1.4113,
  1.6911,
  1.2494,
  1.1895,
  0.6909,
  0.4736,
  0.1754,
  0.0582,
  0.2173,
  0.0785,
  0.8249,
  0.7638,
  0.1095,
  0.2568,
  0.0895,
  0.1576,
  0.7727,
  0.1573,
  0.5043,
  1.0444,
  0.4766,
  0.5588,
  0.9054,
  1.0125,
  1.0482,
  1.044,
  0.4522,
  0.211,
  0.0461,
  0.082,
  0.0643,
  0.1526,
  0.0894,
  0.5081,
  1.0745,
  0.3215,
  0.0716,
  0.0244,
  0.0123,
  0.1984,
  0.4961,
  1.0641,
  0.6394,
  0.6789,
  0.365,
  0.1741,
  0.1408,
  0.1646,
  0.5389,
  0.683,
  0.4273,
  0.0527,
  0.0693,
  0.1398,
  0.2937,
  0.2362,
  0.5731],
 'deg_Mg_pH10': [0.7556,
  2.983,
  0.2

In [44]:
test_labels = pd.read_csv('../data/stanford-covid-vaccine/post_deadline_files/private_test_labels.csv')
test_labels.head()

Unnamed: 0,id,ID,sequence,structure,seqpos,reactivity,deg_Mg_pH10,deg_Mg_50C,errors,deg_pH10_Mg_errors,deg_50C_Mg_errors,S/N filter,predicted_loop_type,seq_scored,seq_length,cluster_id,n_neighbors,first_cluster_member,test_filter
0,id_40f52a81b,10207086,GGAAAUUUUCGCGGGACGGGCGGCCGGGCGGAGGCGGCGCGAGGGC...,.......(((((.((.((..(.(((..(((...((..((((....(...,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[0.6009, 1.3193, 1.5475, 0.5852, 1.566, 0.3387...","[0.5866, 1.4956, 1.3765, 0.5714, 2.9199, 0.925...","[1.2183, 1.718, 0.8737, 0.9644, 2.7502, 0.5862...","[0.4419, 0.4736, 0.4529, 0.30820000000000003, ...","[0.4333, 0.491, 0.4375, 0.3079, 0.5534, 0.3432...","[0.5922000000000001, 0.5873, 0.434700000000000...",1,EEEEEEESSSSSISSISSIISISSSIISSSIIISSIISSSSIIIIS...,92,130,694,2,1,1
1,id_59252b684,10207088,GGAAAUUUUCGCGGGACGGGCGGCAGGGCUGAGGUUUCGCGAGGGC...,........(((((((((...((((...))))..))))))))).......,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[0.713, 0.6368, 0.9615, 0.4066, 1.1178, 0.5775...","[0.7798, 0.9725, 0.8631, 0.5575, 3.2812, 1.241...","[0.8272, 0.9112, 1.3493, 1.0047, 2.1987, 1.499...","[0.3924, 0.3191, 0.3269, 0.2175, 0.3406, 0.259...","[0.4247, 0.3841, 0.34440000000000004, 0.2682, ...","[0.48960000000000004, 0.4184, 0.432, 0.3526, 0...",1,EEEEEEEESSSSSSSSSIIISSSSHHHSSSSIISSSSSSSSSMMMM...,92,130,694,2,0,1
2,id_ebf1148ee,10207093,GGAAAUUUUCGCGAGACAAGCGGCAGGGCUGAGAUUACGCGAGGGC...,........(((((.....(((......))).......)))))..((...,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1.0044, 1.7519, 1.2299, 0.9241, 0.982, 1.1125...","[0.8368, 1.5918, 1.0052, 0.7231, 2.6103, 1.948...","[0.6066, 1.5895, 1.279, 0.8911, 2.9968, 1.4682...","[0.2933, 0.33140000000000003, 0.2746, 0.237500...","[0.26990000000000003, 0.3124, 0.2505, 0.2142, ...","[0.3024, 0.3673, 0.3183, 0.2696, 0.4048, 0.294...",1,EEEEEEEESSSSSIIIIISSSHHHHHHSSSIIIIIIISSSSSXXSS...,92,130,696,6,1,1
3,id_63c3b7d50,10207098,GGAAAUUUUCGCGAGACCAGCGGCAGGGCUGAGCUAACGCGAGGGC...,........(((((((..((((......))))..))..)))))((.(...,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[0.877, 1.3276, 1.0295, 0.7543, 1.2666, 1.0846...","[0.4452, 1.4232, 1.131, 0.6042, 3.3402, 2.014,...","[0.3878, 1.2893, 1.5437, 1.2499, 3.4329, 1.837...","[0.1713, 0.1758, 0.1399, 0.11220000000000001, ...","[0.15080000000000002, 0.1819, 0.14780000000000...","[0.177, 0.202, 0.1847, 0.1555, 0.2392000000000...",1,EEEEEEEESSSSSSSIISSSSHHHHHHSSSSIISSBBSSSSSSSIS...,92,130,696,6,0,0
4,id_a181978cc,10207103,GGAAAUUUUCGCGAGACCAGCGGCAGGGCUGAGCUAACGCGAGGGC...,........(((((((..((((......))))..))..))))).......,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[0.8414, 1.4938, 1.0377, 0.995, 1.4373, 1.218,...","[0.8237, 1.5692, 0.8926, 0.7978, 3.1381, 1.797...","[0.6292, 1.5747, 1.567, 1.2007, 3.254, 2.0014,...","[0.1192, 0.1288, 0.1018, 0.09290000000000001, ...","[0.1279, 0.14350000000000002, 0.1061, 0.0945, ...","[0.136, 0.1564, 0.1408, 0.11810000000000001, 0...",1,EEEEEEEESSSSSSSIISSSSHHHHHHSSSSIISSBBSSSSSXXXX...,92,130,696,6,0,0


In [45]:
def extract_base_pairs(dot_bracket):
    """
    Извлекает пары оснований из строки dot-bracket.
    """
    stack = []
    pairs = []
    for i, char in enumerate(dot_bracket):
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                pairs.append((stack.pop(), i))
    return pairs

# Пример
structure = "(((...)))"
base_pairs = extract_base_pairs(structure)
print("Пары оснований:", base_pairs)


Пары оснований: [(2, 6), (1, 7), (0, 8)]


In [56]:
def create_graph_with_targets(sequence, dot_bracket, targets, seq_scored):
    """
    Создает граф из последовательности, вторичной структуры и целевых значений.

    :param sequence: строка с последовательностью нуклеотидов
    :param dot_bracket: dot-bracket нотация вторичной структуры
    :param targets: список целевых значений для каждой позиции (размер: seq_scored x num_targets)
    :param seq_scored: количество позиций последовательности, для которых известны истинные значения
    :return: объект Data
    """
    num_nodes = len(sequence)  # seq_length
    
    # Узлы: One-hot кодирование типов нуклеотидов
    nucleotide_map = {'A': [1, 0, 0, 0], 'U': [0, 1, 0, 0], 'C': [0, 0, 1, 0], 'G': [0, 0, 0, 1]}
    if any(nt not in nucleotide_map for nt in sequence):
        raise ValueError("Последовательность содержит некорректные символы. Ожидалось: A, U, C, G")
    x = torch.tensor([nucleotide_map[nt] for nt in sequence], dtype=torch.float)  # (num_nodes, 4)
    
    # Ребра: последовательные + вторичная структура
    edge_index = []
    edge_attr = []
    
    # Последовательные связи
    for i in range(num_nodes - 1):
        edge_index.append([i, i + 1])
        edge_index.append([i + 1, i])
        edge_attr.extend([[1, 0], [1, 0]])  # Последовательные связи
    
    # Вторичная структура (пары оснований)
    base_pairs = extract_base_pairs(dot_bracket)
    for i, j in base_pairs:
        edge_index.append([i, j])
        edge_index.append([j, i])
        edge_attr.extend([[0, 1], [0, 1]])  # Связи вторичной структуры
    
    # Преобразование в тензоры
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # (2, num_edges)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)  # (num_edges, edge_dim)
    
    # Целевые значения
    y = torch.zeros((num_nodes, 3), dtype=torch.float)  # Заполняем нулями для всех узлов
    if seq_scored > num_nodes:
        raise ValueError("seq_scored не может быть больше длины последовательности (seq_length)")
    if len(targets) != seq_scored or any(len(t) != 3 for t in targets):
        targets = targets[:seq_scored]    
    # Записываем известные значения в первые seq_scored узлов
    y[:seq_scored, :] = torch.tensor(targets, dtype=torch.float)
    
    # Маска для узлов с известными целевыми значениями
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[:seq_scored] = True  # Первые seq_scored узлы участвуют в расчете метрики
    
    # Создание графа
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, mask=mask)


In [47]:
class GNNModelNodePrediction(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, edge_dim):
        super(GNNModelNodePrediction, self).__init__()
        self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=4, edge_dim=edge_dim, concat=True)
        self.conv2 = GATv2Conv(hidden_dim * 4, hidden_dim, heads=4, edge_dim=edge_dim, concat=False)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)  # Убедимся, что output_dim = 3
    
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # Первый слой внимания
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        # Второй слой внимания
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)
        # Прогноз на уровне узлов
        x = self.fc(x)  # Размерность выхода: (num_nodes, output_dim)
        return x


In [48]:
def mcrmse(y_true, y_pred):
    """
    Расчет MCRMSE (Mean Columnwise Root Mean Squared Error).
    
    :param y_true: истинные значения (размер: num_nodes x num_targets)
    :param y_pred: предсказанные значения (размер: num_nodes x num_targets)
    :return: скалярное значение метрики
    """
    # Проверка размерностей
    assert y_true.shape == y_pred.shape, "Размерности y_true и y_pred должны совпадать"
    
    # Среднеквадратичная ошибка по каждому столбцу
    columnwise_rmse = torch.sqrt(torch.mean((y_true - y_pred) ** 2, dim=0))  # Размер: num_targets
    
    # Среднее значение по всем целевым метрикам
    return torch.mean(columnwise_rmse)  # Скаляр


In [49]:
def masked_mcrmse(y_true, y_pred, mask):
    """
    MCRMSE с использованием маски для игнорирования ненужных узлов.
    
    :param y_true: истинные значения (num_nodes x num_targets)
    :param y_pred: предсказанные значения (num_nodes x num_targets)
    :param mask: маска узлов с известными значениями (num_nodes,)
    :return: значение метрики
    """
    y_true_masked = y_true[mask]
    y_pred_masked = y_pred[mask]
    return mcrmse(y_true_masked, y_pred_masked)


In [71]:
graphs = [create_graph_with_targets(x['sequence'], x['structure'], list(zip(x['reactivity'], x['deg_Mg_pH10'], x['deg_Mg_50C'])), x['seq_scored']) for x in train_dataset]  # Генерация набора данных

# DataLoader
loader = DataLoader(graphs, batch_size=16, shuffle=True)

# Модель
input_dim = 4
hidden_dim = 64
output_dim = 3
edge_dim = 2

model = GNNModelNodePrediction(input_dim, hidden_dim, output_dim, edge_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Обучение
for epoch in range(20):
    epoch_loss = []
    for batch in loader:
        optimizer.zero_grad()
        pred = model(batch)  # Предсказание (размер: num_nodes x 3)
        loss = masked_mcrmse(batch.y, pred, batch.mask)  # Используем MCRMSE
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    print(f"Epoch {epoch + 1}, Loss: {np.mean(epoch_loss):.4f}")




Epoch 1, Loss: 0.7941
Epoch 2, Loss: 0.7186
Epoch 3, Loss: 0.7091
Epoch 4, Loss: 0.6992
Epoch 5, Loss: 0.6950
Epoch 6, Loss: 0.6871
Epoch 7, Loss: 0.6923
Epoch 8, Loss: 0.6877
Epoch 9, Loss: 0.6867
Epoch 10, Loss: 0.6888
Epoch 11, Loss: 0.6829
Epoch 12, Loss: 0.6725
Epoch 13, Loss: 0.6773
Epoch 14, Loss: 0.6689
Epoch 15, Loss: 0.6689
Epoch 16, Loss: 0.6648
Epoch 17, Loss: 0.6639
Epoch 18, Loss: 0.6664
Epoch 19, Loss: 0.6664
Epoch 20, Loss: 0.6603


In [73]:
graphs_test = [create_graph_with_targets(test_labels.iloc[i]['sequence'], test_labels.iloc[i]['structure'], list(zip(json.loads(test_labels.iloc[i]['reactivity']), json.loads(test_labels.iloc[i]['deg_Mg_pH10']), json.loads(test_labels.iloc[i]['deg_Mg_50C']))), test_labels.iloc[i]['seq_scored']) for i in range(test_labels.shape[0])]

test_loader = DataLoader(graphs_test, batch_size=16, shuffle=True)

def test_model(model, test_data_loader, loss_fn, device='cpu'):
    """
    Тестирование модели на тестовых данных.

    :param model: обученная модель
    :param test_data_loader: DataLoader с тестовыми графовыми данными
    :param loss_fn: функция потерь
    :param device: устройство для вычислений ('cpu' или 'cuda')
    :return: средняя метрика MCRMSE на тестовом наборе
    """
    model.eval()  # Перевод модели в режим оценки
    
    losses = []
    all_predictions = []
    
    with torch.no_grad():
        for batch in test_data_loader:
            optimizer.zero_grad()
            pred = model(batch).to(device)  # Предсказание (размер: num_nodes x 3)
            loss = loss_fn(batch.y, pred, batch.mask)  # Используем MCRMSE
            losses.append(loss.item())
    
    return np.mean(losses)



In [72]:
print(test_model(model, test_loader, masked_mcrmse))

0.49583020175878817


In [65]:
torch.save(model.state_dict(), '../models/first_model.pt')

In [66]:
import torch
from torch_geometric.nn import GATv2Conv, BatchNorm

class EnhancedGNNModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, edge_dim, num_layers=4, dropout=0.2):
        super(EnhancedGNNModel, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = torch.nn.Dropout(dropout)
        
        self.input_layer = GATv2Conv(input_dim, hidden_dim, heads=4, edge_dim=edge_dim, concat=True)
        
        self.layers = torch.nn.ModuleList()
        for _ in range(num_layers - 2):
            self.layers.append(
                GATv2Conv(hidden_dim * 4, hidden_dim, heads=4, edge_dim=edge_dim, concat=True)
            )
        
        self.output_layer = GATv2Conv(hidden_dim * 4, hidden_dim, heads=1, edge_dim=edge_dim, concat=False)
        
        self.norms = torch.nn.ModuleList([BatchNorm(hidden_dim * 4) for _ in range(num_layers - 1)])
        
        self.fc = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        x = self.input_layer(x, edge_index, edge_attr)
        x = self.norms[0](x)
        x = torch.relu(x)
        x = self.dropout(x)
        
        for i, layer in enumerate(self.layers):
            residual = x
            x = layer(x, edge_index, edge_attr)
            x = self.norms[i + 1](x)
            x = torch.relu(x + residual)  # Резидуальное соединение
            x = self.dropout(x)
        
        x = self.output_layer(x, edge_index, edge_attr)
        x = torch.relu(x)
        
        x = self.fc(x)  # Размерность выхода: (num_nodes, output_dim)
        return x


In [67]:
# Инициализация модели
input_dim = 4  # Число признаков на узел (A, U, C, G)
hidden_dim = 128
output_dim = 3  # Количество целевых переменных
edge_dim = 2  # Число признаков для ребер
num_layers = 4
dropout = 0.3

model = EnhancedGNNModel(input_dim, hidden_dim, output_dim, edge_dim, num_layers=num_layers, dropout=dropout)
loader = DataLoader(graphs, batch_size=16, shuffle=True)

# Оптимизатор и функция потерь
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = masked_mcrmse

# Обучение модели
for epoch in range(20):
    epoch_loss = []
    for batch in loader:
        model.train()
        optimizer.zero_grad()
        out = model(batch)  # Вызов модели
        loss = loss_fn(batch.y, out, batch.mask)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    
    print(f"Epoch {epoch+1}, Loss: {np.mean(epoch_loss):.4f}")




Epoch 1, Loss: 0.6942
Epoch 2, Loss: 0.6640
Epoch 3, Loss: 0.6558
Epoch 4, Loss: 0.6583
Epoch 5, Loss: 0.6523
Epoch 6, Loss: 0.6560
Epoch 7, Loss: 0.6452
Epoch 8, Loss: 0.6392
Epoch 9, Loss: 0.6447
Epoch 10, Loss: 0.6469
Epoch 11, Loss: 0.6342
Epoch 12, Loss: 0.6403
Epoch 13, Loss: 0.6461
Epoch 14, Loss: 0.6333
Epoch 15, Loss: 0.6316
Epoch 16, Loss: 0.6309
Epoch 17, Loss: 0.6337
Epoch 18, Loss: 0.6292
Epoch 19, Loss: 0.6302
Epoch 20, Loss: 0.6335


In [69]:
print(test_model(model, test_loader, masked_mcrmse))

0.47124883704460585


In [70]:
torch.save(model.state_dict(), '../models/second_model.pt')