In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import get_linear_schedule_with_warmup

import json
import numpy as np
import random
import os
from sklearn.model_selection import train_test_split
from scipy.optimize import linear_sum_assignment

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
def cxcywh_to_xyxy(boxes):
    cx, cy, w, h = boxes.unbind(-1)
    return torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], dim=-1)

def box_iou(boxes1, boxes2):
    boxes1 = cxcywh_to_xyxy(boxes1)
    boxes2 = cxcywh_to_xyxy(boxes2)

    x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])
    y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])
    x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])
    y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])

    inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
    union = area1[:, None] + area2 - inter
    return inter / union.clamp(min=1e-6)

def hungarian_loss_batch(pred, target, class_weight=10.0, l1_weight=1.0, iou_weight=1.0):
    """
    pred: (B, N, 5)
    target: (B, M, 4)
    """
    B, N, _ = pred.shape
    M = target.shape[1]

    total_loss = 0.

    for b in range(B):
        pred_b = pred[b]      # (N, 5)
        tgt_b = target[b]     # (M, 4)

        distances = torch.norm(pred[b, :, 1:3], dim=1)
        vesa = 1 + 1/(0.1 + (5*distances.detach())**2)

        scores = torch.sigmoid(pred_b[:, 0])
        boxes_pred = pred_b[:, 1:]

        with torch.no_grad():
            cost_class = -scores[:, None]
            cost_bbox = torch.cdist(boxes_pred, tgt_b, p=1)
            cost_iou = -box_iou(boxes_pred, tgt_b)

            total_cost = class_weight * cost_class + l1_weight * cost_bbox + iou_weight * cost_iou
            indices = linear_sum_assignment(total_cost.cpu())

        matched_idx_pred = torch.tensor(indices[0], device=pred.device)
        matched_idx_tgt = torch.tensor(indices[1], device=pred.device)

        # Score targets: matched → 1, unmatched → 0
        target_scores = torch.zeros(N, device=pred.device)
        target_scores[matched_idx_pred] = 1.

        #print(scores, target_scores)
        loss_score = F.binary_cross_entropy(scores, target_scores, weight=vesa)

        # BBox & IoU losses (по matched)
        matched_pred_boxes = boxes_pred[matched_idx_pred]
        matched_target_boxes = tgt_b[matched_idx_tgt]

        loss_bbox = F.mse_loss(matched_pred_boxes, matched_target_boxes)
        loss_iou = 1 - box_iou(matched_pred_boxes, matched_target_boxes).mean()

        loss = (
            class_weight * loss_score +
            l1_weight * loss_bbox +
            iou_weight * loss_iou
        )
        total_loss += loss

    return total_loss / B


def batched_hungarian_loss(res1, bbox_target, views_masks, house_count, hungarian_loss_fn):
    """
    res1: (B, V, N, 5) — предсказания боксов
    bbox_target: (B, V, M, 4) — таргет боксы
    views_masks: (B, V) — булева маска валидных views
    house_count: (B,) — количество реальных боксов (M) на каждый B
    hungarian_loss_fn: функция, принимающая (B', N, 5) и (B', M, 4)
    """
    device = res1.device
    B, V, N, _ = res1.shape
    M = bbox_target.shape[2]

    pred_batch = []
    target_batch = []

    for bs in range(B):
        valid_views = torch.where(views_masks[bs])[0]
        for v in valid_views:
            pred_boxes = res1[bs, v]  # (N, 5)
            num_houses = house_count[bs]
            targ_boxes = bbox_target[bs, v, :num_houses]  # (M, 4)

            # Padding targets to match M if нужно
            if num_houses < M:
                pad = torch.zeros((M - num_houses, 4), device=device)
                targ_boxes = torch.cat([targ_boxes, pad], dim=0)

            pred_batch.append(pred_boxes)
            target_batch.append(targ_boxes)

    if len(pred_batch) == 0:
        return torch.tensor(0.0, device=device)

    pred_tensor = torch.stack(pred_batch)   # (B', N, 5)
    target_tensor = torch.stack(target_batch)  # (B', M, 4)

    return hungarian_loss_fn(pred_tensor, target_tensor)

