# Imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
from scipy.sparse import diags
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import json
import bisect
from collections import OrderedDict
import os
import time
import tqdm
import gc

# Model

In [None]:
from model import (
    CReLU,
    Identity,
    ComplexInstanceNorm2d,
    ComplexConv2d,
    ComplexConv1d,
    SpectralConv2d,
    SpectralConv2d_Complex,
    MLP,
    MLP1d,
    FourierLayer,
    FNO2d
)

In [None]:
def move_to_cuda(sample, device=None):
    if len(sample) == 0:
        return {}

    def _move_to_cuda(maybe_tensor,device):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.to(device)
        elif isinstance(maybe_tensor, dict):
            return {key: _move_to_cuda(value,device) for key, value in maybe_tensor.items()}
        elif isinstance(maybe_tensor, list):
            return [_move_to_cuda(x,device) for x in maybe_tensor]
        elif isinstance(maybe_tensor, tuple):
            return [_move_to_cuda(x,device) for x in maybe_tensor]
        else:
            return maybe_tensor

    return _move_to_cuda(sample,device)

# def

def collate_fn(batch_data):
    inputs = []
    eig_vecs = []

    for _,data in enumerate(batch_data):
        inputs.append(data['input'])
        eig_vecs.append(data['eig_vec'])

    return {'inputs':torch.cat(inputs,dim=0),
            'eig_vecs':torch.cat(eig_vecs,dim=0)
    }

In [None]:
class mat_dataset(Dataset):
    def __init__(self, index_path, k, channel_dim=3, cache_size=4):
        super().__init__()
        self.k = k
        self.channel_dim = channel_dim

        # читаем метаданные из json (быстро)
        with open(index_path, "r") as f:
            index = json.load(f)
        self.path_list = index["paths"]
        self.file_sizes = index["sizes"]

        # кумулятивные суммы как обычный python-список
        self.cum_sizes = []
        s = 0
        for sz in self.file_sizes:
            s += sz
            self.cum_sizes.append(s)
        self.total_len = self.cum_sizes[-1]

        # LRU-кэш
        self.cache = OrderedDict()
        self.cache_size = cache_size

    def _load_file(self, path):
        if path in self.cache:
            self.cache.move_to_end(path)
            return self.cache[path]

        data = torch.load(path, map_location="cpu")

        if len(self.cache) >= self.cache_size:
            self.cache.popitem(last=False)

        self.cache[path] = data
        return data

    def __len__(self):
        return self.total_len

    def __getitem__(self, index):
        # Находим file_idx через bisect вместо .nonzero()
        # self.cum_sizes — монотонно возрастающий список
        file_idx = bisect.bisect_right(self.cum_sizes, index)
        # bisect_right вернёт позицию, поэтому индекс файла = file_idx
        # но index — 0-based, а cum_sizes — как [n0, n0+n1, ...]
        # Например, index=0..n0-1 → file_idx=0

        # аккуратно вычисляем локальный индекс
        if file_idx == 0:
            local_index = index
        else:
            local_index = index - self.cum_sizes[file_idx - 1]

        data = self._load_file(self.path_list[file_idx])

        x = data['params'][local_index]
        eig_vecs = data['eig_vecs'][local_index]

        if self.channel_dim is not None and self.channel_dim >= 0:
            # на всякий случай проверка размерности
            if self.channel_dim <= x.dim():
                x = x.unsqueeze(self.channel_dim)
            else:
                # если что-то не так, лучше явно упасть, чем тихо сломать форму
                raise ValueError(
                    f"channel_dim={self.channel_dim} is out of range for x.dim()={x.dim()}"
                )

        if self.k > 0:
            eig_vecs = eig_vecs[..., :self.k]

        # Я БЫ советовал здесь не делать unsqueeze(0), а оставить
        # это DataLoader’у, но если у тебя остальной код завязан на
        # эту форму – можно временно оставить как есть:
        return {
            'input': x.unsqueeze(0),
            'eig_vec': eig_vecs.unsqueeze(0)
        }


## Losses

In [None]:
class ProjectionLoss(nn.Module):
    def __init__(self,num_type='real',reduction='mean',p=2, dim=1):
        super(ProjectionLoss,self).__init__()
        self.num_type = num_type
        self.reduction = reduction
        self.p=p
        self.dim=dim

        if self.num_type == 'complex':
            self.trans = torch.adjoint
        else:
            self.trans = torch.transpose

    def forward(self,Q,V):
        """
        Q: [batch_size,dim,k]
        V: [batch_size,dim,K] # K > k
        caculate sum(V-Q@Q*V)
        """

        assert Q.shape[-2] == V.shape[-2], f"Shape error! Q shape [{Q.shape[-2]}] must match V shape [{V.shape[-2]}] in dimension -2"

        Qt = self.trans(Q,-2,-1)
        QtV = torch.bmm(Qt,V)
        QQtV = torch.bmm(Q,QtV)

        result = V - QQtV
        norm = torch.norm(result,p=self.p,dim=self.dim)

        loss = torch.sum(norm,dim=-1)

        if self.reduction == 'mean':
            result = torch.mean(loss)
        elif self.reduction == 'sum':
            result = torch.sum(loss)

        return result


