# Train Distributed U-net for: 
## 3 APs (Mt), 8 antennas per AP (Nt), and 5 users (U)

### 1. Import dependancies

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as loader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import scipy.io
import csv
import h5py # Needed to read matlab v7.3 files

### 2. Import dataset

In [None]:
num_data = 20000

# Change Nt and U if necessary
Nt=8  # Num. of antennas per AP
U=5    # Num. of users
T=1    # Num. of targets (always 1)

class CF_ISACDataset(loader.Dataset):
  def __init__(self):   
    
    filepath =  f'../sample_datasets/unsup_dataset_U5_T1_L3_ant8_R2x16.7m.mat'
    in_file = {}
    f = h5py.File(filepath)
    for k, v in f.items():
      print(k)
      in_file[k] = np.array(v).T

    Hcomm = in_file['H'][0:num_data,:,:,:].view('complex')  # data size x num users x num APs x num antennas
    Hcomm_AP0 = in_file['H'][0:num_data,:,0,:].view('complex')  # data size x num users x num antennas
    Hcomm_AP1 = in_file['H'][0:num_data,:,1,:].view('complex')  # data size x num users x num antennas
    Hcomm_AP2 = in_file['H'][0:num_data,:,2,:].view('complex')  # data size x num users x num antennas
    sensing_beamsteering = in_file['a'][0:num_data,:,:,:].view('complex') # data size x num targets x num APs x num antennas
    dad = in_file['DAD'][0:num_data,:,:,:].view('complex')      # data size x num APs x (num APs*num antennas) x (num APs*num antennas)


    input_features0 = np.zeros((num_data,2,U+T,Nt))
    input_features1 = np.zeros((num_data,2,U+T,Nt))
    input_features2 = np.zeros((num_data,2,U+T,Nt))

    Hcomm_AP0_re = np.real(Hcomm_AP0)
    Hcomm_AP1_re = np.real(Hcomm_AP1)
    Hcomm_AP2_re = np.real(Hcomm_AP2)

    a_AP0_re = np.real(sensing_beamsteering[:,:,0,:])
    a_AP1_re = np.real(sensing_beamsteering[:,:,1,:])
    a_AP2_re = np.real(sensing_beamsteering[:,:,2,:])

    Hcomm_AP0_im = np.imag(Hcomm_AP0)
    Hcomm_AP1_im = np.imag(Hcomm_AP1)
    Hcomm_AP2_im = np.imag(Hcomm_AP2)

    a_AP0_im = np.imag(sensing_beamsteering[:,:,0,:])
    a_AP1_im = np.imag(sensing_beamsteering[:,:,1,:])
    a_AP2_im = np.imag(sensing_beamsteering[:,:,2,:])


    input_features0[:,0,:,:] = np.concatenate((Hcomm_AP0_re,a_AP0_re),axis=1)
    input_features0[:,1,:,:] = np.concatenate((Hcomm_AP0_im,a_AP0_im),axis=1)

    input_features1[:,0,:,:] = np.concatenate((Hcomm_AP1_re,a_AP1_re),axis=1)
    input_features1[:,1,:,:] = np.concatenate((Hcomm_AP1_im,a_AP1_im),axis=1)

    input_features2[:,0,:,:] = np.concatenate((Hcomm_AP2_re,a_AP2_re),axis=1)
    input_features2[:,1,:,:] = np.concatenate((Hcomm_AP2_im,a_AP2_im),axis=1)



    self.x0 = torch.from_numpy(input_features0).float()
    self.x1 = torch.from_numpy(input_features1).float()
    self.x2 = torch.from_numpy(input_features2).float()

    self.Hcomm = torch.from_numpy(Hcomm).type(torch.complex64)
    self.DAD = torch.from_numpy(dad).type(torch.complex64)

    self.dat_size = self.x0.shape[0]
    print(f'Total data points = {self.dat_size} out of {in_file['H'].shape[0]} points in total')
    print(f'x0 shape = {self.x0.shape}')
    print(f'H_comm shape = {self.Hcomm.shape} [Complex]')
    print(f'Dmt*A*Dmt  shape = {self.DAD.shape} [Complex]')
    
  def __getitem__(self,index):
    return self.x0[index,:], self.x1[index,:], self.x2[index,:], self.Hcomm[index,:], self.DAD[index,:]

  def __len__(self):
    return self.dat_size

cf_isac_dataset = CF_ISACDataset()
N_points = len(cf_isac_dataset)

### 3. Dataloader constructor and device setup

In [None]:
p = int(0.97*N_points) # Number of points reserved for training

batch_size = 500 # How many data points taken at once

# Dataloaders
train, test =loader.random_split(cf_isac_dataset, [p,N_points-p])
train_dataloader = loader.DataLoader(dataset=train, batch_size=batch_size, shuffle=True)
test_dataloader = loader.DataLoader(dataset=test, batch_size=1, shuffle=False)
print(f'Batch size : {batch_size}')
print(f'Train size: {p}')
print(f'Validation size: {N_points-p}')
train_ind = train_dataloader.dataset.indices
test_ind = test_dataloader.dataset.indices

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


### 3.1. Saving the dataloaders (Optional)

In [None]:
validation_data_dict = {
    'x0':test_dataloader.dataset.dataset.x0[test_ind,:],
    'x1':test_dataloader.dataset.dataset.x1[test_ind,:],
    'x2':test_dataloader.dataset.dataset.x2[test_ind,:],
    
    'H':test_dataloader.dataset.dataset.Hcomm[test_ind,:,:,:],
    'DAD':test_dataloader.dataset.dataset.DAD[test_ind,:,:,:],
    
}
valid_data_path = f'validation_data_dict4APs.pth'  # Change the directory and the file name if necessary
torch.save(validation_data_dict,valid_data_path)

training_data_dict = {
    'x0':test_dataloader.dataset.dataset.x0[train_ind,:],
    'x1':test_dataloader.dataset.dataset.x1[train_ind,:],
    'x2':test_dataloader.dataset.dataset.x2[train_ind,:],

    'H':test_dataloader.dataset.dataset.Hcomm[train_ind,:,:,:],
    'DAD':test_dataloader.dataset.dataset.DAD[train_ind,:,:,:],
    
    
}
train_data_path = f'training_data_dict4APs.pth'  # Change the directory and the file name if necessary
torch.save(training_data_dict,train_data_path)

