# Imports

In [79]:
# General Imports
from modify_dataset import sanity_check_dimensions
from math import ceil
from tqdm import tqdm
import os
import pandas as pd
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import WeightedRandomSampler

# PyTorch Geometric
import torch_geometric
from torch_geometric.data.dataset import Dataset
from torch_geometric.nn import GCNConv, global_mean_pool

# Sets the seed for generating random numbers in PyTorch, numpy and Python.
torch_geometric.seed_everything(42)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Plotting Libraries
import plotly.express as px

In [80]:
print(f"Torch Version: {torch.__version__}")
print(f"Cuda Available: {torch.cuda.is_available()}")
print(f"Torch Geometric Version: {torch_geometric.__version__}")

Torch Version: 1.13.0
Cuda Available: True
Torch Geometric Version: 2.1.0


In [81]:
feature_selection_columns = np.load("Dataset_Files/Feature_Selection/features_dd_psd_list.npy",
                                    allow_pickle=True)
unique_proteins_list = np.load("Dataset_Files/Unique_Proteins_List.npy",
                               allow_pickle=True)

# Contact Maps Loading & Plotting

In [82]:
contact_map_A0A0A0MRZ7 = np.load("Dataset_Files/Protein_Graph_Data/raw/Contact_Map_Files/A0A0A0MRZ7.npy")
fig = px.imshow(contact_map_A0A0A0MRZ7, color_continuous_scale=["white", "black"])

fig.show()

# Dimensions Sanity Check

In [83]:
sanity_check_dimensions("A4D1B5", print_information=True)

Contact Map Shape: (854, 854)
Amino Acid Descriptors Shape: (854, 66)
PSSM Shape: (854, 20)
UniProt Embedding Shape: (854, 1024)


True

In [84]:
if os.path.exists("Dataset_Files/Protein_Graph_Data/proteins_with_wrong_dimensions.npy"):
    proteins_with_wrong_dimensions = np.load("Dataset_Files/Protein_Graph_Data/proteins_with_wrong_dimensions.npy")
else:
    proteins_with_wrong_dimensions = []
    for protein_accession in unique_proteins_list:
        if not sanity_check_dimensions(protein_accession):
            proteins_with_wrong_dimensions.append(protein_accession)
    np.save("Dataset_Files/Protein_Graph_Data/proteins_with_wrong_dimensions",
            np.asarray(proteins_with_wrong_dimensions))

print(f"{len(proteins_with_wrong_dimensions)} proteins with wrong dimensions!")

41 proteins with wrong dimensions!


# Training & Test Sets (Drug Descriptors and Protein Sequence Descriptors)

In [85]:
X_train = np.load("Dataset_Files/Training_Test_Sets/X_train_dd_psd_feature_selection.npy")
X_train_accession_list = np.load("Dataset_Files/Training_Test_Sets/X_train_dd_psd_accession.npy", allow_pickle=True)
y_train = np.load("Dataset_Files/Training_Test_Sets/y_train_dd_psd.npy")

X_train_dataframe = pd.DataFrame(X_train, columns=feature_selection_columns)
X_train_accession_list_dataframe = pd.DataFrame(X_train_accession_list, columns=["Protein_Accession"])
y_train_dataframe = pd.DataFrame(y_train, columns=["Activity_Binary"])
train_dataframe = pd.concat([X_train_accession_list_dataframe, y_train_dataframe, X_train_dataframe], axis=1)

X_test_classification = np.load("Dataset_Files/Training_Test_Sets/X_test_classification_dd_psd_feature_selection.npy")
X_test_classification_accession_list = np.load("Dataset_Files/Training_Test_Sets/X_test_classification_accession.npy",
                                               allow_pickle=True)
y_test_classification = np.load("Dataset_Files/Training_Test_Sets/y_test_classification.npy")

