In [91]:
from AIA.rl.AlphaTensor.dataset.dataset import generate_synthetic_dataset

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [92]:
dataset = generate_synthetic_dataset(10000, save_to_file=False)

In [93]:
dataset

In [94]:
dataset.drop("actions", axis=1, inplace=True)

In [95]:
dataset

In [96]:
import torch
from torch.utils.data import Dataset

class TensorDataset3D(Dataset):
    def __init__(self, data):
        """
        data: list of (numpy_array [4,4,4], label)
        """
        self.samples = []
        for tensor, label in data:
            t = torch.tensor(tensor, dtype=torch.float32)  # [4,4,4]
            l = torch.tensor(label, dtype=torch.long)
            self.samples.append((t, l))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [97]:
import ast
import numpy as np

def parse_tensor(tensor_str):
    return np.array(ast.literal_eval(tensor_str), dtype=np.float32)

In [98]:
raw_data = [
    (tensor.astype(np.float32), int(label))
    for tensor, label in zip(dataset["tensor"], dataset["value"])
]

In [99]:
raw_data[0][0]

In [100]:
from AIA.rl.AlphaTensor.net.network import Torso, ValueHead
from sklearn.model_selection import train_test_split

In [101]:
train_data, test_data = train_test_split(raw_data, test_size=0.2, random_state=42)

In [102]:
from torch.utils.data import DataLoader
train_dataset = TensorDataset3D(train_data)
test_dataset = TensorDataset3D(test_data)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [103]:
import torch.nn as nn
import torch.optim as optim

In [107]:
model_torso = Torso(scalar_size=0, hidden_dim=1024)
model_value = ValueHead(hidden_dim=1024)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_torso.to(device)
model_value.to(device)

criterion = nn.MSELoss()

params = list(model_torso.parameters()) + list(model_value.parameters())
optimizer = optim.Adam(params, lr=1e-3)


In [108]:
epochs = 100

for epoch in range(epochs):
    print("==========================")
    print("Epoch", epoch)

    model_torso.train()
    model_value.train()

    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        model_torso.train()
        model_value.train()

        x_batch = x_batch.to(device)          # [B, 4,4,4]
        y_batch = y_batch.float().to(device)  # [B]

        # reshape: [B, 64]
        input_flat = x_batch.reshape(-1, 4 * 4 * 4)

        # forward pass
        emb = model_torso(input_flat)  # [B, hidden_dim]
        out = model_value(emb)         # [B], после squeeze

        # loss
        loss = criterion(out, y_batch)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x_batch.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Train | loss = {avg_loss:.4f}")

    total_loss = 0.0
    for x_batch, y_batch in test_loader:
        model_torso.eval()
        model_value.eval()

        x_batch = x_batch.to(device)
        y_batch = y_batch.float().to(device)

        input_flat = x_batch.reshape(-1, 4 * 4 * 4)

        # forward pass
        emb = model_torso(input_flat)  # [B, hidden_dim]
        out = model_value(emb)         # [B], после squeeze

        # loss
        loss = criterion(out, y_batch)

        total_loss += loss.item() * x_batch.size(0)

    avg_loss = total_loss / len(test_loader.dataset)
    print(f"Test | loss = {avg_loss:.4f}")

Epoch 0
Train | loss = 2.2533
Test | loss = 1.5852
Epoch 1
Train | loss = 1.6280
Test | loss = 1.6285
Epoch 2
Train | loss = 1.4101
Test | loss = 1.5255
Epoch 3
Train | loss = 1.2902
Test | loss = 1.2984
Epoch 4
Train | loss = 1.1283
Test | loss = 1.2562
Epoch 5
Train | loss = 0.9878
Test | loss = 1.2672
Epoch 6
Train | loss = 0.8650
Test | loss = 1.1346
Epoch 7
Train | loss = 0.7683
Test | loss = 0.9887
Epoch 8
Train | loss = 0.7003
Test | loss = 0.9454
Epoch 9
Train | loss = 0.6271
Test | loss = 0.9419
Epoch 10
Train | loss = 0.5741
Test | loss = 0.9629
Epoch 11
Train | loss = 0.5474
Test | loss = 1.0002
Epoch 12
Train | loss = 0.5115
Test | loss = 0.9128
Epoch 13
Train | loss = 0.4892
Test | loss = 0.8455
Epoch 14
Train | loss = 0.4580
Test | loss = 0.9188
Epoch 15
Train | loss = 0.4372
Test | loss = 0.8939
Epoch 16
Train | loss = 0.4154
Test | loss = 0.8537
Epoch 17
Train | loss = 0.4053
Test | loss = 0.9452
Epoch 18
Train | loss = 0.3849
Test | loss = 0.9172
Epoch 19
Train | loss 