### 3.2 Upload the saved dataloaders (Run if 3.1 was run)

In [None]:
data_path = f'training_data_dict4APs.pth' # Change the directory and the file name if necessary
loaded_data = torch.load(data_path)
train_dataloader.dataset.dataset.x0[train_ind,:] = loaded_data['x0']
train_dataloader.dataset.dataset.x1[train_ind,:] = loaded_data['x1']
train_dataloader.dataset.dataset.x2[train_ind,:] = loaded_data['x2']

train_dataloader.dataset.dataset.Hcomm[train_ind,:,:,:] = loaded_data['H']
train_dataloader.dataset.dataset.DAD[train_ind,:,:,:] = loaded_data['DAD']

data_path = f'validation_data_dict4APs.pth' # Change the directory and the file name if necessary
loaded_data = torch.load(data_path)
test_dataloader.dataset.dataset.x0[test_ind,:] = loaded_data['x0']
test_dataloader.dataset.dataset.x1[test_ind,:] = loaded_data['x1']
test_dataloader.dataset.dataset.x2[test_ind,:] = loaded_data['x2']

test_dataloader.dataset.dataset.Hcomm[test_ind,:,:,:] = loaded_data['H']
test_dataloader.dataset.dataset.DAD[test_ind,:,:,:] = loaded_data['DAD']

### 4. U-net (F3)

As denoted in the paper:

F3: f=3, P=1

F5: f=5, P=2

F7: f=7, P=3

.
.
.

In [None]:
P_total = 1 #Watts

f = 3 # Filter size
P = 1 # Padding
S = 1 # Stride (always 1)

class DownConvFirst(nn.Module): # 1st arrow class
  def __init__(self,in_ch,out_ch):
    super(DownConvFirst, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, f, S, P),
        nn.LeakyReLU(0.2)
    )
  def forward(self,x):
    return self.conv(x)

class DownConvMiddle(nn.Module): # 2nd arrow class
  def __init__(self,in_ch,out_ch):
    super(DownConvMiddle, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, f, S, P),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(0.2)
    )
  def forward(self,x):
    return self.conv(x)

class DownConvFinal(nn.Module): # 3rd arrow class
  def __init__(self,in_ch,out_ch):
    super(DownConvFinal, self).__init__()
    self.conv = nn.Conv2d(in_ch, out_ch, f, S, P)
  def forward(self,x):
    return self.conv(x)

class UpConvFirst(nn.Module): # 4th arrow class (Same-convolution = no 2D cropping)
  def __init__(self,in_ch,out_ch,p=None):
    super(UpConvFirst, self).__init__()
    if p is None:
      p=0.5
    self.conv = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(in_ch,out_ch, f, S, P),
        nn.BatchNorm2d(out_ch),
        nn.Dropout2d(p)
    )
  def forward(self,x):
    return self.conv(x)

class UpConvMiddle(nn.Module): # 5th arrow class (Same-convolution = no 2D cropping)
  def __init__(self,in_ch,out_ch):
    super(UpConvMiddle, self).__init__()
    self.conv = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(in_ch,out_ch, f, S, P),
        nn.BatchNorm2d(out_ch)
    )
  def forward(self,x):
    return self.conv(x)

class UpConvFinal(nn.Module): # 6th arrow class (Same-convolution = no 2D cropping)
  def __init__(self,in_ch,out_ch):
    super(UpConvFinal, self).__init__()
    self.conv = nn.Sequential(
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(in_ch,out_ch, f, S, P),
    )
  def forward(self,x):
    return self.conv(x)

class UpConvHidden(nn.Module):
  def __init__(self,in_ch,out_ch):
    super(UpConvHidden, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_ch,out_ch, f, S, P),
        nn.LeakyReLU(0.2)
    )
  def forward(self,x):
    return self.conv(x)




class UNet(nn.Module):
  def __init__(self, in_ch=2, out_ch=2, filters=[16,32,64,128]):
    super(UNet, self).__init__()
    self.up_layers = nn.ModuleList()
    self.down_layers = nn.ModuleList()
    self.up_hidden = nn.ModuleList()

    # Descending part:
    for i,filter in enumerate(filters):
      if i == 0:
        self.down_layers.append(DownConvFirst(in_ch,filter))
      else:
        self.down_layers.append(DownConvMiddle(in_ch,filter))
      in_ch = filter

    # Bottom part:
    self.bottom = DownConvFinal(filters[-1],filters[-1])

    # Ascending part:
    self.up_layers.append(UpConvFirst(filters[-1],filters[-1]))
    self.up_hidden.append(UpConvHidden(2*filters[-1],filters[-1]))
    for i,filter in enumerate(list(reversed(filters))):
      if filter==filters[0]:
        self.up_layers.append(UpConvFinal(filter,out_ch))
      else:
        self.up_layers.append(UpConvMiddle(filter,filter//2))
        self.up_hidden.append(UpConvHidden(filter,filter//2))

  def forward(self,x):
    conc_save = []
    for down in self.down_layers:
      x = down(x)
      conc_save.append(x)

    x = self.bottom(x)
    conc_save = conc_save[::-1]

    for i,up in enumerate(self.up_layers):

      x = up(x)
      if i<len(self.up_layers)-1:
        x = torch.cat((conc_save[i],x), dim=1)
        x = self.up_hidden[i](x)
    F = P_total * x/(x**2).sum(dim=[1,2,3],keepdim=True)**(0.5)  # Ensures the power constraint
    return F.float()

### 5. Teacher training

### 5.1 Loss function, loss validation function, and validation regime

In [None]:
# Loss function

sigmasq_ue = 1
sigmasq_radar_rcs = 0.1000

def unsup_loss(DAD,H_comm,f_pred,alpha):
  M_t = f_pred.shape[-1]
  batch_size = f_pred.shape[0]
  f_pred_complex = torch.complex(f_pred[:,0,:,:,:],f_pred[:,1,:,:,:])
  f_strems_anten = f_pred_complex.reshape((batch_size,U+T,Nt*M_t))
  F_sum = torch.matmul(torch.conj(torch.transpose(f_strems_anten,1,2)),f_strems_anten).to(device)

  SSNR = 0
  SINR = torch.zeros((batch_size,U),device=device)
  for mt in range(M_t):
    SSNR = SSNR + (torch.matmul(DAD[:,mt,:,:],F_sum)).diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) # torch.trace() alternative

  p0 = -SSNR.real*sigmasq_radar_rcs

  H_transposed = torch.transpose(H_comm,2,3)
  H_st = H_transposed.reshape(batch_size,U,-1)

  for u in range(U):
    h_u = H_st[:,u:u+1,:].to(device)
    f_u = f_strems_anten[:,u:u+1,:].to(device)
    SINRu_num = torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_u,1,2)))**2
    L_u=0
    for st in range(U+T):
      if st != u :
        f_int = f_strems_anten[:,st:st+1,:].to(device)
        L_u = L_u + torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_int,1,2)))**2
    SINR[:,u] = SINRu_num[:,0,0] / (L_u[:,0,0]+sigmasq_ue)

  SINR_min,_ = torch.min(SINR,dim=1)
  p1 = -SINR_min[:,None]

  L = (1-alpha)*torch.mean(p0) + alpha*torch.mean(p1)

  return L, -torch.mean(p0).detach(), torch.mean(SINR_min).detach()