def batched_camera_loss(camera_target, res2, views_masks):
    """
    camera_target: (B, V, V, D) — ground truth
    res2: (B, V, V, D) — predicted
    views_masks: (B, V) — bool mask per sample: which views are valid
    """
    B = camera_target.shape[0]
    total_loss = 0.0
    total_count = 0

    for bs in range(B):
        mask = views_masks[bs]  # shape: (V,)
        if mask.sum() == 0:
            continue  # нет ни одного валидного представления

        targ_points = camera_target[bs][mask][:, mask]  # shape: (m, m, D)
        pred_points = res2[bs][mask][:, mask]

        loss = F.mse_loss(pred_points, targ_points, reduction='sum')
        total_loss += loss
        total_count += targ_points.numel()  # нормализуем по числу элементов

    if total_count == 0:
        return torch.tensor(0.0, device=camera_target.device)

    return total_loss / total_count

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, pathes, v_len=64, h_len=16, f_dim=118):
        self.pathes = pathes
        self.f_dim = f_dim
        self.v_len = v_len
        self.h_len = h_len

    def __len__(self):
        return len(self.pathes)

    def __getitem__(self, idx):

        with open(self.pathes[idx], 'r', encoding='utf-8') as file:
            data = json.load(file)

        rand_count_view = np.random.randint(len(data['v'])//2, len(data['v'])+1)
        rand_views_list = sorted(random.sample(list(range(len(data['v']))), rand_count_view))
        rand_house_fitr = torch.rand(data['n'], self.f_dim)
        
        padded = self.v_len-rand_count_view
        n_houses = torch.tensor(data['t']).shape[1]
        views_masks = torch.tensor([True]*rand_count_view+[False]*padded)
        house_masks = []
        house_fites = []
        for view_idx in rand_views_list:
            houses_f = []
            h_c = len(data['b'][view_idx])
            house_mask = torch.tensor([True]*h_c + [False]*(self.h_len-h_c))
            
            house_fitr = []
            for house in torch.tensor(data['b'][view_idx]):
                house_idx = int(house[0])
                noise_house_f = rand_house_fitr[house_idx]# + torch.randn(self.f_dim) * 0.5
                fully_house_f = torch.cat([house[1:], noise_house_f])
                house_fitr.append(fully_house_f)
            posfix = torch.full((self.h_len-len(data['b'][view_idx]), self.f_dim+4),-150.)
            house_fitr = torch.cat([torch.stack(house_fitr), posfix], dim=0)
            house_fites.append(house_fitr)
            house_masks.append(house_mask)

        postfix = torch.full((padded, self.h_len), False)
        house_masks = torch.cat([torch.stack(house_masks), postfix], dim=0)
        postfix = torch.full((padded, self.h_len, self.f_dim+4), -250.)
        house_fites = torch.cat([torch.stack(house_fites), postfix], dim=0)

        camera_real = torch.tensor(data['c'])[rand_views_list][:, rand_views_list]
        target_real = torch.tensor(data['t'])[rand_views_list]

        camera = torch.full((self.v_len, self.v_len, 2), -75.)
        camera_idx = torch.cartesian_prod(torch.tensor(rand_views_list), torch.tensor(rand_views_list))
        camera[camera_idx[:, 0], camera_idx[:, 1]] = camera_real.view(-1, 2)

        target = torch.full((self.v_len, self.h_len, 4), -95.)
        target[views_masks, :n_houses, :] = target_real[:,:,1:]

        # target_masks -- views_masks, :n_houses, :
        return house_fites, views_masks, house_masks, camera, target, n_houses

In [None]:
data_path = ... # Путь к сохраненным данным
all_files = [data_path+'/'+i for i in os.listdir(data_path)]
train_files, valid_files = train_test_split(all_files, test_size=0.2, random_state=42)

train_data = Dataset(train_files)
valid_data = Dataset(valid_files)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, transformer_dim, n_heads, ff_dim, num_layers, dropout=0.1, batch_first=True):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, transformer_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=batch_first)
        
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x, src_key_padding_mask=None):
        x = self.mlp(x)
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        return x

In [None]:
class Matrix(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, embed_dim, num_heads):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, embed_dim)
        )

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, src_key_padding_mask=None):
        B, S, _ = x.shape
        x = self.mlp(x)

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        def reshape_heads(tensor):
            return tensor.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)

        Q = reshape_heads(Q)
        K = reshape_heads(K)
        V = reshape_heads(V)
        raw_attention = torch.matmul(Q, K.transpose(-2, -1)) / self.head_dim**0.5
        if src_key_padding_mask is not None:
            # (B, 1, 1, S) — broadcast по головам и запросам
            mask = src_key_padding_mask[:, None, None, :]  # True = pad
            raw_attention = raw_attention.masked_fill(mask, 0.0)

        return raw_attention