class PrincipalAngle(nn.Module):
    """
    Caculate the angle between two subspace.
    type[Option]: biggest or smallest angle between two subspace
    """
    def __init__(self, angle_type='biggest', reduction = 'mean', clip_value=True):
        super(PrincipalAngle,self).__init__()
        self.angle_type = angle_type
        self.reduction = reduction
        self.clip_value = clip_value
        self.compare = torch.max if self.angle_type == 'biggest' else torch.min

    def forward(self,Q,V):
        """
        Q,V: base vectors which make up the subspace
        """

        _, values, _ = torch.linalg.svd(torch.bmm(torch.transpose(Q,-2,-1),V))

        if self.clip_value:
            values = torch.clamp(values,min=-1,max=1)

        angles = torch.acos(values)

        angle,_ = self.compare(angles,dim=-1)

        if self.reduction == 'mean':
            result = torch.mean(angle)
        elif self.reduction == 'sum':
            result = torch.sum(angle)
        else:
            result = angle

        return result

## For colab

In [None]:
from google.colab import drive
import time, os, shutil
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Path for .pt file data

In [None]:
folder_path = '/content/drive/MyDrive/heat_backup'
files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.startswith('heat_batch_') and f.endswith('.pt')]

In [7]:
divide_step = int(len(files) / 5 * 4)

In [None]:
def build_index(path_list, index_path="sizes_index.json"):
    file_sizes = []
    for path in path_list:
        data = torch.load(path, map_location="cpu")
        # предполагаем, что data['params'] есть в каждом файле
        file_sizes.append(data['params'].shape[0])
        del data
    index = {
        "paths": path_list,
        "sizes": file_sizes
    }
    with open(index_path, "w") as f:
        json.dump(index, f)
    print(f"Saved index to {index_path}")


In [None]:
# Files for train, valid and test (sample)

train_index = "train_sizes_index_10.json"
valid_index = "valid_sizes_index_10.json"
test_index = "test_sizes_index_10.json"

# Setup

In [None]:
SAVE_PATH = './results/'
if not os.path.exists(SAVE_PATH):
    os.mkdir(SAVE_PATH)

r = 300
width = 10
modes1 = 8
modes2 = 8
in_channel_size = 2
out_size = (r)**2
num_layers = 5
norm = True
COMPLEX = False
grid = True
num_type = 'complex' if COMPLEX else 'real'

k = 25
batch_size = 32
num_workers = 0

print("*"*20+" DATA LOADING "+"*"*20)

train_loader = DataLoader(
    mat_dataset(index_path=train_index, k=k, cache_size=8),
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=True,
    shuffle=True,
    pin_memory=False,
    collate_fn=collate_fn,
)

valid_loader = DataLoader(
    mat_dataset(index_path=valid_index, k=k, cache_size=8),
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
    pin_memory=False,
    collate_fn=collate_fn,
)


******************** DATA LOADING ********************


In [None]:
torch.cuda.empty_cache()
gc.collect()

device = 'cuda:0'
lr = 5e-3
weight_decay = 1e-5
num_epoch = 100

model = FNO2d(modes1=modes1,
              modes2=modes2,
              width=width,
              in_channel_size=in_channel_size,
              out_size=out_size,
              resolution=r,
              num_layers=num_layers,
              num_space_size=k,
              norm=norm,
              grid=grid,
              COMPLEX=COMPLEX).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
iterations = num_epoch * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)

#! Hybrid loss
class HybridLoss(nn.Module):
    def __init__(self, alpha=10.0):
        super().__init__()
        self.proj_loss = ProjectionLoss(num_type='real', reduction='mean')
        self.angle_loss = PrincipalAngle(angle_type='biggest', reduction='mean')
        self.alpha = alpha

    def forward(self, Q, V):
        proj = self.proj_loss(Q, V)
        angle = self.angle_loss(Q, V)
        return proj + self.alpha * angle

    def get_components(self, Q, V):
        with torch.no_grad():
            proj = self.proj_loss(Q, V).item()
            angle = self.angle_loss(Q, V).item()
        return proj, angle

train_loss_fn = HybridLoss(alpha=10.0)
test_loss_fn = PrincipalAngle()