# Loss validation function


def validate_metric(DAD,H_comm,f_pred,alpha):

  M_t = f_pred.shape[-1]
  batch_size = f_pred.shape[0]

  f_pred_complex = torch.complex(f_pred[:,0,:,:,:],f_pred[:,1,:,:,:])
  f_strems_anten = f_pred_complex.reshape((batch_size,U+T,Nt*M_t))
  F_sum = torch.matmul(torch.conj(torch.transpose(f_strems_anten,1,2)),f_strems_anten).to(device)

  SSNR = 0
  SINR = torch.zeros((batch_size,U),device=device)

  for mt in range(M_t):
    SSNR = SSNR + (torch.matmul(DAD[:,mt,:,:],F_sum)).diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

  p0 = -SSNR.real*sigmasq_radar_rcs

  H_transposed = torch.transpose(H_comm,2,3)
  H_st = H_transposed.reshape(batch_size,U,-1)

  for u in range(U):
    h_u = H_st[:,u:u+1,:].to(device)
    f_u = f_strems_anten[:,u:u+1,:].to(device)
    SINRu_num = torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_u,1,2)))**2
    L_u=0

    for st in range(U+T):
      if st != u :
        f_int = f_strems_anten[:,st:st+1,:].to(device)
        L_u = L_u + torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_int,1,2)))**2
    SINR[:,u] = SINRu_num / (L_u+sigmasq_ue)

  vio1 = -torch.min(SINR)
  L_valid = (1-alpha)*p0 + alpha*vio1

  return L_valid, -p0, -vio1



# Validation regime

def validate(model0,model1,model2,test_dataloader,alpha):
  SSNR_estimate_ten = []
  num_vio1 = []
  L_valid_sum=0
  with torch.no_grad():
    model0.eval()
    model1.eval()
    model2.eval()
    
    mse_min = 100000
    mse_max = 0
    for i, (x0_test, x1_test,x2_test, H_comm, DAD) in enumerate(test_dataloader):
      x0 = x0_test.to(device)  
      x1 = x1_test.to(device)  
      x2 = x2_test.to(device)   


      y0_predict = model0(x0)
      y1_predict = model1(x1)
      y2_predict = model2(x2)

        
      y_predict = torch.stack((y0_predict,y1_predict,y2_predict),dim=4)
      L_valid, SSNR_estimate,vio1 = validate_metric(DAD.to(device),H_comm.to(device),y_predict.to(device),alpha)
      num_vio1 = np.append(num_vio1,vio1.to('cpu'))
      L_valid_sum+=L_valid

      SSNR_estimate_ten = np.append(SSNR_estimate_ten,SSNR_estimate.to('cpu'))

  return L_valid_sum.item()/(i+1), np.mean(SSNR_estimate_ten),np.mean(num_vio1),np.var(num_vio1)

### 5.2 SINR teacher training

