In [None]:
import os
from pickle import load
import matplotlib.pyplot as plt
from numpy import array, stack

In [None]:
fading = 0
num_ant = 2
SNR_dB = 5
bit_codebook = 1
max_iter = 100000
num_classes = 2**bit_codebook

In [None]:
dir_name = 'datasets'
if not os.path.exists(dir_name):
    os.makedirs(dir_name)
file_name = f'{dir_name}/{fading}_precoder_data_{num_ant}_ant_SNR_{SNR_dB}dB_{bit_codebook}_bit_codebk'
print(file_name)
with open(f'{file_name}.pkl', 'rb') as f:
    precoders, labels = load(f)

In [None]:
# test plots
for i in range(2):
    plt.scatter(precoders[i][0::2], precoders[i][1::2])
    plt.show()

In [None]:
precoder_data = stack(precoders, axis=0)
precoder_labels = array(labels)

In [None]:
import torch 
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
class PrecoderDataset(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).long()
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.data)

precoder_dataset = PrecoderDataset(precoder_data, precoder_labels)
train, val = random_split(precoder_dataset, [int(0.8*max_iter), int(0.2*max_iter)])
train_loader = DataLoader(
    train,
    batch_size=128,
    shuffle=True,
    num_workers=0,
)
# increase batch size
val_loader = DataLoader(
    val,
    batch_size=128,
    shuffle=True,
    num_workers=0,
)

In [None]:
from torch import nn 
from torch import optim 
from torchsummary import summary

In [None]:
class Model(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        # define layers and activation function as class parameters
        self.linear1 = nn.Linear(input_size, 8)
        self.drop = nn.Dropout(0.1)
        self.linear2 = nn.Linear(8, output_size)

    def forward(self, x):
        # define NN connections here, output of last layer is pred
        x1 = self.linear1(x)
        x2 = self.drop(x1)
        pred = self.linear2(x2)
        return pred

input_size = 2*num_ant*num_ant
output_size = num_classes
model = Model(input_size, output_size)

if torch.cuda.is_available():
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
loss = nn.CrossEntropyLoss()

In [None]:
num_epochs = 500 # Number of times you go through the whole dataset
for epoch in range(num_epochs):
    
    # TRAINING
    model.train()
    training_total_correct = 0
    training_losses = list()
    for batch in train_loader:
        
        x, y = batch # extracted from the batch 
        
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        # step 1: forward pass 
        prob = model(x) 
        
        output = prob
        target = y
        
        # step 2: compute objective function - measuring distance between the output of the network vs actual answer 
        obj_func = loss(output, target)
        
        # step 3: clear the gradients 
        model.zero_grad()
        
        # step 4: accumulate partial derivatives of obj_func wrt parameters 
        obj_func.backward()
        
        # step 5: step in the opposite direction of the gradient 
        optimizer.step()

        training_losses.append(obj_func.item())
        
        max_prob, max_ind = torch.max(output, 1)
#         print(len(max_ind))
        for ind in range(len(max_ind)):
#             print(max_ind[ind], y[ind])
            if torch.equal(max_ind[ind], y[ind]):
                training_total_correct += 1
#     print(training_total_correct)  
#     print(f'Epoch {epoch + 1}, training loss: {torch.tensor(training_losses).mean():.8f}')
    
    
    # VALIDATION
    model.eval()

    validation_total_correct = 0
    validation_losses = list()
    for batch in val_loader:
        
        x, y = batch # extracted from the batch 
        
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        # step 1: forward pass 
        with torch.no_grad():
            prob = model(x) 
        
        output = prob
        target = y
        
        # step 2: compute objective function - measuring distance between the output of the network vs actual answer 
        obj_func = loss(output, target)
        validation_losses.append(obj_func.item())
        
        max_prob, max_ind = torch.max(output, 1)
        
        for ind in range(len(max_ind)):
#             print(max_ind[ind], y[ind])
            if torch.equal(max_ind[ind], y[ind]):
                validation_total_correct += 1
#     print(len(train), len(val))          
    training_accuracy = (training_total_correct/len(train))*100  
    validation_accuracy = (validation_total_correct/len(val))*100   
    
    print(f'Epoch {epoch + 1}, training loss: {torch.tensor(training_losses).mean():.8f}, training accuracy: {training_accuracy}%, validation loss: {torch.tensor(validation_losses).mean():.8f}, validation accuracy: {validation_accuracy}%')