In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics


class dataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32, device='cpu')
        self.y = torch.tensor(y, dtype=torch.long, device='cpu')
        self.length = self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return self.length


class Net(nn.Module):
    def __init__(self, input_shape, output_shape):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_shape, 126)
        self.fc2 = nn.Linear(126, 64)
        # self.fc2 = nn.Linear(100, 50)
        # self.fc3 = nn.Linear(128, 64)
        self.fcout = nn.Linear(64, output_shape)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.25)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        #out = self.dropout(out)

        out = self.fc2(out)
        out = self.relu(out)
        #out = self.dropout(out)

        # out = self.fc3(out)
        # out = self.relu(out)
        # out = self.dropout(out)

        out = self.fcout(out)
        #out = self.dropout(out)
        return out


# def compute_accuracy(output, targets):
#    predicted_labels = torch.argmax(output, dim=1)
#    targets = torch.argmax(targets, dim=1)  # get the index of the maximum value
#    num_correct = torch.sum(predicted_labels == targets).item()
#    accuracy = num_correct / len(targets)
#    return accuracy

# def compute_accuracy(output, targets):
#    output_softmax = torch.log_softmax(output, dim=1)
#    _, output_tags = torch.max(output_softmax, dim=1)
#    _, targets = torch.max(targets, dim=1)
#    correct_pred = (output_tags == targets)
#    acc = correct_pred.sum() / len(correct_pred)
#    acc = torch.round(acc * 100)
#    return acc

TrainX = np.load('Train_X_3mer.npy')
TrainY = np.load('Train_Y_PHYLUM.npy')
TestX = np.load('Test_X_3mer.npy')
TestY = np.load('Test_Y_PHYLUM.npy')
ValX = np.load('Validation_X_3mer.npy')
ValY = np.load('Validation_Y_PHYLUM.npy')
print('Training, test and validation datasets are loaded...')

batches = 100
trainset = dataset(TrainX, TrainY)
valset = dataset(ValX, ValY)
testset = dataset(TestX, TestY)
trainloader = DataLoader(trainset, batch_size=batches, shuffle=True)
valloader = DataLoader(valset, batch_size=batches, shuffle=False)
testloader = DataLoader(testset, batch_size=batches, shuffle=False)
testloader2 = DataLoader(testset, shuffle=False)
print('Loading trainset, trainloader, testset, testloader ...')

learning_rate = 0.001
epochs = 200
model = Net(input_shape=125, output_shape=16)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
print('Setting hyperparameters...')

training_losses = []
training_accuracies = []
validation_losses = []
validation_accuracies = []

print('Training model...')
for epoch in range(epochs):
    # Training loop
    model.train()
    training_loss = 0.0
    correct = 0
    total = 0
    for j, (x_train, y_train) in enumerate(trainloader):
        # calculate output
        output = model(x_train)

        # calculate loss
        loss = loss_fn(output, y_train)

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate training loss and accuracy
        training_loss += loss.item() * x_train.size(0)
        output_tags = torch.argmax(output, dim=1)
        targets = y_train
        correct += (output_tags == targets).sum().item()
        total += y_train.size(0)

    # Print training statistics
    epoch_loss = training_loss / len(trainloader.dataset)
    epoch_acc = 100. * correct / total
    print(f'Epoch [{epoch + 1}] Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.2f}%')

    # Store the training loss and training accuracy
    training_losses.append(epoch_loss)
    training_accuracies.append(epoch_acc)

    # Validation loop
    model.eval()
    validation_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for j, (x_val, y_val) in enumerate(valloader):
            output = model(x_val)
            loss = loss_fn(output, y_val)
            # Calculate validation loss and accuracy
            validation_loss += loss.item() * x_val.size(0)
            output_tags = torch.argmax(output, dim=1)
            targets = y_val
            correct += (output_tags == targets).sum().item()
            total += y_val.size(0)

    # Print validation statistics
    epoch_val_loss = validation_loss / len(valloader.dataset)
    epoch_val_acc = 100. * correct / total
    print(f'Epoch [{epoch + 1}] Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.2f}%')

    # Store the validation loss and validation accuracy
    validation_losses.append(epoch_val_loss)
    validation_accuracies.append(epoch_val_acc)


# Testing
with torch.no_grad():
    test_accuracy = 0.0
    correct = 0
    total = 0
    y_pred = []
    y_true = []
    # simple accuracy as above
    for x_test, y_test in testloader:
        test_output = model(x_test)
        output_tags = torch.argmax(test_output, dim=1)
        targets = y_test
        correct += (output_tags == targets).sum().item()
        total += y_test.size(0)
        y_pred += torch.argmax(test_output, dim=1).tolist()
        y_true += y_test.tolist()
    print(metrics.classification_report(y_true, y_pred, digits=3))
    epoch_test_acc = 100. * correct / total
    print(f'Test accuracy: {epoch_test_acc:.2f}')



# Project settings for plots
plt.style.use('bmh')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'UGent Panno Text'
plt.rcParams['font.monospace'] = 'UGent Panno Text'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 12
# Set an aspect ratio

plt.plot(training_losses, label='Training', color='#1E64C8', linewidth=1)
plt.plot(validation_losses, label='Validation', color='black', linewidth=1)
plt.title('Training and Validation Loss with NAME on 3mers and phyla')
plt.xlabel('Epoch')
plt.ylabel('Loss (in %)')
plt.legend()
plt.show()

plt.plot(training_accuracies, label='Training', color='#1E64C8', linewidth=1)
plt.plot(validation_accuracies, label='Validation', color='black', linewidth=1)
plt.title('Training and Validation Accuracy with NAME on 3mers and phyla')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (in %)')
plt.legend()
plt.show()

Running the code above on 3-mer and using phyla as labels, results in the following graphs and table:
![losstrainval](https://user-images.githubusercontent.com/127412115/236807579-5cb2d45e-cca4-4e84-872a-9e2875b08d78.png)\
![acctrainval](https://user-images.githubusercontent.com/127412115/236807578-af234a32-f50c-42c2-9ed8-2dea7626d053.png)

TEST RESULTS:
|     PHYLUM |  precision|    recall|  f1-score|   support|
|-----------:|:---------:|:--------:|:--------:|:---------|
|           0|      0.988|     0.996|     0.992|      8952|
|           1|      1.000|     1.000|     1.000|        11|
|           2|      0.997|     0.989|     0.993|     10231|
|           3|      0.714|     0.833|     0.769|         6|
|           4|      1.000|     1.000|     1.000|         5|
|           5|      0.827|     0.937|     0.878|       158|
|           6|      1.000|     0.818|     0.900|        11|
|           7|      1.000|     0.769|     0.870|        13|
|           8|      0.997|     0.994|     0.996|       689|
|           9|      0.892|     0.868|     0.880|        38|
|          10|      0.000|     0.000|     0.000|         3|
|          11|      0.986|     0.976|     0.981|       291|
|          12|      0.952|     0.959|     0.956|       292|
|          13|      1.000|     0.895|     0.944|        19|
|          14|      1.000|     1.000|     1.000|         2|
|          15|      0.800|     0.800|     0.800|        10|
|-----------:|:---------:|:--------:|:--------:|:---------|
|    accuracy|           |          |     0.990|     20731|
|   macro avg|      0.885|     0.865|     0.872|     20731|
|weighted avg|      0.990|     0.990|     0.990|     20731|