In [None]:
M_t=3  # Num APs
seed=5
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
alpha_ten = [1.0]  # Add values for alpha_ten for manual grid search rather than dynamic balance (teacher-student training)
for alpha in alpha_ten:
    best_SSNR = -100000
    best_SINR = -100000
    best_loss = 100000
    combined_loss_ten = []
    sinr_tr_ten = []
    ssnr_tr_ten = []
    validation_combined_SSNR=[]
    validation_combined_SINR=[]
    L_valid_ten = []
    best_loss = 10000
    
    SINR_AP0 = UNet().to(device)
    SINR_AP1 = UNet().to(device)
    SINR_AP2 = UNet().to(device)

    weight_decay=0
    opt_adam_AP0 = torch.optim.Adam(SINR_AP0.parameters(), lr=1e-2)
    opt_adam_AP1 = torch.optim.Adam(SINR_AP1.parameters(), lr=1e-2)
    opt_adam_AP2 = torch.optim.Adam(SINR_AP2.parameters(), lr=1e-2)
    lr_scheduler_AP0 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP0, 'min', 0.1, 10, verbose=True)
    lr_scheduler_AP1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP1, 'min', 0.1, 10, verbose=True)
    lr_scheduler_AP2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP2, 'min', 0.1, 10, verbose=True)


    flag=0
    sinr_save=0
    counter1= 0
    counter2= 0
    disp_cycle = 5
    epochs = 2000
    patience=100

    for epoch in range(epochs):
      mean_loss=0
      SINR_AP0.train()
      SINR_AP1.train()
      SINR_AP2.train()
      for i, (x0_train, x1_train, x2_train, H_comm, DAD) in enumerate(train_dataloader):
        x0 = x0_train.to(device)  
        x1 = x1_train.to(device)  
        x2 = x2_train.to(device)    

        y0_estimate = SINR_AP0(x0)
        y1_estimate = SINR_AP1(x1)
        y2_estimate = SINR_AP2(x2)

        y_estimate = torch.stack((y0_estimate,y1_estimate,y2_estimate),dim=4)

        L_unsup, ssnr_tr, sinr_tr = unsup_loss(DAD.to(device),H_comm.to(device),y_estimate,alpha)

        mean_loss=mean_loss+L_unsup

        with torch.autograd.set_detect_anomaly(True):
            opt_adam_AP0.zero_grad()

            opt_adam_AP1.zero_grad()

            opt_adam_AP2.zero_grad()

            L_unsup.backward()

        opt_adam_AP0.step()
        opt_adam_AP1.step()
        opt_adam_AP2.step()


        if i%disp_cycle==0:
          print(f'ep {epoch+1} - p {i+1} - Loss = {L_unsup.item():.7f}, SSNR = {ssnr_tr.item():.7f}, SINR = {sinr_tr.item():.7f}')

      combined_loss_ten = np.append(combined_loss_ten,L_unsup.item())
      ssnr_tr_ten = np.append(ssnr_tr_ten,ssnr_tr.item())
      sinr_tr_ten = np.append(sinr_tr_ten,sinr_tr.item())

      lr_scheduler_AP0.step(mean_loss/(i+1))
      lr_scheduler_AP1.step(mean_loss/(i+1))
      lr_scheduler_AP2.step(mean_loss/(i+1))
      current_lr = opt_adam_AP0.param_groups[0]['lr']

      L_valid, SSNR_validate, SINR_valid, vio_var_valid = validate(SINR_AP0, SINR_AP1, SINR_AP2,test_dataloader,alpha)

      validation_combined_SSNR=np.append(validation_combined_SSNR,SSNR_validate)
      validation_combined_SINR=np.append(validation_combined_SINR,SINR_valid)
      L_valid_ten = np.append(L_valid_ten,L_valid)

      print(f'ep {epoch+1} (alpha={alpha}): Loss = {L_valid:.5f}. Validation SSNR = {SSNR_validate:.7f}, average min SINR: {SINR_valid.item():.7f}, lr = {current_lr}')

      if SSNR_validate > best_SSNR:
        print(f'Improved SSNR: from {best_SSNR:.5f} to {SSNR_validate:.5f} (delta = {(SSNR_validate-best_SSNR):.5f})')
        best_SSNR = SSNR_validate

        counter1=0
      else:
        counter1+=1

      if SINR_valid > best_SINR:
        print(f'Improved SINR: from {best_SINR:.5f} to {SINR_valid:.5f} (delta = {(SINR_valid-best_SINR):.5f})')
        best_SINR = SINR_valid
        PATH = f'../temp_model/AP0_bestSINR_{M_t}APs.pth'
        torch.save(SINR_AP0.state_dict(),PATH)

        PATH = f'../temp_model/AP1_bestSINR_{M_t}APs.pth'
        torch.save(SINR_AP1.state_dict(),PATH)

        PATH = f'../temp_model/AP2_bestSINR_{M_t}APs.pth'
        torch.save(SINR_AP2.state_dict(),PATH)

        counter2=0
      else:
        counter2+=1


      if L_valid < best_loss:
        print(f'Improved loss: from {best_loss:.5f} to {L_valid:.5f} (delta = {(L_valid-best_loss):.5f})')
        best_loss = L_valid

      print(f'Best values: {best_SSNR}, {best_SINR}. Counters: {counter1}, {counter2}')
      print('########################################################')
      if counter2 == patience:
        print('Early stropping')
        break

disp_cycle=1
loss_ten_plot = combined_loss_ten[::disp_cycle]
ssnr_ten_plot = ssnr_tr_ten[::disp_cycle]
sinr_ten_plot = sinr_tr_ten[::disp_cycle]
trials = range(len(combined_loss_ten))
trials_plot = trials[::disp_cycle]

plt.figure(11)
plt.plot(trials_plot,loss_ten_plot, 'b')
plt.title(f'Training Curve (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Loss')
plt.grid()

