In [1]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from src.model_builder import Generator, Critic
from eval.eval import *
from src.model_utils import *
from src.data_setup import *
from src.utils import *

In [2]:
# simulation params
L = 23
Nr = Nt = 4
NUM_ANTENNA_PAIRS = Nr * Nt
z_dim = 50
EMBED_DIM = 4
HIDDEN_DIM = 100
BATCH_SIZE = 12000
T = 128
N_CRITIC = 25

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
# Create an instance of the model
g = Generator(Nr, Nt, L, z_dim, EMBED_DIM, HIDDEN_DIM).to(device)
g = g.double()
# Load the state dictionary
state_dict = torch.load(os.path.join('models', 'G_B0L1_BSZ512_EMB4_Z50.pt'))
g.load_state_dict(state_dict)

g.eval()

Generator(
  (embedding): Linear(in_features=16, out_features=4, bias=True)
  (main): Sequential(
    (0): Linear(in_features=54, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=46, bias=True)
  )
)

In [5]:
c = Critic(T+Nr, Nr).to(device)
c = c.double()
state_dict = torch.load(os.path.join('models', 'C_B0L1_BSZ512_EMB4_Z50.pt'))
c.load_state_dict(state_dict)

c.eval()

Critic(
  (embedding): Linear(in_features=4, out_features=4, bias=True)
  (main): Sequential(
    (0): Linear(in_features=268, out_features=100, bias=True)
    (1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=100, out_features=1, bias=True)
  )
)

In [6]:
import torch.nn as nn

from tqdm.auto import tqdm
from typing import Dict, List, Tuple

def train_WGAN_GP(generator: nn.Module,
                  critic: nn.Module,
                  train_dataloader: torch.utils.data.DataLoader, 
                  val_dataloader: torch.utils.data.DataLoader, 
                  sample_indices: torch.Tensor,
                  epochs: int,
                  device: torch.device) -> Dict[str, List]:

    # setting up transmitted signal - unit power discrete impulse
    input_signal = torch.zeros(1, Nt, T, device=device, dtype=torch.complex128)
    input_signal[0,0,0] = 1
    input_signal[0,1,12] = 1 #12
    input_signal[0,2,25] = 1 #25
    input_signal[0,3,39] = 1 #39

    ij_matrix_full = torch.eye(NUM_ANTENNA_PAIRS, dtype=torch.float64, device=device).repeat(BATCH_SIZE, 1) 
    i_matrix_full = torch.eye(Nr, dtype=torch.float64, device=device).repeat(BATCH_SIZE, 1)

    generator.train()
    critic.train()

    c_losses = []
    g_losses = []

    for epoch in range(epochs):
        c_loss = g_loss = 0
        for batch_idx, batch_real in enumerate(tqdm(train_dataloader)):
            
            batch_real = batch_real.to(device) # shape [BATCH_SIZE, Nr, T+Nr]
            cur_batch_size = batch_real.shape[0]
            # setting up conditioning information
            ij_matrix = ij_matrix_full[:(cur_batch_size*NUM_ANTENNA_PAIRS)]
            i_matrix = i_matrix_full[:(cur_batch_size*Nr)] 

            # Train Critic: max E[critic(real)] - E[critic(fake)]
            # equivalent to minimizing the negative of that
            for _ in range(N_CRITIC):

                # generating a batch of fake data
                z = torch.randn(cur_batch_size*NUM_ANTENNA_PAIRS, z_dim, dtype=torch.float64, device=device)

                channel_tensor = generator(z, ij_matrix)
                batch_fake = get_fake_batch(input_signal, channel_tensor, sample_indices)
                
                # interleave real and imaginary
                batch_real_int = prepare_complex_signal(batch_real).view(cur_batch_size*Nr, -1)
                batch_fake_int = prepare_complex_signal(batch_fake).view(cur_batch_size*Nr, -1)

                # calculating critic loss
                critic_real = critic(batch_real_int, i_matrix).view(-1)
                critic_fake = critic(batch_fake_int, i_matrix).view(-1)
                gp = gradient_penalty(critic, batch_real, batch_fake, i_matrix)
                critic_loss = (-(torch.mean(critic_real) - torch.mean(critic_fake))) + (10 * gp)
                c_loss += critic_loss.item()

            # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
            gen_fake = critic(batch_fake_int, i_matrix).view(-1)
            gen_loss = -(torch.mean(gen_fake))
            g_loss += gen_loss.item()

        c_loss = c_loss / (N_CRITIC * len(train_dataloader))
        g_loss = g_loss / len(train_dataloader)
        c_losses.append(c_loss)
        g_losses.append(g_loss)

        print(f"Epoch [{epoch+1}/{epochs}] \ Loss D: {c_loss:.4f}, loss G: {g_loss:.4f}")

    return c_losses, g_losses

In [7]:
# Setup directories
train_dataset_path = os.path.join("Dataset", "train_data_TDL_A.mat")
test_dataset_path = os.path.join("Dataset", "test_data_TDL_A.mat")
val_dataset_path = os.path.join("Dataset", "val_data_TDL_A.mat")

train_dataloader, _, val_dataloader = create_dataloaders(train_dataset_path, test_dataset_path, val_dataset_path,
                                                                    "rx_train_data", "rx_test_data", "rx_val_data",
                                                                        BATCH_SIZE, 0)

# Intialize Models (Generator & Critic)
generator = Generator(Nr=Nr, Nt=Nt, l=L, z_dim=z_dim, embed_dim=EMBED_DIM).to(device)
generator = generator.double()  # Converts all parameters to torch.float64

critic = Critic(N=T+Nr, num_receive_antennas=Nr, embed_dim=EMBED_DIM).to(device)
critic = critic.double() # Converts all parameters to torch.float64

# train models
sample_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 14, 17, 19, 20, 22, 23, 28, 35, 38, 42, 44, 46, 49, 89], device=device) 
c_losses, g_losses = train_WGAN_GP(g, c, train_dataloader=val_dataloader, val_dataloader=val_dataloader, sample_indices=sample_indices, epochs=500, device=device)

  0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: 