X_test_classification_dataframe = pd.DataFrame(X_test_classification, columns=feature_selection_columns)
X_test_classification_accession_list = pd.DataFrame(X_test_classification_accession_list, columns=["Protein_Accession"])
y_test_dataframe = pd.DataFrame(y_test_classification, columns=["Activity_Binary"])
test_dataframe = pd.concat([X_test_classification_accession_list, y_test_dataframe, X_test_classification_dataframe],
                           axis=1)

In [86]:
train_dataframe = train_dataframe[train_dataframe["Protein_Accession"].isin(proteins_with_wrong_dimensions) == False]
train_dataframe.reset_index(inplace=True, drop=True)
train_dataframe

Unnamed: 0,Protein_Accession,Activity_Binary,MolecularWeight,XLogP,ExactMass,MonoisotopicMass,TPSA,Complexity,HBondDonorCount,HBondAcceptorCount,...,Tripeptide_Composition_PCA_Component_2598,Tripeptide_Composition_PCA_Component_2600,Tripeptide_Composition_PCA_Component_2604,Tripeptide_Composition_PCA_Component_2605,Tripeptide_Composition_PCA_Component_2607,Tripeptide_Composition_PCA_Component_2614,Tripeptide_Composition_PCA_Component_2615,Tripeptide_Composition_PCA_Component_2617,Tripeptide_Composition_PCA_Component_2619,Tripeptide_Composition_PCA_Component_2620
0,A0AVT1,1,445.50,1.900391,445.25,445.25,139.00000,734.0,3.0,9.0,...,-0.245361,1.178711,0.026733,-0.682129,-0.876465,-1.202148,-0.278809,0.908203,-0.824707,-1.632812
1,A0AVT1,0,500.50,6.898438,500.25,500.25,41.59375,782.0,1.0,4.0,...,-0.245361,1.178711,0.026733,-0.682129,-0.876465,-1.202148,-0.278809,0.908203,-0.824707,-1.632812
2,A0AVT1,1,462.50,-0.399902,462.25,462.25,183.00000,770.0,4.0,11.0,...,-0.245361,1.178711,0.026733,-0.682129,-0.876465,-1.202148,-0.278809,0.908203,-0.824707,-1.632812
3,Q13564,1,445.50,1.900391,445.25,445.25,139.00000,734.0,3.0,9.0,...,0.500000,1.199219,0.107544,0.381104,-1.268555,0.176147,-1.222656,0.448486,0.339111,-0.546875
4,Q13564,1,500.50,6.898438,500.25,500.25,41.59375,782.0,1.0,4.0,...,0.500000,1.199219,0.107544,0.381104,-1.268555,0.176147,-1.222656,0.448486,0.339111,-0.546875
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
133673,Q9Y6K0,1,416.50,4.398438,416.25,416.25,84.62500,604.0,2.0,5.0,...,-0.209839,-1.358398,0.013573,0.680176,-0.179199,-1.126953,-1.352539,-0.505859,0.071777,-0.205322
133674,Q9Y6K0,1,481.25,4.898438,480.00,480.00,84.62500,600.0,2.0,5.0,...,-0.209839,-1.358398,0.013573,0.680176,-0.179199,-1.126953,-1.352539,-0.505859,0.071777,-0.205322
133675,Q9Y6K0,1,455.50,4.601562,455.25,455.25,108.00000,745.0,2.0,6.0,...,-0.209839,-1.358398,0.013573,0.680176,-0.179199,-1.126953,-1.352539,-0.505859,0.071777,-0.205322
133676,Q9Y6M7,0,528.00,7.101562,527.00,527.00,81.87500,922.0,0.0,4.0,...,-0.386719,0.953613,0.214233,0.876465,-0.258789,-1.447266,-0.364990,1.380859,-0.339355,-0.846680


In [87]:
print(f"Training entries lost due to wrong dimensions: {X_train.shape[0] - train_dataframe.shape[0]}")

Training entries lost due to wrong dimensions: 1056


In [88]:
train_binding_count = train_dataframe[train_dataframe['Activity_Binary'] == 1]
train_non_binding_count = train_dataframe[train_dataframe['Activity_Binary'] == 0]