plt.figure(12)
plt.plot(trials_plot,ssnr_ten_plot, 'b')
plt.title(f'Training Curve SSNR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'SSNR')
plt.grid()

plt.figure(13)
plt.plot(trials_plot,sinr_ten_plot, 'b')
plt.title(f'Training Curve SINR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'SINR')
plt.grid()





plt.figure(1300)
plt.plot(range(epoch+1),L_valid_ten, 'r')
plt.title(f'Validation Loss (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
# plt.ylabel(f'Scale')
# plt.legend(['SSNR','SINR','Loss'])
plt.grid()

plt.figure(1100)
plt.plot(range(epoch+1),validation_combined_SSNR, 'r')
plt.title(f'Validation Curve SSNR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SSNR')
plt.grid()

plt.figure(1200)
plt.plot(range(epoch+1),validation_combined_SINR, 'r')
plt.title(f'Validation Curve SINR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SINR')
plt.grid()



model_index_SSNR = np.where(validation_combined_SSNR == max(validation_combined_SSNR))
print(f'SSNR: Model at epoch {model_index_SSNR[0].item()} achieved {max(validation_combined_SSNR)}')

model_index_SINR = np.where(validation_combined_SINR == max(validation_combined_SINR))
print(f'SINR: Model at epoch {model_index_SINR[0].item()} achieved {max(validation_combined_SINR)}')
print(f'Average SINR after convergence = {np.mean(validation_combined_SINR[-300:])}')
print(f'SINR at maximum SSNR = {validation_combined_SINR[model_index_SSNR[0]].item()}')
# print(f'Refrence values (SSNR,SINR) : ({SSNR_ref.item():.4f},{SINR_ref:.4f})')

model_index_loss = np.where(L_valid_ten == min(L_valid_ten))
print(f'Loss: Model at epoch {model_index_loss[0].item()} achieved {min(L_valid_ten)}')


# Save results

np.savetxt(f"../save_results/train_loss_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", loss_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_ssnr_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", ssnr_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_sinr_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", sinr_ten_plot, delimiter=",")

np.savetxt(f"../save_results/valid_loss_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", L_valid_ten, delimiter=",")
np.savetxt(f"../save_results/valid_ssnr_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SSNR, delimiter=",")
np.savetxt(f"../save_results/valid_sinr_SINRteacher{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SINR, delimiter=",")

### 5.3 SSNR teacher training

In [None]:
M_t=3  # Num APs
seed=5
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
alpha_ten = [1.0]
for alpha in alpha_ten:
    best_SSNR = -100000
    best_SINR = -100000
    best_loss = 100000
    combined_loss_ten = []
    sinr_tr_ten = []
    ssnr_tr_ten = []
    validation_combined_SSNR=[]
    validation_combined_SINR=[]
    L_valid_ten = []
    best_loss = 10000
    
    SSNR_AP0 = UNet().to(device)
    SSNR_AP1 = UNet().to(device)
    SSNR_AP2 = UNet().to(device)

    weight_decay=0
    opt_adam_AP0 = torch.optim.Adam(SSNR_AP0.parameters(), lr=1e-2)
    opt_adam_AP1 = torch.optim.Adam(SSNR_AP1.parameters(), lr=1e-2)
    opt_adam_AP2 = torch.optim.Adam(SSNR_AP2.parameters(), lr=1e-2)
    lr_scheduler_AP0 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP0, 'min', 0.1, 10, verbose=True)
    lr_scheduler_AP1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP1, 'min', 0.1, 10, verbose=True)
    lr_scheduler_AP2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP2, 'min', 0.1, 10, verbose=True)


    flag=0
    sinr_save=0
    counter1= 0
    counter2= 0
    disp_cycle = 5
    epochs = 100
    patience=100
    #trials=0
    for epoch in range(epochs):
      mean_loss=0
      SSNR_AP0.train()
      SSNR_AP1.train()
      SSNR_AP2.train()
      for i, (x0_train, x1_train, x2_train, H_comm, DAD) in enumerate(train_dataloader):
        x0 = x0_train.to(device)  
        x1 = x1_train.to(device)  
        x2 = x2_train.to(device)   

        y0_estimate = SSNR_AP0(x0)
        y1_estimate = SSNR_AP1(x1)
        y2_estimate = SSNR_AP2(x2)

        y_estimate = torch.stack((y0_estimate,y1_estimate,y2_estimate),dim=4)

        L_unsup, ssnr_tr, sinr_tr = unsup_loss(DAD.to(device),H_comm.to(device),y_estimate,alpha)

        mean_loss=mean_loss+L_unsup

        with torch.autograd.set_detect_anomaly(True):
            opt_adam_AP0.zero_grad()

            opt_adam_AP1.zero_grad()

            opt_adam_AP2.zero_grad()

            L_unsup.backward()

        opt_adam_AP0.step()
        opt_adam_AP1.step()
        opt_adam_AP2.step()


        if i%disp_cycle==0:
          print(f'ep {epoch+1} - p {i+1} - Loss = {L_unsup.item():.7f}, SSNR = {ssnr_tr.item():.7f}, SINR = {sinr_tr.item():.7f}')

      combined_loss_ten = np.append(combined_loss_ten,L_unsup.item())
      ssnr_tr_ten = np.append(ssnr_tr_ten,ssnr_tr.item())
      sinr_tr_ten = np.append(sinr_tr_ten,sinr_tr.item())

      lr_scheduler_AP0.step(mean_loss/(i+1))
      lr_scheduler_AP1.step(mean_loss/(i+1))
      lr_scheduler_AP2.step(mean_loss/(i+1))
      current_lr = opt_adam_AP0.param_groups[0]['lr']

      L_valid, SSNR_validate, SINR_valid, vio_var_valid = validate(SSNR_AP0, SSNR_AP1, SSNR_AP2,test_dataloader,alpha)

      validation_combined_SSNR=np.append(validation_combined_SSNR,SSNR_validate)
      validation_combined_SINR=np.append(validation_combined_SINR,SINR_valid)
      L_valid_ten = np.append(L_valid_ten,L_valid)

      print(f'ep {epoch+1} (alpha={alpha}): Loss = {L_valid:.5f}. Validation SSNR = {SSNR_validate:.7f}, average min SINR: {SINR_valid.item():.7f}, lr = {current_lr}')

      if SSNR_validate > best_SSNR:
        print(f'Improved SSNR: from {best_SSNR:.5f} to {SSNR_validate:.5f} (delta = {(SSNR_validate-best_SSNR):.5f})')
        best_SSNR = SSNR_validate
        
        PATH = f'../temp_model/AP0_bestSSNR_{M_t}APs.pth'
        torch.save(SSNR_AP0.state_dict(),PATH)

        PATH = f'../temp_model/AP1_bestSSNR_{M_t}APs.pth'
        torch.save(SSNR_AP1.state_dict(),PATH)

        PATH = f'../temp_model/AP2_bestSSNR_{M_t}APs.pth'
        torch.save(SSNR_AP2.state_dict(),PATH)

        counter1=0
      else:
        counter1+=1

      if SINR_valid > best_SINR:
        print(f'Improved SINR: from {best_SINR:.5f} to {SINR_valid:.5f} (delta = {(SINR_valid-best_SINR):.5f})')
        best_SINR = SINR_valid

        counter2=0
      else:
        counter2+=1


      if L_valid < best_loss:
        print(f'Improved loss: from {best_loss:.5f} to {L_valid:.5f} (delta = {(L_valid-best_loss):.5f})')
        best_loss = L_valid

      print(f'Best values: {best_SSNR}, {best_SINR}. Counters: {counter1}, {counter2}')
      print('########################################################')
      if counter1 == patience:
        print('Early stropping')
        break

disp_cycle=1
loss_ten_plot = combined_loss_ten[::disp_cycle]
ssnr_ten_plot = ssnr_tr_ten[::disp_cycle]
sinr_ten_plot = sinr_tr_ten[::disp_cycle]
trials = range(len(combined_loss_ten))
trials_plot = trials[::disp_cycle]

plt.figure(11)
plt.plot(trials_plot,loss_ten_plot, 'b')
plt.title(f'Training Curve (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Loss')
plt.grid()

plt.figure(12)
plt.plot(trials_plot,ssnr_ten_plot, 'b')
plt.title(f'Training Curve SSNR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'SSNR')
plt.grid()

plt.figure(13)
plt.plot(trials_plot,sinr_ten_plot, 'b')
plt.title(f'Training Curve SINR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'SINR')
plt.grid()





plt.figure(1300)
plt.plot(range(epoch+1),L_valid_ten, 'r')
plt.title(f'Validation Loss (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.grid()

plt.figure(1100)
plt.plot(range(epoch+1),validation_combined_SSNR, 'r')
plt.title(f'Validation Curve SSNR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SSNR')
plt.grid()

plt.figure(1200)
plt.plot(range(epoch+1),validation_combined_SINR, 'r')
plt.title(f'Validation Curve SINR (Batch size = {batch_size})')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SINR')
plt.grid()



model_index_SSNR = np.where(validation_combined_SSNR == max(validation_combined_SSNR))
print(f'SSNR: Model at epoch {model_index_SSNR[0].item()} achieved {max(validation_combined_SSNR)}')

model_index_SINR = np.where(validation_combined_SINR == max(validation_combined_SINR))
print(f'SINR: Model at epoch {model_index_SINR[0].item()} achieved {max(validation_combined_SINR)}')
print(f'Average SINR after convergence = {np.mean(validation_combined_SINR[-300:])}')
print(f'SINR at maximum SSNR = {validation_combined_SINR[model_index_SSNR[0]].item()}')

model_index_loss = np.where(L_valid_ten == min(L_valid_ten))
print(f'Loss: Model at epoch {model_index_loss[0].item()} achieved {min(L_valid_ten)}')


# Save results

np.savetxt(f"../save_results/train_loss_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", loss_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_ssnr_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", ssnr_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_sinr_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", sinr_ten_plot, delimiter=",")

np.savetxt(f"../save_results/valid_loss_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", L_valid_ten, delimiter=",")
np.savetxt(f"../save_results/valid_ssnr_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SSNR, delimiter=",")
np.savetxt(f"../save_results/valid_sinr_SSNRteacher{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SINR, delimiter=",")

### 6. Evaluate reference values for SINR and SSNR

### 6.1 Upload saved models

In [None]:
#Initialize models:
SINR_AP0 = UNet()
SINR_AP1 = UNet()
SINR_AP2 = UNet()

SSNR_AP0 = UNet()
SSNR_AP1 = UNet()
SSNR_AP2 = UNet()

#Upload best models:


#SINR-biased teachers
model_path = f'../temp_model/AP0_bestSINR_{M_t}APs.pth'
SINR_AP0.load_state_dict(torch.load(model_path))

model_path = f'../temp_model/AP1_bestSINR_{M_t}APs.pth'
SINR_AP1.load_state_dict(torch.load(model_path))

model_path = f'../temp_model/AP2_bestSINR_{M_t}APs.pth'
SINR_AP2.load_state_dict(torch.load(model_path))


#SSNR-biased teachers
model_path = f'../temp_model/AP0_bestSSNR_{M_t}APs.pth'
SSNR_AP0.load_state_dict(torch.load(model_path))

model_path = f'../temp_model/AP1_bestSSNR_{M_t}APs.pth'
SSNR_AP1.load_state_dict(torch.load(model_path))

model_path = f'../temp_model/AP2_bestSSNR_{M_t}APs.pth'
SSNR_AP2.load_state_dict(torch.load(model_path))

### 6.2 Reference values

In [None]:
sinr_gt_ten = []
ssnr_gt_ten = []

for i, (x0_train, x1_train, x2_train, H_comm, DAD) in enumerate(train_dataloader):
    BS=x0_train.shape[0]
    x0 = x0_train
    x1 = x1_train 
    x2 = x2_train 


    y0_sinr = SINR_AP0(x0)
    y1_sinr = SINR_AP1(x1)
    y2_sinr = SINR_AP2(x2)

    y_sinr = torch.stack((y0_sinr,y1_sinr,y2_sinr),dim=4)
    _, _, sinr = unsup_loss(DAD.to(device),H_comm.to(device),y_sinr.to(device),0)


    y0_ssnr = SSNR_AP0(x0)
    y1_ssnr = SSNR_AP1(x1)
    y2_ssnr = SSNR_AP2(x2)

    y_ssnr = torch.stack((y0_ssnr,y1_ssnr,y2_ssnr),dim=4)
    _, ssnr, _ = unsup_loss(DAD.to(device),H_comm.to(device),y_ssnr.to(device),0)
    sinr_gt_ten = np.append(sinr_gt_ten,sinr.to('cpu'))
    ssnr_gt_ten = np.append(ssnr_gt_ten,ssnr.to('cpu'))

sinr_gt = np.mean(sinr_gt_ten)  # Used for student training 
ssnr_gt = np.mean(ssnr_gt_ten)  # Used for student training 
print(f'Ground truth: SSNR = {ssnr_gt}, SINR = {sinr_gt}')

### 7. Student training

### 7.1 Loss function, validation loss function, and validation regime

In [None]:
sigmasq_ue = 1
sigmasq_radar_rcs = 0.1000

e1=1e-2
e2=1e-2

# Loss function

def unsup_loss(DAD,H_comm,f_pred,SSNR_gt,SINR_gt,lam):
  P_total = 1 #Watts
  M_t = f_pred.shape[-1]
  batch_size = f_pred.shape[0]

  f_pred_complex = torch.complex(f_pred[:,0,:,:,:],f_pred[:,1,:,:,:])
  f_strems_anten = f_pred_complex.reshape((batch_size,U+T,Nt*M_t))
  F_sum = torch.matmul(torch.conj(torch.transpose(f_strems_anten,1,2)),f_strems_anten).to(device)

  SSNR = 0
  SINR = torch.zeros((batch_size,U),device=device)
  for mt in range(M_t):
    SSNR = SSNR + (torch.matmul(DAD[:,mt,:,:],F_sum)).diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

  p0 = SSNR.real*sigmasq_radar_rcs

  H_transposed = torch.transpose(H_comm,2,3)
  H_st = H_transposed.reshape(batch_size,U,-1)

  for u in range(U):
    h_u = H_st[:,u:u+1,:].to(device)
    f_u = f_strems_anten[:,u:u+1,:].to(device)
    SINRu_num = torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_u,1,2)))**2
    L_u=0
    for st in range(U+T):
      if st != u :
        f_int = f_strems_anten[:,st:st+1,:].to(device)
        L_u = L_u + torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_int,1,2)))**2
    SINR[:,u] = SINRu_num[:,0,0] / (L_u[:,0,0]+sigmasq_ue)

  SINR_min,_ = torch.min(SINR,dim=1)

  # Lambda update criterion:
  if torch.mean((SINR_gt-SINR_min)/SINR_gt)>=torch.mean((SSNR_gt-p0)/SSNR_gt): # If SINR loss is larger than SSNR loss (Normalized): Update lam2
        lam+=e2*torch.mean(SINR_gt-SINR_min)/SINR_gt
        lam[lam>1]=1.0
  else:
        lam-=e1*torch.mean(SSNR_gt-p0)/SSNR_gt
        lam[lam<0]=0.0

  L = -((1-lam)*torch.mean(p0/SSNR_gt) + lam*torch.mean(SINR_min/SINR_gt))

  return L, torch.mean(p0).detach(), torch.mean(SINR_min).detach(),lam.detach()





# Validation loss function

def validate_metric(DAD,H_comm,f_pred,lam):

  M_t = f_pred.shape[-1]
  batch_size = f_pred.shape[0]
  f_pred_complex = torch.complex(f_pred[:,0,:,:,:],f_pred[:,1,:,:,:])
  f_strems_anten = f_pred_complex.reshape((batch_size,U+T,Nt*M_t))
  F_sum = torch.matmul(torch.conj(torch.transpose(f_strems_anten,1,2)),f_strems_anten).to(device)

  SSNR = 0
  SINR = torch.zeros((batch_size,U),device=device)
  for mt in range(M_t):
    SSNR = SSNR + (torch.matmul(DAD[:,mt,:,:],F_sum)).diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

  p0 = SSNR.real*sigmasq_radar_rcs

  H_transposed = torch.transpose(H_comm,2,3)
  H_st = H_transposed.reshape(batch_size,U,-1)

  for u in range(U):
    h_u = H_st[:,u:u+1,:].to(device)
    f_u = f_strems_anten[:,u:u+1,:].to(device)
    SINRu_num = torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_u,1,2)))**2
    L_u=0
    for st in range(U+T):
      if st != u :
        f_int = f_strems_anten[:,st:st+1,:].to(device)
        L_u = L_u + torch.abs(torch.matmul(torch.conj(h_u),torch.transpose(f_int,1,2)))**2
    SINR[:,u] = SINRu_num / (L_u+sigmasq_ue)

  vio1 = torch.min(SINR)
  L = -((1-lam)*torch.mean(p0/ssnr_gt) + lam*torch.mean(vio1/sinr_gt))

  return L, p0, vio1





