In [55]:
from AIA.rl.AlphaTensor.dataset.dataset import generate_synthetic_dataset

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
dataset = generate_synthetic_dataset(10000, save_to_file=False)

In [57]:
dataset

In [58]:
dataset.drop("actions", axis=1, inplace=True)

In [59]:
dataset

In [60]:
import torch
from torch.utils.data import Dataset

class TensorDataset3D(Dataset):
    def __init__(self, data):
        """
        data: list of (numpy_array [4,4,4], label)
        """
        self.samples = []
        for tensor, label in data:
            t = torch.tensor(tensor, dtype=torch.float32)  # [4,4,4]
            l = torch.tensor(label, dtype=torch.long)
            self.samples.append((t, l))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [61]:
import ast
import numpy as np

def parse_tensor(tensor_str):
    return np.array(ast.literal_eval(tensor_str), dtype=np.float32)

In [62]:
raw_data = [
    (tensor.astype(np.float32), int(label))
    for tensor, label in zip(dataset["tensor"], dataset["value"])
]

In [63]:
raw_data[0][0]

In [64]:
from AIA.rl.AlphaTensor.net.network import Torso, ValueHead

tensor_dataset = TensorDataset3D(raw_data)

In [65]:
from torch.utils.data import DataLoader
loader = DataLoader(tensor_dataset, batch_size=16, shuffle=True)

In [66]:
import torch.nn as nn
import torch.optim as optim

In [67]:
model_torso = Torso(scalar_size=0)
model_value = ValueHead()

# Перемести на GPU если есть
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_torso.to(device)
model_value.to(device)

# Loss: регрессия → MSELoss
criterion = nn.MSELoss()

# Optimizer: можно общий
params = list(model_torso.parameters()) + list(model_value.parameters())
optimizer = optim.Adam(params, lr=1e-3)


In [68]:
epochs = 100

for epoch in range(epochs):
    model_torso.train()
    model_value.train()

    total_loss = 0.0

    for x_batch, y_batch in loader:
        x_batch = x_batch.to(device)          # [B, 4,4,4]
        y_batch = y_batch.float().to(device)  # [B]

        # reshape: [B, 64]
        input_flat = x_batch.reshape(-1, 4 * 4 * 4)

        # forward pass
        emb = model_torso(input_flat)  # [B, hidden_dim]
        out = model_value(emb)         # [B], после squeeze

        # loss
        loss = criterion(out, y_batch)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x_batch.size(0)

    avg_loss = total_loss / len(loader.dataset)
    print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")

Epoch 1: loss = 2.1583
Epoch 2: loss = 1.6185
Epoch 3: loss = 1.4410
Epoch 4: loss = 1.2902
Epoch 5: loss = 1.0963
Epoch 6: loss = 0.9587
Epoch 7: loss = 0.8215
Epoch 8: loss = 0.7309
Epoch 9: loss = 0.6658
Epoch 10: loss = 0.6215
Epoch 11: loss = 0.5633
Epoch 12: loss = 0.5321
Epoch 13: loss = 0.5149
Epoch 14: loss = 0.4916
Epoch 15: loss = 0.4690
Epoch 16: loss = 0.4543
Epoch 17: loss = 0.4253
Epoch 18: loss = 0.4213
Epoch 19: loss = 0.4101
Epoch 20: loss = 0.3869
Epoch 21: loss = 0.3820
Epoch 22: loss = 0.3746
Epoch 23: loss = 0.3559
Epoch 24: loss = 0.3546
Epoch 25: loss = 0.3443
Epoch 26: loss = 0.3422
Epoch 27: loss = 0.3305
Epoch 28: loss = 0.3263
Epoch 29: loss = 0.3224
Epoch 30: loss = 0.3144
Epoch 31: loss = 0.3115
Epoch 32: loss = 0.3066
Epoch 33: loss = 0.3027
Epoch 34: loss = 0.2963
Epoch 35: loss = 0.2971
Epoch 36: loss = 0.2915
Epoch 37: loss = 0.2917
Epoch 38: loss = 0.2809
Epoch 39: loss = 0.2828
Epoch 40: loss = 0.2803
Epoch 41: loss = 0.2778
Epoch 42: loss = 0.2737
E