In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from model import AlphaNetwork

class AlphaNetwork(nn.Module):
    def __init__(self, input_dim):
        super(AlphaNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1).cuda()
        self.activation = nn.ReLU()

    def forward(self, W, C):
        x = torch.cat((W.view(W.size(0), -1), C.view(C.size(0), -1)), dim=1)
        alpha = self.activation(self.fc1(x))
        return alpha

K = 2
N = 3

input_dim = N * N * K + K * K * N 

# Simplified test
model = AlphaNetwork(input_dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy training data generation
num_samples = 10000
batch_size = 32
num_batches = num_samples // batch_size

W_data = torch.randn((num_samples, N, N, K), device='cuda:0', requires_grad=False)
C_data = torch.randn((num_samples, K, K, N), device='cuda:0', requires_grad=False)
alpha_gt = torch.randn((num_samples, 1), device='cuda:0', requires_grad=False)

# Training loop
num_epochs = 15
for epoch in range(num_epochs):
    total_loss = 0.0
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size

        # Sélectionner le mini-lot
        W_batch = W_data[start_idx:end_idx]
        C_batch = C_data[start_idx:end_idx]
        alpha_true_batch = alpha_gt[start_idx:end_idx]

        # Prédiction de l'alpha
        alpha_pred_batch = model(W_batch, C_batch)

        # Calcul de la perte
        loss = nn.MSELoss()(alpha_pred_batch, alpha_true_batch)
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Affichage de la perte moyenne de l'époque
    epoch_loss = total_loss / num_batches
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}")

# Vérification des gradients
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Gradients for {name}: {param.grad}")
    else:
        print(f"No gradients for {name}")


Epoch [1/15], Loss: 1.042522204036896
Epoch [2/15], Loss: 1.0023178787758718
Epoch [3/15], Loss: 1.0003985373828657
Epoch [4/15], Loss: 1.0000962380988476
Epoch [5/15], Loss: 0.999667405604552
Epoch [6/15], Loss: 0.9994626534290802
Epoch [7/15], Loss: 0.9991869853857236
Epoch [8/15], Loss: 0.9990181802557065
Epoch [9/15], Loss: 0.998842232597944
Epoch [10/15], Loss: 0.9985905465407249
Epoch [11/15], Loss: 0.9983064592457734
Epoch [12/15], Loss: 0.9980253097720635
Epoch [13/15], Loss: 0.9977899610232084
Epoch [14/15], Loss: 0.997697713952034
Epoch [15/15], Loss: 0.9977220912010242
Gradients for fc1.weight: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0.]], device='cuda:0')
Gradients for fc1.bias: tensor([0.], device='cuda:0')