# Validation regime

def validate(model0,model1,model2,test_dataloader,lam):
  SSNR_estimate_ten = []
  num_vio1 = []
  L_valid_sum=0
  with torch.no_grad():
    model0.eval()
    model1.eval()
    model2.eval()

    for i, (x0_test, x1_test,x2_test, H_comm, DAD) in enumerate(test_dataloader):
      x0 = x0_test.to(device) 
      x1 = x1_test.to(device) 
      x2 = x2_test.to(device)  


      y0_predict = model0(x0)
      y1_predict = model1(x1)
      y2_predict = model2(x2)

      y_predict = torch.stack((y0_predict,y1_predict,y2_predict),dim=4)
      L_valid,SSNR_estimate,vio1 = validate_metric(DAD.to(device),H_comm.to(device),y_predict.to(device),lam)
      num_vio1 = np.append(num_vio1,vio1.to('cpu'))
      L_valid_sum+=L_valid

      SSNR_estimate_ten = np.append(SSNR_estimate_ten,SSNR_estimate.to('cpu'))

  return L_valid_sum.item()/(i+1),np.mean(SSNR_estimate_ten),np.mean(num_vio1),np.var(num_vio1)

### 7.2 Training

In [None]:
lam1_ten = []
M_t=3
seed=5
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
mse_ue_ten = []
mse_t_ten = []
mse_ue_valid_ten = []
mse_t_valid_ten = []
lam1=torch.tensor([0.5],device=device)
lam1_ten=np.append(lam1_ten,lam1.to('cpu'))
best_SSNR = -100000
best_SINR = -100000
best_loss = 100000
combined_loss_ten = []
sinr_tr_ten = []
ssnr_tr_ten = []
validation_combined_SSNR=[]
validation_combined_SINR=[]
L_valid_ten = []
best_loss = 10000

