In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_distances
from collections import defaultdict
import random
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [None]:
# @title Data_preparation
dx   = pd.read_csv('hosp/diagnoses_icd.csv')
adm  = pd.read_csv('hosp/admissions.csv', parse_dates=['admittime'])
meta = pd.read_csv('hosp/d_icd_diagnoses.csv')
admissions_number = dx.groupby(by='subject_id').apply(lambda x: x['hadm_id'].nunique())
icds_number = dx.groupby(by='subject_id').apply(lambda x: x['icd_code'].count())

# merge dx → attach each diagnosis to its admission time
dx = dx.merge(adm[['hadm_id','admittime', 'dischtime', 'admission_type', 'admission_location', 'discharge_location', 'race', 'hospital_expire_flag']], on='hadm_id')
dx = dx.sort_values(['subject_id','admittime'])

disease_counts = dx.groupby('subject_id')['icd_code'].nunique()
valid_ids = disease_counts[disease_counts > 1].index
df = dx[dx['subject_id'].isin(valid_ids)].copy()

# Filter only ICD‑10 codes
df = df[df['icd_version'] == 10].copy()

# Truncate ICD‑10 codes to first 3 characters (e.g., 'C787' → 'C78')
df['icd3'] = df['icd_code'].str[:3]
df['icd_code'] = df['icd3']
df.drop(columns=['icd3'], inplace=True)


In [None]:
# @title Final_version
import math
import random
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_distances
from tqdm import tqdm
import pandas as pd

class Config:
    embedding_dim = 256
    num_heads     = 8
    hidden_dim    = 512
    num_layers    = 3
    mask_prob     = 0.15
    batch_size    = 64
    learning_rate = 1e-4
    epochs        = 40
    max_seq_len   = 100
    dropout_rate  = 0.2

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        self.d_model = d_model
        self.register_buffer('pe', self._build_pe(max_len), persistent=False)

    def _build_pe(self, length):
        position = torch.arange(length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model)
        )
        pe = torch.zeros(length, self.d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, length, d_model)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        batch_size, seq_len, _ = x.size()
        if seq_len > self.pe.size(1):
            pe_big = self._build_pe(seq_len).to(x.device)
            return x + pe_big
        else:
            return x + self.pe[:, :seq_len, :]

class ICDMLMModel(nn.Module):
    def __init__(self, config, vocab_size):
        super().__init__()
        self.cfg = config

        self.embedding = nn.Embedding(vocab_size, config.embedding_dim, padding_idx=0)
        self.pos_encoder = PositionalEncoding(config.embedding_dim, config.max_seq_len)
        self.embedding_dropout = nn.Dropout(config.dropout_rate)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=config.embedding_dim,
            nhead=config.num_heads,
            dim_feedforward=config.hidden_dim,
            dropout=config.dropout_rate,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=config.num_layers)
        self.layer_norm  = nn.LayerNorm(config.embedding_dim)

        self.diagnosis_head = nn.Sequential(
            nn.Linear(config.embedding_dim, config.hidden_dim),
            nn.GELU(),
            nn.LayerNorm(config.hidden_dim),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.hidden_dim, vocab_size)
        )

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None: nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids):
        x = self.embedding(input_ids)          # (B, L, D)
        x = self.pos_encoder(x)
        x = self.embedding_dropout(x)
        x = self.transformer(x)
        x = self.layer_norm(x)
        return self.diagnosis_head(x)          # (B, L, V)

def prepare_data(df):
    sequences = [
        grp.sort_values('seq_num')['disease_id'].tolist()
        for _, grp in df.groupby('hadm_id')
    ]
    vocab = {'[PAD]': 0, '[MASK]': 1}
    idx = 2
    for seq in sequences:
        for code in seq:
            if code not in vocab:
                vocab[code] = idx
                idx += 1
    idx_to_code = {i: c for c, i in vocab.items()}
    return sequences, vocab, idx_to_code

class MaskedICDDataset(Dataset):
    def __init__(self, sequences, vocab, cfg):
        self.sequences = sequences
        self.vocab     = vocab
        self.cfg       = cfg
        self.mask_id   = vocab['[MASK]']
        self.pad_id    = vocab['[PAD]']

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx][:self.cfg.max_seq_len]
        ids = [self.vocab[c] for c in seq]
        pad = [self.pad_id] * (self.cfg.max_seq_len - len(ids))
        input_ids = ids + pad

        labels = [-100] * self.cfg.max_seq_len
        for i in range(len(ids)):
            if random.random() < self.cfg.mask_prob:
                labels[i] = input_ids[i]
                r = random.random()
                if r < 0.8:
                    input_ids[i] = self.mask_id
                elif r < 0.9:
                    input_ids[i] = random.choice(list(self.vocab.values()))
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels':    torch.tensor(labels,    dtype=torch.long)
        }

