In [20]:
import sys
import torch
import matplotlib.pyplot as plt
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
sys.path.append("../")
from MLRF.datatools import h5kit
import seaborn as sns


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.conv1 = nn.Conv1d(
            in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1
        )
        self.conv2 = nn.Conv1d(
            in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.conv3 = nn.Conv1d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 256, 128)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(128, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(1)

        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))

        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        # x = self.dropout(x)
        x = self.fc2(x)

        return x

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_losses = []
    
    pbar = tqdm(dataloader, desc="Training", total=len(dataloader), unit="batchs", leave=False)
    
    for batch, (sample, label) in enumerate(pbar):
        sample, label = sample.to(device), label.to(device)

        pred = model(sample)
        loss = loss_fn(pred, label)
        batch_losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())
    
    loss = loss.cpu().detach().numpy()

    return loss, batch_losses

def test(dataloader, model):
    model.eval()
    total, correct = 0, 0
    batch_accuracies = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Testing", total=len(dataloader), unit="batchs", leave=False)
        for sample, label in pbar:
            sample, label = sample.to(device), label.to(device)
            output = model(sample)
            _, predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            
            batch_accuracy = 100 * (predicted == label).sum().item() / label.size(0)
            batch_accuracies.append(batch_accuracy)
            
            pbar.set_postfix(accuracy=100 * correct / total)

    overall_accuracy = 100 * correct / total
    return overall_accuracy, batch_accuracies

class psd_dataset(Dataset):
    def __init__(self, data):
        self.data = h5kit(data)
        self.keys = self.data.keys()
    def __len__(self):
        return len(self.keys) 
    def __getitem__(self, idx):
        return self.data.read(self.keys[idx]), 0 if self.keys[idx].split('_')[0] == 'wifi' else 1
print(f"Using {device} device")

In [22]:
bSize = 512
epochs = 10

training_dataset = psd_dataset("../psd_train.h5")
testing_dataset = psd_dataset("../psd_test.h5")

training_dataloader = DataLoader(training_dataset, batch_size=bSize, shuffle=True, num_workers=16)
testing_dataloader = DataLoader(testing_dataset, batch_size=bSize, shuffle=True, num_workers=16)

In [None]:
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_losses = []
test_losses = []
accuracies = []

for t in tqdm(range(epochs), desc="Epochs", total=epochs):
    train_loss, train_batch_loss = train(training_dataloader, model, loss_fn, optimizer)
    accuracy, batch_accuracy = test(testing_dataloader, model)
    train_losses.append(train_loss)
    accuracies.append(accuracy)

model_filename = Path(globals()['__vsc_ipynb_file__']).stem

torch.save(model, f"{model_filename}.pth")

print(f"Saved PyTorch Model State to {model_filename}")

In [None]:
plt.figure(figsize=(20, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label="train")
plt.title("Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.gca().yaxis.set_major_locator(plt.MultipleLocator(0.025))
plt.gca().xaxis.set_major_locator(plt.MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(accuracies, label="accuracy")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=1)
ymin, ymax = min(accuracies), max(accuracies)
if ymax > ymin:
    plt.gca().yaxis.set_major_locator(plt.MultipleLocator((ymax - ymin)/4))
else:
    plt.gca().yaxis.set_major_locator(plt.MultipleLocator(0.25))
plt.gca().xaxis.set_major_locator(plt.MultipleLocator(1))

plt.tight_layout()
plt.show()