model_AP0 = UNet().to(device)
model_AP1 = UNet().to(device)
model_AP2 = UNet().to(device)

opt_adam_AP0 = torch.optim.Adam(model_AP0.parameters(), lr=1e-2)
opt_adam_AP1 = torch.optim.Adam(model_AP1.parameters(), lr=1e-2)
opt_adam_AP2 = torch.optim.Adam(model_AP2.parameters(), lr=1e-2)
lr_scheduler_AP0 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP0, 'min', 0.1, 10, verbose=True)
lr_scheduler_AP1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP1, 'min', 0.1, 10, verbose=True)
lr_scheduler_AP2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_adam_AP2, 'min', 0.1, 10, verbose=True)

flag=0

sinr_save=0
counter1= 0
counter2= 0
disp_cycle = 5
epochs = 1000
patience=100

for epoch in range(epochs):
    if epoch == 100:
        best_SSNR = -100000
        best_SINR = -100000
    mean_loss=0
    model_AP0.train()
    model_AP1.train()
    model_AP2.train()

    for i, (x0_train, x1_train, x2_train, H_comm, DAD) in enumerate(train_dataloader):
        BS=x0_train.shape[0]
        x0 = x0_train.to(device)  
        x1 = x1_train.to(device)  
        x2 = x2_train.to(device)   

        y0_estimate = model_AP0(x0)
        y1_estimate = model_AP1(x1)
        y2_estimate = model_AP2(x2)
        y_estimate = torch.stack((y0_estimate,y1_estimate,y2_estimate),dim=4)

        L_distil, ssnr_tr, sinr_tr,lam1 = unsup_loss(DAD.to(device),H_comm.to(device),y_estimate,ssnr_gt, sinr_gt,lam1)
        mean_loss=mean_loss+L_distil


        with torch.autograd.set_detect_anomaly(True):

            opt_adam_AP0.zero_grad()

            opt_adam_AP1.zero_grad()

            opt_adam_AP2.zero_grad()

            L_distil.backward()

        opt_adam_AP0.step()
        opt_adam_AP1.step()
        opt_adam_AP2.step()


        if i%disp_cycle==0:
          print(f'ep {epoch+1} - p {i+1} - Loss = {L_distil.item():.4f}, SSNR = {ssnr_tr.item():.5f}, SINR = {sinr_tr.item():.5f}, lam = {lam1.item():.3f}')


    combined_loss_ten = np.append(combined_loss_ten,L_distil.item())
    ssnr_tr_ten = np.append(ssnr_tr_ten,ssnr_tr.item())
    sinr_tr_ten = np.append(sinr_tr_ten,sinr_tr.item())
    lam1_ten = np.append(lam1_ten,lam1.to('cpu'))




    lr_scheduler_AP0.step(mean_loss/(i+1))
    lr_scheduler_AP1.step(mean_loss/(i+1))
    lr_scheduler_AP2.step(mean_loss/(i+1))
    current_lr = opt_adam_AP0.param_groups[0]['lr']

    L_valid,SSNR_validate, SINR_valid, vio_var_valid = validate(model_AP0, model_AP1, model_AP2, test_dataloader,lam1)
    
    validation_combined_SSNR=np.append(validation_combined_SSNR,SSNR_validate)
    validation_combined_SINR=np.append(validation_combined_SINR,SINR_valid)
    L_valid_ten = np.append(L_valid_ten,L_valid)
    var_valid_ten = np.append(var_valid_ten,vio_var_valid)

    print(f'ep {epoch+1} (alpha={alpha}): Loss = {L_valid:.5f}. Validation SSNR = {SSNR_validate:.7f}, average min SINR: {SINR_valid.item():.7f}, lr = {current_lr}')


    if SSNR_validate > best_SSNR:
        print(f'Improved SSNR: from {best_SSNR:.5f} to {SSNR_validate:.5f} (delta = {(SSNR_validate-best_SSNR):.5f})')
        best_SSNR = SSNR_validate

        counter1=0
    else:
        counter1+=1

    if SINR_valid > best_SINR:
        print(f'Improved SINR: from {best_SINR:.5f} to {SINR_valid:.5f} (delta = {(SINR_valid-best_SINR):.5f})')
        best_SINR = SINR_valid
        counter2=0
    else:
        counter2+=1

    print(f'Best values: {best_SSNR}, {best_SINR}. Counters: {counter1}, {counter2}')

    print('########################################################')
    if counter1 >= patience and counter2 >= patience:
        print('Early stropping')
        break