def calculate_accuracy(logits, labels, ignore_index=-100):
    preds = logits.argmax(dim=-1)
    mask  = labels != ignore_index
    correct = (preds[mask] == labels[mask]).sum().float()
    total   = mask.sum().float()
    return (correct / total).item() if total > 0 else 0.0

def get_optimizer(model, cfg, train_loader):
    opt = optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=0.01)
    sched = optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=cfg.learning_rate,
        steps_per_epoch=len(train_loader),
        epochs=cfg.epochs,
        anneal_strategy='linear'
    )
    return opt, sched

def extract_embeddings(model, idx_to_code, vocab):
    model.eval()
    device = next(model.parameters()).device

    valid = [i for i, c in idx_to_code.items() if c not in ('[PAD]', '[MASK]')]
    codes = [idx_to_code[i] for i in valid]
    input_ids = torch.tensor(valid, dtype=torch.long, device=device).unsqueeze(0)

    with torch.no_grad():
        x = model.embedding(input_ids)
        x = model.pos_encoder(x)
        x = model.transformer(x)
        x = model.layer_norm(x)
        emb = x.squeeze(0).cpu().numpy()

    dist = cosine_distances(emb)
    return pd.DataFrame(dist, index=codes, columns=codes)

def train_mlm(df):
    cfg = Config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    sequences, vocab, idx_to_code = prepare_data(df)
    vocab_size = len(vocab)

    # Сплит 90/10
    n_train = int(0.9 * len(sequences))
    train_ds = MaskedICDDataset(sequences[:n_train], vocab, cfg)
    val_ds   = MaskedICDDataset(sequences[n_train:],   vocab, cfg)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False)

    model = ICDMLMModel(cfg, vocab_size).to(device)
    optimizer, scheduler = get_optimizer(model, cfg, train_loader)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    best_val_loss = float('inf')
    best_val_acc  = 0.0
    best_state    = None

    for epoch in range(1, cfg.epochs+1):
        model.train()
        train_loss = train_acc = 0.0
        for batch in tqdm(train_loader, desc=f"Train {epoch}/{cfg.epochs}"):
            inp, lbl = batch['input_ids'].to(device), batch['labels'].to(device)
            optimizer.zero_grad()
            out  = model(inp)
            loss = criterion(out.view(-1, vocab_size), lbl.view(-1))
            acc  = calculate_accuracy(out, lbl)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
            train_acc  += acc

        model.eval()
        val_loss = val_acc = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Valid {epoch}/{cfg.epochs}"):
                inp, lbl = batch['input_ids'].to(device), batch['labels'].to(device)
                out  = model(inp)
                loss = criterion(out.view(-1, vocab_size), lbl.view(-1))
                acc  = calculate_accuracy(out, lbl)
                val_loss += loss.item()
                val_acc  += acc

        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc  = train_acc  / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        avg_val_acc    = val_acc    / len(val_loader)

        print(f"\nEpoch {epoch} Summary:")
        print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f}")
        print(f" Val  Loss: {avg_val_loss:.4f} |  Val  Acc: {avg_val_acc:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_acc  = avg_val_acc
            best_state    = model.state_dict().copy()
            print("New best model saved!")

    model.load_state_dict(best_state)
    torch.save({
        'model_state_dict': model.state_dict(),
        'vocab':            vocab,
        'config':           cfg.__dict__,
        'val_loss':         best_val_loss,
        'val_acc':          best_val_acc
    }, 'icd_mlm_model.pth')
    print("Модель сохранена в 'icd_mlm_model.pth'")

    return extract_embeddings(model, idx_to_code, vocab)


In [None]:
if __name__ == '__main__':
    df['disease_id'] = df['icd_code'] + '_v' + df['icd_version'].astype(str)
    distance_matrix = train_mlm(df)
distance_matrix.to_csv('icd_distance_matrix.csv')
print("\nМатрица сходства сохранена в 'icd_distance_matrix.csv'")