In [None]:
class MultiTaskLossWrapper(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.log_sigma1_sq = torch.nn.Parameter(torch.tensor(0.0))
        self.log_sigma2_sq = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, loss1, loss2):
        loss = (
            0.5 * torch.exp(-self.log_sigma1_sq) * loss1 + 0.5 * self.log_sigma1_sq +
            0.5 * torch.exp(-self.log_sigma2_sq) * loss2 + 0.5 * self.log_sigma2_sq
        )
        return loss

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_hidden_layers, output_dim, activation=nn.SiLU):
        super(MLP, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(activation())
        for _ in range(num_hidden_layers):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(activation())
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
class Merge(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.boxes = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 5*16))
        
        self.house_encoder = Encoder(
            input_dim=128,         # размерность входного вектора
            mlp_hidden_dim=256,    # скрытый слой MLP
            transformer_dim=128,   # d_model для трансформера
            n_heads=4,             # количество голов внимания
            ff_dim=128,            # размер FFN в трансформере
            num_layers=2,          # количество слоёв трансформера
        )

        self.view_encoder = Encoder(
            input_dim=128,         # размерность входного вектора
            mlp_hidden_dim=256,    # скрытый слой MLP
            transformer_dim=128,    # d_model для трансформера
            n_heads=4,             # количество голов внимания
            ff_dim=128,            # размер FFN в трансформере
            num_layers=2,          # количество слоёв трансформера
        )

        self.camera = Matrix(
            input_dim=128,
            mlp_hidden_dim=256,
            embed_dim=128,
            num_heads=32
        )

        self.coords = nn.Sequential(
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, 2))

    def forward(self, house_fites, views_masks, house_masks):
        bs, nv, nh, _ = house_fites.shape
        x, y = house_fites[..., 0], house_fites[..., 1]
        rot0 = torch.stack([x, y], dim=-1)       # 0°
        rot90 = torch.stack([-y, x], dim=-1)     # 90°
        rot180 = torch.stack([-x, -y], dim=-1)   # 180°
        rot270 = torch.stack([y, -x], dim=-1)    # 270°
        all_rots = torch.cat([rot0, rot90, rot180, rot270, house_fites[..., 2:]], dim=-1)
        house_seq = all_rots[views_masks]
        all_rots[views_masks] = self.house_encoder(house_seq, src_key_padding_mask=~house_masks[views_masks])
        view_vectors = all_rots.max(dim=2).values
        final_view_vectors = self.view_encoder(view_vectors, src_key_padding_mask=~views_masks)
        camera_matrix = self.camera(final_view_vectors, src_key_padding_mask=~views_masks)
        camera_matrix = self.coords(camera_matrix.permute(0, 2, 3, 1))
        boxes = self.boxes(final_view_vectors)
        boxes = boxes.view(bs, nv, nh, 5)
        
        return boxes, camera_matrix

In [None]:
w_path = ... # Путь к весам модели
model = Merge()
model.to(device)
model.load_state_dict(torch.load(w_path))

In [None]:
num_epochs = 100
device = 'cuda'
model.to(device)
optimizer = torch.optim.AdamW(
    list(model.parameters()),
    lr=1e-4,
    weight_decay=0.01)

num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=32, shuffle=False)

In [None]:
best_loss = torch.inf
for epoch in tqdm(range(num_epochs), desc="Epoches: "):
    model.train()
    total_train_camera, total_train_bboxes, total_train_loss = 0, 0, 0
    train_bar = tqdm(train_loader, leave=False, desc="Train: ")

    for batch in train_bar:
        house_fites, views_masks, house_masks, camera_target, bbox_target, house_count = [b.to(device) for b in batch]
        res1, res2 = model(house_fites, views_masks, house_masks)

        #camera_loss = batched_camera_loss(camera_target, res2, views_masks)
        bboxes_loss = batched_hungarian_loss(res1, bbox_target, views_masks, house_count, hungarian_loss_fn=hungarian_loss_batch)
        #loss = multitask_loss_fn(camera_loss, bboxes_loss)

        optimizer.zero_grad()
        #loss.backward()
        bboxes_loss.backward()
        optimizer.step()
        scheduler.step()

        #total_train_camera += camera_loss.item()
        total_train_bboxes += bboxes_loss.item()
        if bboxes_loss.item() < best_loss:
            torch.save(model.state_dict(), 'w_mlp.pth')
        #total_train_loss += loss.item()

        train_bar.set_postfix({
            #"Loss": f"{loss.item():.4f}",
            #"Cam": f"{camera_loss.item():.4f}",
            "BBox": f"{bboxes_loss.item():.4f}"
        })
    total_train_bboxes /= len(train_loader)
    
    # Валидация
    model.eval()
    val_camera, val_bboxes, val_loss = 0, 0, 0
    with torch.no_grad():
        for batch in tqdm(valid_loader, leave=False, desc="Valid: "):
            house_fites, views_masks, house_masks, camera_target, bbox_target, house_count = [b.to(device) for b in batch]
            res1, res2 = model(house_fites, views_masks, house_masks)

            #camera_loss = batched_camera_loss(camera_target, res2, views_masks)
            bboxes_loss = batched_hungarian_loss(res1, bbox_target, views_masks, house_count, hungarian_loss_fn=hungarian_loss_batch)
            #loss = multitask_loss_fn(camera_loss, bboxes_loss)

            #val_camera += camera_loss.item()
            val_bboxes += bboxes_loss.item()
            #val_loss += loss.item()
    val_bboxes /= len(valid_loader)

    #print(f"Epoch {epoch+1}/{num_epochs}")
    #print(f"Train loss: total={total_train_loss:.4f}, cam={total_train_camera:.4f}, box={total_train_bboxes:.4f}")
    #print(f"Val   loss: total={val_loss:.4f}, cam={val_camera:.4f}, box={val_bboxes:.4f}")
    print(f"Epoch {epoch+1}/{num_epochs}, Train loss: box={total_train_bboxes:.4f}, Val   loss: box={val_bboxes:.4f}")