print("*"*20 + " TRAINING" + "*"*20)
print(f"Config: width={20}, modes={12}x{12}, layers={5}")
print(f"Loss: ProjectionLoss + {10.0} * AngleLoss")
print(f"Data: {len(train_loader)} train batches, {len(valid_loader)} val batches\n")

best_val_angle = float('inf')
patience = 20
no_improve_count = 0


accumulation_steps = 2

for epoch in range(num_epoch):
    model.train()

    train_loss_total = 0
    train_angle_total = 0
    train_proj_total = 0

    optimizer.zero_grad()

    for batch_idx, batch_data in enumerate(train_loader):
        batch_data = move_to_cuda(batch_data, device=device)

        inputs = batch_data['inputs']
        if inputs.ndim == 5:
            inputs = inputs.squeeze(-1)
        if inputs.shape[1] == 302 and inputs.shape[2] == 302:
            inputs = inputs[:, 1:-1, 1:-1, :]

        label = batch_data['eig_vecs']
        output = model(**{'inputs': inputs})

        loss = train_loss_fn(output, label)

        loss = loss / accumulation_steps

        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        with torch.no_grad():
            angle = test_loss_fn(output, label)
            proj, angle_comp = train_loss_fn.get_components(output, label)

            train_loss_total += loss.item() * accumulation_steps  # умножаем обратно
            train_angle_total += angle.item()
            train_proj_total += proj

        if (batch_idx + 1) % 10 == 0:
            torch.cuda.empty_cache()

    if len(train_loader) % accumulation_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    avg_train_loss = train_loss_total / len(train_loader)
    avg_train_angle = train_angle_total / len(train_loader)
    avg_train_proj = train_proj_total / len(train_loader)

    # ============================================================
    # VALIDATION
    # ============================================================
    model.eval()
    val_angle_total = 0
    val_proj_total = 0

    with torch.no_grad():
        for batch_data in valid_loader:
            batch_data = move_to_cuda(batch_data, device=device)

            inputs = batch_data['inputs']
            if inputs.ndim == 5:
                inputs = inputs.squeeze(-1)
            if inputs.shape[1] == 302:
                inputs = inputs[:, 1:-1, 1:-1, :]

            label = batch_data['eig_vecs']
            output = model(**{'inputs': inputs})

            angle = test_loss_fn(output, label)
            proj, _ = train_loss_fn.get_components(output, label)

            val_angle_total += angle.item()
            val_proj_total += proj

    avg_val_angle = val_angle_total / len(valid_loader)
    avg_val_proj = val_proj_total / len(valid_loader)

    print(f"Epoch {epoch:3d} | "
          f"Train: angle={avg_train_angle:.4f} proj={avg_train_proj:.2f} | "
          f"Val: angle={avg_val_angle:.4f} proj={avg_val_proj:.2f} | "
          f"LR={scheduler.get_last_lr()[0]:.6f}")

    if avg_val_angle < best_val_angle:
        best_val_angle = avg_val_angle
        no_improve_count = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_angle': avg_val_angle,
        }, SAVE_PATH + 'best_model.pt')
        print(f"    ✓ Saved best model (val_angle={best_val_angle:.4f})")
    else:
        no_improve_count += 1

    if no_improve_count >= patience:
        print(f"\n⚠️  Early stopping: no improvement for {patience} epochs")
        break

    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*70)
print(f"TRAINING COMPLETED")
print(f"Best validation angle: {best_val_angle:.4f}")
print("="*70)

******************** TRAINING С ГИБРИДНОЙ LOSS ********************
Конфигурация: width=20, modes=12x12, layers=5
Loss: ProjectionLoss + 10.0 * AngleLoss
Data: 40 train batches, 7 val batches

Epoch   0 | Train: angle=1.1797 proj=9.52 | Val: angle=1.0508 proj=6.25 | LR=0.005000
    ✓ Saved best model (val_angle=1.0508)
Epoch   1 | Train: angle=1.0012 proj=5.29 | Val: angle=0.9981 proj=4.76 | LR=0.004999
    ✓ Saved best model (val_angle=0.9981)
Epoch   2 | Train: angle=0.9661 proj=4.38 | Val: angle=0.9534 proj=4.19 | LR=0.004997
    ✓ Saved best model (val_angle=0.9534)
Epoch   3 | Train: angle=0.8545 proj=3.97 | Val: angle=0.7762 proj=3.69 | LR=0.004995
    ✓ Saved best model (val_angle=0.7762)
Epoch   4 | Train: angle=0.6914 proj=3.39 | Val: angle=0.6928 proj=3.14 | LR=0.004992
    ✓ Saved best model (val_angle=0.6928)
Epoch   5 | Train: angle=0.6290 proj=3.00 | Val: angle=0.6532 proj=2.95 | LR=0.004989
    ✓ Saved best model (val_angle=0.6532)
Epoch   6 | Train: angle=0.5790 proj=2.