In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
import pandas as pd


def CNNK(kmer, tax, epochsize):
    print(f'Initiating training, validation and testing on {tax} level with {kmer}.')

    class mydataset(Dataset):
        def __init__(self, x, y):
            self.x = torch.tensor(x, dtype=torch.float32, device='cpu')
            self.y = torch.tensor(y, dtype=torch.long, device='cpu')
            self.length = self.x.shape[0]

        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]

        def __len__(self):
            return self.length

    class ConvoNet(nn.Module):
        def __init__(self, input_shape, output_shape):
            super(ConvoNet, self).__init__()
            # First layer
            # Input shape = (100, 1, 125)
            self.conv1 = nn.Conv1d(in_channels=1, out_channels=2, kernel_size=8, stride=1)
            self.relu = nn.ReLU()
            self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
            self.fc1 = nn.Linear(((input_shape - 1*(8 - 1) - 1)//1 + 1), output_shape)

        def forward(self, x):
            out = self.conv1(x)
            out = self.relu(out)
            out = self.pool(out)

            out = out.view(out.size(0), -1)  # flatten
            out = self.fc1(out)
            return out

    print('Model constructed...')

    TrainX = np.load(f'Train_X_{kmer}.npy')
    TrainY = np.load(f'Train_Y_{tax}.npy')
    TestX = np.load(f'Test_X_{kmer}.npy')
    TestY = np.load(f'Test_Y_{tax}.npy')
    ValX = np.load(f'Validation_X_{kmer}.npy')
    ValY = np.load(f'Validation_Y_{tax}.npy')
    print('Training, test and validation datasets are loaded...')

    batches = 100
    trainset = mydataset(TrainX, TrainY)
    valset = mydataset(ValX, ValY)
    testset = mydataset(TestX, TestY)
    trainloader = DataLoader(trainset, batch_size=batches, shuffle=True)
    valloader = DataLoader(valset, batch_size=batches, shuffle=False)
    testloader = DataLoader(testset, batch_size=batches, shuffle=False)
    print('Loading trainset, trainloader, testset, testloader ...')

    learning_rate = 0.0005
    epochs = epochsize
    input_size = TrainX.shape[1]
    output_size = len(np.unique(TrainY))
    model = ConvoNet(input_shape=input_size, output_shape=output_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()
    print('Setting hyperparameters...')

    training_losses = []
    training_accuracies = []
    validation_losses = []
    validation_accuracies = []

    print('Training model...')
    for epoch in range(epochs):
        # Training loop
        model.train()
        training_loss = 0.0
        correct = 0
        total = 0
        for j, (x_train, y_train) in enumerate(trainloader):
            # calculate output
            x_train = x_train.unsqueeze(1)
            output = model(x_train)

            # calculate loss
            loss = loss_fn(output, y_train)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate training loss and accuracy
            training_loss += loss.item() * x_train.size(0)
            output_tags = torch.argmax(output, dim=1)
            targets = y_train
            correct += (output_tags == targets).sum().item()
            total += y_train.size(0)

        # Print training statistics
        epoch_loss = training_loss / len(trainloader.dataset)
        epoch_acc = 100. * correct / total
        print(f'Epoch [{epoch + 1}] Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.2f}%')

        # Store the training loss and training accuracy
        training_losses.append(epoch_loss)
        training_accuracies.append(epoch_acc)

        # Validation loop
        model.eval()
        validation_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for j, (x_val, y_val) in enumerate(valloader):
                x_val = x_val.unsqueeze(1)
                output = model(x_val)
                loss = loss_fn(output, y_val)
                # Calculate validation loss and accuracy
                validation_loss += loss.item() * x_val.size(0)
                output_tags = torch.argmax(output, dim=1)
                targets = y_val
                correct += (output_tags == targets).sum().item()
                total += y_val.size(0)

        # Print validation statistics
        epoch_val_loss = validation_loss / len(valloader.dataset)
        epoch_val_acc = 100. * correct / total
        print(f'Epoch [{epoch + 1}] Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.2f}%')

        # Store the validation loss and validation accuracy
        validation_losses.append(epoch_val_loss)
        validation_accuracies.append(epoch_val_acc)

    # Testing
    with torch.no_grad():
        y_pred = []
        y_true = []
        # simple accuracy as above
        for x_test, y_test in testloader:
            x_test = x_test.unsqueeze(1)
            test_output = model(x_test)
            y_pred += torch.argmax(test_output, dim=1).tolist()
            y_true += y_test.tolist()
        report_dict = metrics.classification_report(y_true, y_pred, digits=3)
        print(report_dict)

    # Project settings for plots
    plt.style.use('bmh')
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = 'UGent Panno Text'
    plt.rcParams['font.monospace'] = 'UGent Panno Text'
    plt.rcParams['font.size'] = 10
    plt.rcParams['axes.labelsize'] = 10
    plt.rcParams['axes.labelweight'] = 'bold'
    plt.rcParams['axes.titlesize'] = 10
    plt.rcParams['xtick.labelsize'] = 8
    plt.rcParams['ytick.labelsize'] = 8
    plt.rcParams['legend.fontsize'] = 10
    plt.rcParams['figure.titlesize'] = 12

    plt.plot(training_losses, label='Training', color='#1E64C8', linewidth=1)
    plt.plot(validation_losses, label='Validation', color='black', linewidth=1)
    plt.title(f'Training and Validation Loss with FCN on {tax} level with {kmer}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (in %)')
    plt.legend()
    plt.savefig(f'CNNK{tax}{kmer}Loss.svg')
    plt.show()

    plt.plot(training_accuracies, label='Training', color='#1E64C8', linewidth=1)
    plt.plot(validation_accuracies, label='Validation', color='black', linewidth=1)
    plt.title(f'Training and Validation Accuracy FCN on {tax} level with {kmer}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (in %)')
    plt.legend()
    plt.savefig(f'{tax}{kmer}Accuracy.svg')
    plt.show()
    print(f'Training, validation and testing on {tax} level with {kmer} is completed.')
    input('Continue?')


In [None]:
from defCNN import CNNK

# 3 mer
CNNK('3mer', 'phylum', 100)
CNNK('3mer', 'class', 100)
CNNK('3mer', 'order', 100)
CNNK('3mer', 'family', 100)
CNNK('3mer', 'genus', 100)

# 4 mer
CNNK('4mer', 'phylum', 100)
CNNK('4mer', 'class', 100)
CNNK('4mer', 'order', 100)
CNNK('4mer', 'family', 100)
CNNK('4mer', 'genus', 100)

# 5 mer
CNNK('5mer', 'phylum', 100)
CNNK('5mer', 'class', 100)
CNNK('5mer', 'order', 100)
CNNK('5mer', 'family', 100)
CNNK('5mer', 'genus', 100)