print("Training Set Class Information:")
print(f"Binding Count: {train_binding_count.shape[0]}")
print(f"Non-Binding Count: {train_non_binding_count.shape[0]}")
print(f"Class Imbalance: {ceil(train_binding_count.shape[0] / train_non_binding_count.shape[0])}:1")

Training Set Class Information:
Binding Count: 98002
Non-Binding Count: 35676
Class Imbalance: 3:1


In [89]:
test_dataframe = test_dataframe[test_dataframe["Protein_Accession"].isin(proteins_with_wrong_dimensions) == False]
test_dataframe.reset_index(inplace=True, drop=True)
test_dataframe

Unnamed: 0,Protein_Accession,Activity_Binary,MolecularWeight,XLogP,ExactMass,MonoisotopicMass,TPSA,Complexity,HBondDonorCount,HBondAcceptorCount,...,Tripeptide_Composition_PCA_Component_2598,Tripeptide_Composition_PCA_Component_2600,Tripeptide_Composition_PCA_Component_2604,Tripeptide_Composition_PCA_Component_2605,Tripeptide_Composition_PCA_Component_2607,Tripeptide_Composition_PCA_Component_2614,Tripeptide_Composition_PCA_Component_2615,Tripeptide_Composition_PCA_Component_2617,Tripeptide_Composition_PCA_Component_2619,Tripeptide_Composition_PCA_Component_2620
0,O15382,1,262.250,3.300781,262.000,262.000,88.37500,299.0,1.0,5.0,...,0.161377,-1.422852,-0.381348,-0.764648,-0.824219,0.133667,0.286377,0.620605,-1.066406,0.334717
1,Q9BY41,1,624.000,4.398438,623.500,623.500,139.00000,1050.0,3.0,6.0,...,0.288574,-0.579590,0.185791,0.890625,-0.728027,0.289551,-0.427734,-0.304443,0.052032,-0.346436
2,Q3MIX3,0,362.500,2.300781,362.000,362.000,125.00000,495.0,4.0,5.0,...,0.445312,-0.116211,1.637695,2.212891,2.542969,0.291748,0.642578,0.342285,0.619141,0.274902
3,P19099,1,235.250,3.099609,235.125,235.125,30.00000,347.0,0.0,2.0,...,-0.346436,-0.118225,-0.298584,0.777832,0.364746,0.542480,0.560059,-0.350586,0.168091,0.359619
4,O75365,1,406.250,5.000000,405.000,405.000,95.68750,491.0,1.0,4.0,...,0.744629,0.598145,0.088074,0.291748,0.473633,0.034180,0.250000,0.411621,-0.345215,0.299072
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29918,Q99814,1,296.500,3.800781,296.250,296.250,43.40625,571.0,0.0,3.0,...,0.511230,1.780273,0.210815,0.943359,0.865723,0.325928,0.522949,-0.003019,0.089233,0.430664
29919,Q9Y4K4,0,407.500,3.000000,407.250,407.250,83.62500,760.0,2.0,4.0,...,-0.442627,-1.092773,0.530762,0.792969,0.222168,-1.280273,-0.595703,0.897461,-0.779785,1.140625
29920,Q9BQB6,1,342.500,4.898438,342.250,342.250,63.59375,528.0,1.0,4.0,...,0.251221,0.152588,-0.385010,-0.333252,0.456543,-0.207397,-0.259277,-0.126953,-0.151855,0.300293
29921,P28062,1,441.500,2.699219,441.250,441.250,148.00000,681.0,2.0,8.0,...,-0.473389,0.010246,0.143677,-0.155640,0.595215,-0.091553,-0.316650,-0.113159,-0.009079,-0.826660


In [90]:
print(f"Testing entries lost due to wrong dimensions: {X_test_classification.shape[0] - test_dataframe.shape[0]}")

Testing entries lost due to wrong dimensions: 218


In [91]:
test_binding_count = test_dataframe[test_dataframe['Activity_Binary'] == 1]
test_non_binding_count = test_dataframe[test_dataframe['Activity_Binary'] == 0]

