In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
import pandas as pd
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm
import os
import csv

BASE_PATH = "/home/jupyter/datasphere/project/transformer_dataset_2/"

# Загрузка метаданных
df = pd.read_csv('/home/jupyter/datasphere/project/transformer_dataset/_info.csv', sep=',', skipinitialspace=True, header=0)
df = df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

df['full_path'] = BASE_PATH + df['path']

# Разделение на train/val/test
train_files = df[df['path'].str.contains('train/')]['full_path'].tolist()
val_files = df[df['path'].str.contains('val/')]['full_path'].tolist()
test_files = df[df['path'].str.contains('test/')]['full_path'].tolist()

print(f"Train files: {len(train_files)}, Val files: {len(val_files)}, Test files: {len(test_files)}")

Train files: 6990, Val files: 1535, Test files: 1475


In [3]:
BATCH_SIZE = 128
NUM_WORKERS = 0

OUTPUT_DIM = 64

NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
GAMMA = 0.7

In [4]:
class EmbeddingSequenceDataset(Dataset):
    def __init__(self, file_paths, df, max_length=200):
        self.file_paths = file_paths
        self.df = df
        self.max_length = max_length
        
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        
        # Загрузка эмбеддингов
        embeddings = np.load(file_path)
        seq_len = len(embeddings)
        
        # Получение меток из df
        file_id = Path(file_path).stem
        marks = self.df[self.df['full_path'].str.contains(file_id)]['marks'].values[0]
        start, end = map(int, marks.split(','))
        
        # Создание бинарных меток
        labels = torch.zeros(self.max_length, dtype=torch.float32)
        if seq_len > 0:
            # Заполняем 1 в указанном диапазоне
            labels[start:end] = 1.0
            
            # Паддинг
            if seq_len < self.max_length:
                pad = torch.zeros(self.max_length - seq_len, embeddings.shape[1])
                embeddings = torch.cat([torch.FloatTensor(embeddings), pad])
            else:
                embeddings = torch.FloatTensor(embeddings[:self.max_length])
        
        return embeddings, labels, seq_len

def create_dataloader(file_paths, df, batch_size=32, shuffle=True):
    dataset = EmbeddingSequenceDataset(file_paths, df)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=2,
        pin_memory=True
    )

# Создаем DataLoader'ы
train_loader = create_dataloader(train_files, df, batch_size=BATCH_SIZE)
val_loader = create_dataloader(val_files, df, batch_size=BATCH_SIZE, shuffle=False)
test_loader = create_dataloader(test_files, df, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
# class BinaryClassifierTransformer(nn.Module):
class BinaryClassifierTransformer(nn.Module):
    def __init__(self, input_dim, d_model=256, nhead=8, num_layers=8, dim_feedforward=256, dropout=0.1):
        super().__init__()
        
        # 1. Улучшенный embedding слой
        self.embedding = nn.Sequential(
            nn.Linear(input_dim, d_model),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(d_model, d_model),
            nn.LayerNorm(d_model),
            nn.GELU()
        )
        
        # 2. Позиционные эмбеддинги
        self.pos_encoder = nn.Parameter(torch.randn(1, 200, d_model))
        
        # 3. Увеличенный трансформер
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 4. Расширенный классификатор
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model//2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model//2, 1)
        )
        
        # Инициализация весов
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x):
        # x: [batch_size, seq_len, input_dim]
        x = self.embedding(x)  # [batch_size, seq_len, d_model]
        x = x + self.pos_encoder[:, :x.size(1), :]
        x = self.transformer(x)
        return torch.sigmoid(self.classifier(x)).squeeze(-1)  # [batch_size, seq_len]

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BinaryClassifierTransformer(64).to(device)
OFFSET = 50
model.load_state_dict(torch.load('checkpoints/transformer_epoch50.pt', map_location=device))

<All keys matched successfully>

In [7]:
def train_model():
    # Определяем размерность входа
    sample_emb = np.load(train_files[0])
    input_dim = sample_emb.shape[1]

    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=GAMMA)
    
    # Логирование
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("logs", exist_ok=True)
    csv_log_path = "logs/train_log.csv"
    
    with open(csv_log_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "step", "train_loss", "val_loss"])
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
        
        for step, (embeddings, labels, seq_lens) in enumerate(train_bar):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(embeddings)

            # Маска для паддинга
            mask = torch.arange(200).expand(len(seq_lens), 200) < seq_lens.unsqueeze(1)
            mask = mask.to(device)

            loss = criterion(outputs[mask], labels[mask])
            loss.backward()
            optimizer.step()

            if step % 50 == 0:
                with open(csv_log_path, mode='a', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch + 1, step, loss.item(), ""])

        # Валидация
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for embeddings, labels, seq_lens in val_loader:
                embeddings, labels = embeddings.to(device), labels.to(device)
                outputs = model(embeddings)
                mask = torch.arange(200).expand(len(seq_lens), 200) < seq_lens.unsqueeze(1)
                mask = mask.to(device)
                val_loss += criterion(outputs[mask], labels[mask]).item()

        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()

        print(scheduler.get_last_lr())
        # Логирование
        with open(csv_log_path, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch + 1, "final", loss.item(), avg_val_loss])

        # Сохранение модели
        torch.save(model.state_dict(), f"checkpoints/transformer_epoch{epoch+1+OFFSET}.pt")

if __name__ == "__main__":
    train_model()

Epoch 1/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.45it/s]


[0.0001]


Epoch 2/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.45it/s]


[0.0001]


Epoch 3/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.54it/s]


[0.0001]


Epoch 4/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[0.0001]


Epoch 5/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.52it/s]


[0.0001]


Epoch 6/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.43it/s]


[0.0001]


Epoch 7/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.44it/s]


[0.0001]


Epoch 8/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.45it/s]


[0.0001]


Epoch 9/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[0.0001]


Epoch 10/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[7e-05]


Epoch 11/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[7e-05]


Epoch 12/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[7e-05]


Epoch 13/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.52it/s]


[7e-05]


Epoch 14/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.47it/s]


[7e-05]


Epoch 15/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.44it/s]


[7e-05]


Epoch 16/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[7e-05]


Epoch 17/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.45it/s]


[7e-05]


Epoch 18/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[7e-05]


Epoch 19/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.53it/s]


[7e-05]


Epoch 20/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.49it/s]


[4.899999999999999e-05]


Epoch 21/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.49it/s]


[4.899999999999999e-05]


Epoch 22/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.50it/s]


[4.899999999999999e-05]


Epoch 23/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.50it/s]


[4.899999999999999e-05]


Epoch 24/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.50it/s]


[4.899999999999999e-05]


Epoch 25/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.48it/s]


[4.899999999999999e-05]


Epoch 26/50 [Train]: 100%|██████████| 55/55 [00:22<00:00,  2.49it/s]


[4.899999999999999e-05]


Epoch 27/50 [Train]: 100%|██████████| 55/55 [00:21<00:00,  2.52it/s]


[4.899999999999999e-05]


Epoch 28/50 [Train]:  71%|███████   | 39/55 [00:16<00:06,  2.38it/s]


KeyboardInterrupt: 