In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler
from scipy.io import savemat
import random

def h_single(t, r, wavelength=0.01, device='cpu'):
    eta = 120 * torch.pi  # Free space impedance
    k0 = 2 * torch.pi / wavelength

    # polarization vector e (y-direction)
    e = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32, device=device).view(1, 3, 1)

    # Distance vector and norm
    diff = r - t                             # (N, 3)
    norm = torch.norm(diff, dim=1, keepdim=True)  # (N, 1)
    norm_squeezed = norm.squeeze(1)  # (N,)

    # Phase term: e^{-j k0 r} = cos(-k0 r) + j sin(-k0 r)
    phase_real = torch.cos(-k0 * norm_squeezed)  # (N,)
    phase_imag = torch.sin(-k0 * norm_squeezed)  # (N,)
    phase = torch.complex(phase_real, phase_imag)  # (N,)

    # Projection matrix
    diff_u = diff.unsqueeze(2)                            # (N, 3, 1)
    norm_sq = norm.pow(2).unsqueeze(2)                    # (N, 1, 1)
    outer = torch.bmm(diff_u, diff_u.transpose(1, 2)) / norm_sq  # (N, 3, 3)
    I = torch.eye(3, device=device).expand(t.shape[0], 3, 3)         # (N, 3, 3)
    proj = I - outer                                       # (N, 3, 3)

    # Polarization response
    e_expanded = e.expand(t.shape[0], -1, -1)              # (N, 3, 1)
    polarization = torch.bmm(e_expanded.transpose(1, 2), torch.bmm(proj, e_expanded)).squeeze()  # (N,)

    # Final coefficient
    coeff = -1j * eta / (2 * wavelength * norm_squeezed)  # (N,)
    
    return (coeff * phase * polarization).unsqueeze(1)

# 정적분 변환
def evaluate_integral(model, user_1, user_2, device, is_complex=True):
    N = 10
    tx_vals = [-N, N]
    ty_vals = [-N, N]
    inputs = []

    for tx in tx_vals:
        for ty in ty_vals:
            t = torch.tensor([[tx, ty]], dtype=torch.float32).to(device)
            J_input = torch.cat([user_1, user_2], dim=1).to(device)
            full_input = torch.cat([t, J_input], dim=1)
            inputs.append(full_input)

    inputs = torch.cat(inputs, dim=0)  # shape: (4, input_dim)
    outputs = model(inputs)

    if is_complex:
        result = outputs[3] - outputs[1] - outputs[2] + outputs[0]
    else:
        result = outputs[3] - outputs[1] - outputs[2] + outputs[0]
        result = result.squeeze()  # (1,)

    return result

# 이중 미분 함수
def compute_second_order_derivative(output, t_x, t_y):
    if torch.is_complex(output):
        results = []
        for i in range(output.shape[1]):  # for each output dimension
            grad_tx_real = torch.autograd.grad(output[:, i].real.sum(), t_x, create_graph=True, retain_graph=True, allow_unused=True)[0]
            grad_ty_real = torch.autograd.grad(grad_tx_real.sum(), t_y, create_graph=True, retain_graph=True, allow_unused=True)[0]

            grad_tx_imag = torch.autograd.grad(output[:, i].imag.sum(), t_x, create_graph=True, retain_graph=True, allow_unused=True)[0]
            grad_ty_imag = torch.autograd.grad(grad_tx_imag.sum(), t_y, create_graph=True, retain_graph=True, allow_unused=True)[0]

            d2 = torch.complex(grad_ty_real, grad_ty_imag)  # shape: (1000,)
            results.append(d2.unsqueeze(1))  # shape: (1000,1)

        return torch.cat(results, dim=1)  # (1000, 4)

    else:
        results = []
        for i in range(output.shape[1]):
            grad_tx = torch.autograd.grad(output[:, i].sum(), t_x, create_graph=True, retain_graph=True)[0]
            grad_ty = torch.autograd.grad(grad_tx.sum(), t_y, create_graph=True)[0]
            results.append(grad_ty.unsqueeze(1))
        return torch.cat(results, dim=1)  # (1000, 4)

def generate_coordinates(num_points, x_range, y_range):
    x = np.random.uniform(x_range[0], x_range[1], num_points)
    y = np.random.uniform(y_range[0], y_range[1], num_points)
    coordinates = np.stack((x, y), axis=1)  # (num_points, 2)
    return torch.tensor(coordinates, dtype=torch.float32)

def generate_user_pairs(batch_size):
    user_pairs = generate_coordinates(batch_size * 2, x_range=(-1, 1), y_range=(-1, 1))
    return user_pairs.view(batch_size, 2, 2)  # (batch_size, 2, 2)

loss_history = {"LR1": [], "LR2": [], "LR3": [], "LR4": [], "LR5": [], "L_ic_op": [], "L_ic1": [], "L_ic2": []}