print("Testing Set Class Information:")
print(f"Binding Count: {test_binding_count.shape[0]}")
print(f"Non-Binding Count: {test_non_binding_count.shape[0]}")
print(f"Class Imbalance: {ceil(test_binding_count.shape[0] / test_non_binding_count.shape[0])}:1")

Testing Set Class Information:
Binding Count: 21835
Non-Binding Count: 8088
Class Imbalance: 3:1


In [92]:
# No reason to keep these variables
X_train = None
X_train_accession_list = None
y_train = None
X_test_classification = None
X_test_classification_accession_list = None
y_test_classification = None

# Weighted Random Sampler
To balance our batches, given the clear imbalance between the classes

In [93]:
# Reference
# https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/

train_class_counts = train_dataframe["Activity_Binary"].value_counts(ascending=True).to_numpy()

weights = 1. / train_class_counts

weights_all = np.array([weights[t] for t in train_dataframe["Activity_Binary"].astype(int)])
weights_all = torch.from_numpy(weights_all)

print(weights)
print(weights_all)
print(len(weights_all))

[2.80300482e-05 1.02038734e-05]
tensor([1.0204e-05, 2.8030e-05, 1.0204e-05,  ..., 1.0204e-05, 2.8030e-05,
        2.8030e-05], dtype=torch.float64)
133678


# Dataset Class & Dataloaders

In [94]:
class MyDataset(Dataset):
    def __init__(self, root, dataframe, unique_proteins_accessions_list=None, transform=None, pre_transform=None,
                 pre_filter=None):
        self.dataframe = dataframe
        self.unique_proteins_accessions_list = unique_proteins_accessions_list
        super(MyDataset, self).__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return os.listdir(f"{self.root}/raw")

    @property
    def processed_file_names(self):
        return os.listdir(f"{self.root}/processed")

    def download(self):
        pass

    def process(self):
        unique_proteins_accessions_list = self.unique_proteins_accessions_list

        print("Creating Protein Graphs")
        for i in tqdm(range(len(unique_proteins_accessions_list))):
            accession = unique_proteins_accessions_list[i]
            if sanity_check_dimensions(accession):
                amino_acid_descriptors = np.load(
                    f"{self.root}/raw/Amino_Acid_Descriptors_And_PSSM/{accession}_Descriptors.npy")
                pssm = np.load(f"{self.root}/raw/Amino_Acid_Descriptors_And_PSSM/{accession}_PSSM.npy")
                uniprot_embedding = np.load(f"{self.root}/raw/Amino_Acid_Embeddings/{accession}.npy")

                amino_acid_features = np.hstack((amino_acid_descriptors, pssm, uniprot_embedding))

                contact_map = np.load(f"{self.root}/raw/Contact_Map_Files/{accession}.npy")
                contact_map_sparse_representation = []
                index_row, index_col = np.where(contact_map == 1)
                for row, column in zip(index_row, index_col):
                    contact_map_sparse_representation.append([row, column])
                contact_map_sparse_representation = np.array(contact_map_sparse_representation)

                data = torch_geometric.data.Data(x=torch.Tensor(amino_acid_features),
                                                 edge_index=torch.LongTensor(
                                                     contact_map_sparse_representation).transpose(1, 0),
                                                 size=contact_map.shape[0],
                                                 accession=accession)

                torch.save(data, f"{self.root}/processed/protein_graph_{accession}.pt")

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        dataframe = self.dataframe

        accession = dataframe.loc[idx, "Protein_Accession"]
        protein_graph = torch.load(f"{self.root}/processed/protein_graph_{accession}.pt")

        drug_and_sequence_descriptors = torch.tensor(dataframe.loc[idx, "MolecularWeight":])
        label = torch.tensor(self.dataframe.loc[idx, "Activity_Binary"])

        return drug_and_sequence_descriptors, protein_graph, label

In [95]:
BATCH_SIZE = 64

sampler = WeightedRandomSampler(weights_all, len(weights_all))