counter=0
disp_cycle=1
loss_ten_plot = combined_loss_ten[::disp_cycle]
ssnr_ten_plot = ssnr_tr_ten[::disp_cycle]
sinr_ten_plot = sinr_tr_ten[::disp_cycle]
trials = range(len(combined_loss_ten))
trials_plot = trials[::disp_cycle]

plt.figure(110+counter)
plt.plot(trials_plot,loss_ten_plot, 'b')
plt.title(f'Training Curve ')
plt.xlabel(f'Epoch')
plt.ylabel(f'Loss')
plt.grid()

plt.figure(120+counter)
plt.plot(trials_plot,ssnr_ten_plot, 'b')
plt.title(f'Training Curve SSNR ')
plt.xlabel(f'Epoch')
plt.ylabel(f'SSNR')
plt.grid()

plt.figure(130+counter)
plt.plot(trials_plot,sinr_ten_plot, 'b')
plt.title(f'Training Curve SINR ')
plt.xlabel(f'Epoch')
plt.ylabel(f'SINR')
plt.grid()

plt.figure(1100+counter)
plt.plot(range(epoch+1),L_valid_ten, 'r')
plt.title(f'Validation Loss ')
plt.xlabel(f'Epoch')
plt.grid()

plt.figure(1200+counter)
plt.plot(range(epoch+1),validation_combined_SSNR, 'r')
plt.title(f'Validation Curve SSNR ')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SSNR')
plt.grid()

plt.figure(1300+counter)
plt.plot(range(epoch+1),validation_combined_SINR, 'r')
plt.title(f'Validation Curve SINR ')
plt.xlabel(f'Epoch')
plt.ylabel(f'Average SINR')
plt.grid()


plt.figure(1600+counter)
plt.plot(lam1_ten, 'g')
plt.title(f'lam1')
plt.xlabel(f'Epoch')
plt.ylabel(f'lam1')
plt.grid()


model_index_SSNR = np.where(validation_combined_SSNR == max(validation_combined_SSNR))
print(f'SSNR: Model at epoch {model_index_SSNR[0].item()} achieved {max(validation_combined_SSNR)}')

model_index_SINR = np.where(validation_combined_SINR == max(validation_combined_SINR))
print(f'SINR: Model at epoch {model_index_SINR[0].item()} achieved {max(validation_combined_SINR)}')
print(f'Average SINR after convergence = {np.mean(validation_combined_SINR[-300:])}')
print(f'SINR at maximum SSNR = {validation_combined_SINR[model_index_SSNR[0]].item()}')


# Save results

np.savetxt(f"../save_results/train_loss_student{M_t}APs_U{U}_ant{Nt}.csv", loss_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_ssnr_student{M_t}APs_U{U}_ant{Nt}.csv", ssnr_ten_plot, delimiter=",")
np.savetxt(f"../save_results/train_sinr_student{M_t}APs_U{U}_ant{Nt}.csv", sinr_ten_plot, delimiter=",")

np.savetxt(f"../save_results/valid_loss_student{M_t}APs_U{U}_ant{Nt}.csv", L_valid_ten, delimiter=",")
np.savetxt(f"../save_results/valid_ssnr_student{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SSNR, delimiter=",")
np.savetxt(f"../save_results/valid_sinr_student{M_t}APs_U{U}_ant{Nt}.csv", validation_combined_SINR, delimiter=",")

np.savetxt(f"../save_results/lam_student{M_t}APs_U{U}_ant{Nt}.csv", lam1_ten, delimiter=",")