class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features, activation=nn.SiLU()):
        super(ResidualBlock, self).__init__()
        self.activation = activation
        self.fc1 = nn.Linear(in_features, out_features)
        self.fc2 = nn.Linear(out_features, out_features)
        if in_features != out_features:
            self.skip = nn.Linear(in_features, out_features)
        else:
            self.skip = nn.Identity()

    def forward(self, x):
        residual = self.skip(x)
        out = self.activation(self.fc1(x))
        out = self.activation(self.fc2(out))
        out = out + residual
        return out


In [2]:
class Complex_Model(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_blocks, act=nn.SiLU(), out_act=None):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.blocks = nn.ModuleList([ResidualBlock(hidden_dim, hidden_dim, act) for _ in range(num_blocks)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.act = act
        self.out_act = out_act

    def forward(self, x):
        x = self.act(self.input_layer(x))
        for block in self.blocks:
            x = block(x)
        x = self.output_layer(x)

        # 출력 차원 처리
        batch_size = x.shape[0]

        # 홀수이면 그대로 반환
        if x.shape[1] % 2 != 0:
            return x

        # 짝수이면 복소수로 변환
        half_dim = x.shape[1] // 2  # 절반 크기
        x = x.view(batch_size, half_dim, 2)

        # 복소수 변환
        real_part = x[:, :, 0].unsqueeze(2)  # 실수부
        imag_part = x[:, :, 1].unsqueeze(2)  # 허수부
        complex_output = torch.complex(real_part, imag_part)

        return complex_output

class Real_Model(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_blocks, act=nn.SiLU(), out_act=None):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.blocks = nn.ModuleList([ResidualBlock(hidden_dim, hidden_dim, act) for _ in range(num_blocks)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.act = act
        self.out_act = out_act

    def forward(self, x):
        x = self.act(self.input_layer(x))
        for block in self.blocks:
            x = block(x)
        x = self.output_layer(x)
        return x


In [3]:
mse = nn.MSELoss()

Signal = Complex_Model(input_dim=6, output_dim=4, hidden_dim=256, num_blocks=2)
mp_integral = Complex_Model(input_dim=6, output_dim=8, hidden_dim=256, num_blocks=2)
objective_integral = Real_Model(input_dim=6, output_dim=1, hidden_dim=256, num_blocks=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Signal.to(device)
mp_integral.to(device)
objective_integral.to(device)

optimizer1 = optim.Adam(Signal.parameters(), lr=1e-3)
optimizer2 = optim.Adam(mp_integral.parameters(), lr=1e-3)
optimizer3 = optim.Adam(objective_integral.parameters(), lr=1e-3)

N = 10

# SINR Constraint
gamma_db_1 = 5  # dB
gamma_db_2 = 5  # dB
gamma_p_1 = 10**(gamma_db_1 / 10)
gamma_p_2 = 10**(gamma_db_2 / 10)

sigma_sq = 1

# Number of Users
P = 2

gamma_list = torch.tensor([gamma_p_1, gamma_p_2]) 

C_list = []

for p in range(P):
    c = torch.ones(P) / sigma_sq
    c[p] = -1.0 / (gamma_list[p] * sigma_sq)
    C_p = torch.diag(c).to(device)
    C_list.append(C_p)
    
# C matrix for Constraint
C_1, C_2 = C_list

iterations = 100000
num_tx = 100
batch_size = 20
Total = []

for epoch in range(iterations):

    # Gradient-Zero Condition
    optimizer1.zero_grad()
    optimizer2.zero_grad()    
    optimizer3.zero_grad()
    
    t_n_x = np.random.uniform(low=-10, high=10, size=(100, 1))
    t_n_y = np.random.uniform(low=-10, high=10, size=(100, 1))

    tx = torch.tensor(t_n_x, dtype=torch.float32, requires_grad=True).to(device)
    ty = torch.tensor(t_n_y, dtype=torch.float32, requires_grad=True).to(device)
    t = torch.cat([tx, ty], dim=1)
    tz = torch.zeros_like(t[:, :1]).to(device).requires_grad_(True)
    t_3d = torch.cat([t, tz], dim=1)

    user_pairs = generate_user_pairs(batch_size).to(device).requires_grad_(True)

    L_ic1 = L_ic2 = L_ic_op = LR1 = LR2 = LR3 = LR4 = LR5 = 0.0

    for pair in user_pairs:
        user10, user20 = pair.unbind(0)
        user10 = user10.unsqueeze(0)
        user20 = user20.unsqueeze(0)

        pair = pair.unsqueeze(0).expand(num_tx, -1, -1)
        user1 = pair[:, 0, :]
        user2 = pair[:, 1, :]

        z = torch.full((100, 1), 30.0)
        user1_3d = torch.cat([user1, z], dim=1)
        user2_3d = torch.cat([user2, z], dim=1)

        # Channel
        h1 = h_single(t_3d, user1_3d, wavelength=0.01, device=device)
        h2 = h_single(t_3d, user2_3d, wavelength=0.01, device=device)

        flattened_view = pair.view(100, -1).requires_grad_(True)
        Input = torch.cat([t, flattened_view], dim=1)

        Signal_output = Signal(Input).squeeze(-1)
        mp_integral_output = mp_integral(Input).squeeze(-1)
        objective_integral_output = objective_integral(Input).squeeze(-1)

        # Signal
        q1, q2 = Signal_output[:, 0], Signal_output[:, 1]

        # 2-time Derivative
        results = []
        for i in range(mp_integral_output.shape[1]):
            # 실수부 미분
            grad_tx_real = torch.autograd.grad(mp_integral_output[:, i].real.sum(), tx, create_graph=True)[0]
            grad_ty_real = torch.autograd.grad(grad_tx_real.sum(), ty, create_graph=True)[0]

            # 허수부 미분
            grad_tx_imag = torch.autograd.grad(mp_integral_output[:, i].imag.sum(), tx, create_graph=True)[0]
            grad_ty_imag = torch.autograd.grad(grad_tx_imag.sum(), ty, create_graph=True)[0]

            # 복소수 결합
            d2 = torch.complex(grad_ty_real, grad_ty_imag)
            results.append(d2.unsqueeze(1))

        # 최종 결합
        d2_mp = torch.cat(results, dim=1)

        grad_tx = torch.autograd.grad(objective_integral_output.sum(), tx, create_graph=True)[0]
        grad_ty = torch.autograd.grad(grad_tx.sum(), ty, create_graph=True)[0]
        d2_obj = grad_ty

        # integrate
        int_mp = evaluate_integral(mp_integral, user10, user20, device, is_complex=True)
        int_obj = evaluate_integral(objective_integral, user10, user20, device, is_complex=False)

        M = int_mp.view(2, 2)
        C_1 = C_1.to(dtype=torch.complex64, device=device)
        C_2 = C_2.to(dtype=torch.complex64, device=device)
        term1 = torch.conj(M[0]) @ C_1 @ M[0]
        term2 = torch.conj(M[1]) @ C_2 @ M[1]

        L_ic1 += torch.relu(term1.real + 1)
        L_ic2 += torch.relu(term2.real + 1)
        L_ic_op += int_obj

        LR1 += mse(q1.real**2 + q1.imag**2 + q2.real**2 + q2.imag**2, d2_obj.view(-1))
        LR2 += mse((torch.conj(h1) * q1.unsqueeze(1)).real, d2_mp[:,0].real) + mse((torch.conj(h1) * q1.unsqueeze(1)).imag, d2_mp[:,0].imag)
        LR3 += mse((torch.conj(h1) * q2.unsqueeze(1)).real, d2_mp[:,1].real) + mse((torch.conj(h1) * q2.unsqueeze(1)).imag, d2_mp[:,1].imag)
        LR4 += mse((torch.conj(h2) * q1.unsqueeze(1)).real, d2_mp[:,2].real) + mse((torch.conj(h2) * q1.unsqueeze(1)).imag, d2_mp[:,2].imag)
        LR5 += mse((torch.conj(h2) * q2.unsqueeze(1)).real, d2_mp[:,3].real) + mse((torch.conj(h2) * q2.unsqueeze(1)).imag, d2_mp[:,3].imag)

    for key, value in zip(loss_history.keys(), [LR1, LR2, LR3, LR4, LR5, L_ic_op, L_ic1, L_ic2]):
        loss_history[key].append(value.item())

    Total_Loss = LR1 + 3*LR2 + 3*LR3 + 3*LR4 + 3*LR5 + L_ic1 + L_ic2 + 1e-8 * L_ic_op

    Total_Loss.backward()
    optimizer1.step()
    optimizer2.step()
    optimizer3.step()

    Total.append(Total_Loss.item())

    with torch.no_grad():
        print(epoch, f"Total Loss: {Total_Loss.item():.6f}")

    if epoch % 100 == 0:
        plt.figure(figsize=(8, 6))
        plt.plot(Total, label="Total Loss", color=np.random.rand(3,))
        plt.xlabel("Epochs")
        plt.ylabel("Log Loss Value")
        plt.title("Loss during Training")
        plt.legend()
        plt.grid()
        plt.show()

    if epoch % 100 == 0:
        for key, values in loss_history.items():
            plt.figure(figsize=(8, 6))
            if key != "L_ic_op":
                values = np.log10(np.maximum(np.abs(values), 1e-16))
            plt.plot(values, label=key, color=np.random.rand(3,))
            plt.xlabel("Epochs")
            plt.ylabel("Log Loss Value" if key != "L_ic_op" else "Loss Value")
            plt.title(f"{key} During Training")
            plt.legend()
            plt.grid()
            plt.show()

for key, values in loss_history.items():
    plt.figure(figsize=(8, 5))
    if key != "L_ic_op":
        values = np.log10(np.maximum(np.abs(values), 1e-16))
    plt.plot(values, label=key, color=np.random.rand(3,))
    plt.xlabel("Epochs")
    plt.ylabel("Log Loss Value" if key != "L_ic_op" else "Loss Value")
    plt.title(f"{key} During Training")
    plt.legend()
    plt.grid()
    plt.savefig(f"{key.replace(' ', '_')}_history.png")
    plt.show()


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)

True
NVIDIA GeForce RTX 4080