trainloader = torch_geometric.loader.DataLoader(
    MyDataset(root="Dataset_Files/Protein_Graph_Data",
              dataframe=train_dataframe,
              unique_proteins_accessions_list=unique_proteins_list),
    batch_size=BATCH_SIZE,
    sampler=sampler)

testloader = torch.utils.data.DataLoader(
    MyDataset(root="Dataset_Files/Protein_Graph_Data",
              dataframe=test_dataframe,
              unique_proteins_accessions_list=unique_proteins_list),
    batch_size=BATCH_SIZE,
    shuffle=True)

In [96]:
for drug_and_sequence_descriptors, protein_graphs, labels in trainloader:
    print(drug_and_sequence_descriptors.shape),
    print(protein_graphs)
    print(labels.shape)
    break

torch.Size([64, 1044])
DataBatch(x=[34411, 1110], edge_index=[2, 542225], size=[64], accession=[64], batch=[34411], ptr=[65])
torch.Size([64])


# Model

In [97]:
class Model(nn.Module):
    def __init__(self, n_output=1, dd_psd_features=1044, amino_acid_features=1110, output_dim=128, dropout=0.2):
        super(Model, self).__init__()

        # Protein Structure Layers
        self.pro_conv1 = GCNConv(amino_acid_features, amino_acid_features)
        self.pro_batch_normalisation1 = nn.BatchNorm2d(amino_acid_features)

        self.pro_conv2 = GCNConv(amino_acid_features, amino_acid_features * 2)
        self.pro_batch_normalisation2 = nn.BatchNorm2d(amino_acid_features * 2)

        self.pro_conv3 = GCNConv(amino_acid_features * 2, amino_acid_features * 4)
        self.pro_batch_normalisation3 = nn.BatchNorm2d(amino_acid_features * 4)

        self.pro_fc1 = nn.Linear(amino_acid_features * 4, 1024)
        self.pro_fc2 = nn.Linear(1024, output_dim)

        # Drugs Descriptors & Protein Sequence Descriptors Layers
        self.dd_psd_fc1 = nn.Linear(dd_psd_features, dd_psd_features)
        self.dd_psd_batch_normalisation1 = nn.BatchNorm2d(dd_psd_features)

        self.dd_psd_fc2 = nn.Linear(dd_psd_features, dd_psd_features * 2)
        self.dd_psd_batch_normalisation2 = nn.BatchNorm2d(dd_psd_features * 2)

        self.dd_psd_fc3 = nn.Linear(dd_psd_features * 2, dd_psd_features * 4)
        self.dd_psd_batch_normalisation3 = nn.BatchNorm2d(dd_psd_features * 4)

        self.dd_psd_fc4 = nn.Linear(dd_psd_features * 4, 1024)
        self.dd_psd_fc5 = nn.Linear(1024, output_dim)

        # Combined Layers
        self.combined_fc1 = nn.Linear(2 * output_dim, 1024)
        self.combined_fc1 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, n_output)

        # Other Layers
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, dd_psd_descriptors, protein_graph):
        # Protein Graphs
        x_protein = self.pro_conv1(protein_graph.x, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation1(x_protein)

        x_protein = self.pro_conv2(x_protein, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation2(x_protein)

        x_protein = self.pro_conv3(x_protein, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation3(x_protein)

        x_protein = global_mean_pool(x_protein, protein_graph.batch)

        # Flatten
        x_protein = self.pro_fc1(x_protein)
        x_protein = self.relu(x_protein)
        x_protein = self.dropout(x_protein)

        x_protein = self.pro_fc2(x_protein)
        x_protein = self.dropout(x_protein)

        # Drug Descriptors & Protein Sequence Descriptors
        x_dd_psd = self.dd_psd_fc1(dd_psd_descriptors)
        x_dd_psd = self.relu(x_dd_psd)
        x_dd_psd = self.dd_psd_batch_normalisation1(x_dd_psd)

        x_dd_psd = self.dd_psd_fc2(x_dd_psd)
        x_dd_psd = self.relu(x_dd_psd)
        x_dd_psd = self.dd_psd_batch_normalisation2(x_dd_psd)

        x_dd_psd = self.dd_psd_fc3(x_dd_psd)
        x_dd_psd = self.relu(x_dd_psd)
        x_dd_psd = self.dd_psd_batch_normalisation3(x_dd_psd)

        x_dd_psd = self.dd_psd_fc4(x_dd_psd)
        x_dd_psd = self.relu(x_dd_psd)

        x_dd_psd = self.dd_psd_fc5(x_dd_psd)
        x_dd_psd = self.dropout(x_dd_psd)

        # Combine
        x_combined = torch.cat((x_protein, x_dd_psd), 1)
        x_combined = self.combined_fc1(x_combined)
        x_combined = self.relu(x_combined)
        x_combined = self.dropout(x_combined)

        x_combined = self.combined_fc2(x_combined)
        x_combined = self.relu(x_combined)
        x_combined = self.dropout(x_combined)
        out = self.out(x_combined)

        return out

# Training

In [98]:
epoch_print_gap = 1


def training_loop(n_epochs, optimizer, model, device, loss_fn, train_loader):
    model = model.to(device)

    loss_list = []
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for drug_and_sequence_descriptors, protein_graphs, labels in train_loader:
            drug_and_sequence_descriptors = drug_and_sequence_descriptors.to(device)
            protein_graphs = protein_graphs.to(device)
            outputs = model(drug_and_sequence_descriptors, protein_graphs)
            loss = loss_fn(outputs, labels.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_train += loss.item()

        if epoch == 1 or epoch % epoch_print_gap == 0:
            loss_list.append(loss_train)
            print(f"Epoch: {epoch}, Training Loss: {loss_train}")

    torch.save(model.state_dict(), "model_state_dict.pt")
    return loss_list

In [99]:
model = Model()

n_epochs = 250
lr_rate = 0.001
weight_decay = 0.0001

optimizer = optim.Adam(model.parameters(), lr=lr_rate, weight_decay = weight_decay)

loss_fn = nn.CrossEntropyLoss()

loss_list = training_loop(n_epochs = n_epochs,
                          optimizer = optimizer,
                          model = model,
                          device = device,
                          loss_fn = loss_fn,
                          train_loader = trainloader)


# fig = px.line(x=list(i for i in range(1,n_epochs+1)),
#                   y=loss_list,
#                   title="Loss Over Epochs",
#                  labels={
#                      "x": "Epoch",
#                      "y": "Cross Entropy Loss"})
# fig.write_image("loss.png")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.60 GiB (GPU 0; 4.00 GiB total capacity; 3.08 GiB already allocated; 0 bytes free; 3.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [53]:
# from IPython.display import Image
# Image(filename='/content/Classifying_Plankton/loss.png')

# Testing

In [54]:
# label_to_class = dict((v,k) for k,v in data['names'].items())
#
# def test_loop(model, device, test_loader, class_names):
#
#     correct_pred = {classname: 0 for classname in class_names}
#     total_pred = {classname: 0 for classname in class_names}
#
#     model.eval()
#     model = model.to(device)
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()
#
#             for i in range(len(target)):
#                 total_pred[label_to_class[target[i].item()]] += 1
#
#                 if target[i].item() == pred[i].item():
#                     correct_pred[label_to_class[target[i].item()]] += 1
#
#     test_loss /= len(test_loader.dataset)
#
#     print(f"Test set: Average Loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%) \n")
#
#     for class_name, correct_predictions in correct_pred.items():
#         class_accuracy = 100. * correct_predictions / total_pred[class_name]
#         print(f'Accuracy for class: {class_name}: {class_accuracy:.0f}%')

In [55]:
# model = CNN_Model()
# model.load_state_dict(torch.load("/content/Classifying_Plankton/model_state_dict.pt"))
#
# test_loop(model = model, device = device, test_loader = testloader, class_names = data['names'].keys())