# HIV Activity Classification
This is the main HW notebook for the HIV classification problem

In [70]:
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.utils import scatter
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm


# TODO: ADD DEPENDENCY INSTALLATION FOR COLAB

In [21]:
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
#root_folder = colab_root_folder = os.getcwd() # for when we get on colab

path = "hiv.csv"
data_pd = pd.read_csv(path)
data_pd

Unnamed: 0,smiles,HIV_active
0,CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)...,0
1,C(=Cc1ccccc1)C1=[O+][Cu-3]2([O+]=C(C=Cc3ccccc3...,0
2,CC(=O)N1c2ccccc2Sc2c1ccc1ccccc21,0
3,Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c1,0
4,O=S(=O)(O)CCS(=O)(=O)O,0
...,...,...
41122,CCC1CCC2c3c([nH]c4ccc(C)cc34)C3C(=O)N(N(C)C)C(...,0
41123,Cc1ccc2[nH]c3c(c2c1)C1CCC(C(C)(C)C)CC1C1C(=O)N...,0
41124,Cc1ccc(N2C(=O)C3c4[nH]c5ccccc5c4C4CCC(C(C)(C)C...,0
41125,Cc1cccc(N2C(=O)C3c4[nH]c5ccccc5c4C4CCC(C(C)(C)...,0


In [22]:
data_arr = np.array(data_pd)
data_arr.shape

(41127, 2)

In [23]:
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """
    if x not in permitted_list:
        x = permitted_list[-1]
    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]
    return binary_encoding

In [24]:
def get_atom_features(atom, 
                      use_chirality = True, 
                      hydrogens_implicit = True):
    """
    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
    """

    # define list of permitted atoms
    
    permitted_list_of_atoms =  ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca',
                                'Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag',
                                'Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd',
                                'In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown']
    
    if hydrogens_implicit == False:
        permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
    
    # compute atom features
    
    atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
    
    n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
    
    formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
    
    hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    
    is_in_a_ring_enc = [int(atom.IsInRing())]
    
    is_aromatic_enc = [int(atom.GetIsAromatic())]
    
    atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
    
    vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
    
    covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]

    atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
                                    
    if use_chirality == True:
        chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        atom_feature_vector += chirality_type_enc
    
    if hydrogens_implicit == True:
        n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
        atom_feature_vector += n_hydrogens_enc

    return np.array(atom_feature_vector)

In [25]:
def get_bond_features(bond, 
                      use_stereochemistry = True):
    """
    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
    """

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)
    
    bond_is_conj_enc = [int(bond.GetIsConjugated())]
    
    bond_is_in_ring_enc = [int(bond.IsInRing())]
    
    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc
    
    if use_stereochemistry == True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

In [28]:
def create_pytorch_geom(x_smiles, y):
    """
    Inputs:
    
    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings
    y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values)
    
    Outputs:
    
    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning
    
    """
    
    data_list = []
    
    for (smiles, y_val) in zip(x_smiles, y):
        
        
        # convert SMILES to RDKit mol object
        mol = Chem.MolFromSmiles(smiles)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        # creating of O2 is just to get the number n_node/edge_features for creating tensors.
        # Barzilay might have more elegant code, but this is good enough.
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)
            
        X = torch.tensor(X, dtype = torch.float)
        
        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol)) # atoms that are adjacent
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)
        
        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))
        
        for (k, (i,j)) in enumerate(zip(rows, cols)):
            
            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
        
        EF = torch.tensor(EF, dtype = torch.float)
        
        # construct label tensor
        y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
        
        # construct Pytorch Geometric data object and append to data list
        # NOTE, if only want to incorporate node info, can turn off edge_attr
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

    return data_list

In [27]:
dataset = create_pytorch_geom(data_arr[:, 0], data_arr[:, 1])
#TODO, need to make dataset labels a 1-hot vector

 ... (more hidden) ...


In [55]:
count = 0
for i, d in enumerate(dataset):
    if d.y == 1:
        print (i, d.y)
    count += 1
    if count == 4000:
        break

11 tensor([1.])
16 tensor([1.])
80 tensor([1.])
203 tensor([1.])
234 tensor([1.])
235 tensor([1.])
244 tensor([1.])
271 tensor([1.])
279 tensor([1.])
326 tensor([1.])
352 tensor([1.])
353 tensor([1.])
361 tensor([1.])
384 tensor([1.])
387 tensor([1.])
429 tensor([1.])
434 tensor([1.])
443 tensor([1.])
498 tensor([1.])
499 tensor([1.])
654 tensor([1.])
676 tensor([1.])
699 tensor([1.])
740 tensor([1.])
818 tensor([1.])
869 tensor([1.])
879 tensor([1.])
978 tensor([1.])
996 tensor([1.])
1004 tensor([1.])
1059 tensor([1.])
1153 tensor([1.])
1195 tensor([1.])
1196 tensor([1.])
1198 tensor([1.])
1199 tensor([1.])
1200 tensor([1.])
1212 tensor([1.])
1266 tensor([1.])
1277 tensor([1.])
1322 tensor([1.])
1461 tensor([1.])
1482 tensor([1.])
1510 tensor([1.])
1511 tensor([1.])
1528 tensor([1.])
1530 tensor([1.])
1532 tensor([1.])
1533 tensor([1.])
1534 tensor([1.])
1535 tensor([1.])
1571 tensor([1.])
1572 tensor([1.])
1575 tensor([1.])
1576 tensor([1.])
1577 tensor([1.])
1578 tensor([1.])
1579 t

In [73]:
dataset[1].y
len(dataset)

41127

# Graph Convolutional Layer

In [5]:
import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

# the kipf and welling GCN
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # these could be He initialized. 
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        
        # TODO: make blanks out of this one.
        # adj in batch would be a block diagonal matrix??? probably, not...Dataloader takes care of that.
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'


In [None]:
class GCNConv(MessagePassing):
    # Implementation from troch.geometric
    
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        
        
        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [62]:
def global_add_pool(x: torch.Tensor, batch: [torch.Tensor],
                    size: [int] = None) -> torch.Tensor:
    r"""Returns batch-wise graph-level-outputs by adding node features
    across the node dimension, so that for a single graph
    :math:`\mathcal{G}_i` its output is computed by

    .. math::
        \mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n.

    Functional method of the
    :class:`~torch_geometric.nn.aggr.SumAggregation` module.

    Args:
        x (torch.Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
        batch (torch.Tensor, optional): The batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each node to a specific example.
        size (int, optional): The number of examples :math:`B`.
            Automatically calculated if not given. (default: :obj:`None`)
    """
    dim = -1 if x.dim() == 1 else -2

    if batch is None:
        return x.sum(dim=dim, keepdim=x.dim() <= 2)
    size = int(batch.max().item() + 1) if size is None else size
    return scatter(x, batch, dim=dim, dim_size=size, reduce='sum')


In [69]:
class DrugGCN(Module):
    def __init__(self, in_features, hidden_features=[], out_features=1, 
                 activation=nn.ReLU(), pool=global_add_pool):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.activation=activation
        
        gc1 = GraphConvolution(self.in_features, self.hidden_features[0])
        
        self.hidden_layers = []
        for i, h in enumerate(self.hidden_features[1:]):
            self.hidden_layers.append(GraphConvolution(self.hidden_features[i - 1], h))
            
        
        self.pool=global_add_pool
        
        self.out_layer = nn.Linear(self.hidden_features[-1], out_features)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input_graph, edge_index):
        # get adjacency matrix
        # assume precomputed adjacency matrix.
        # implementation based off of kipf and welling
        adj, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        x = self.gc1(input_graph, adj)
        
        for l in self.hidden_layers:
            x = self.activation(l(x, adj))
        
        self.pool(x)
        x = self.out_layer(x) # this generates the logits for BCELoss
        #x = self.sigmoid(self.out_layer(x, adj))
        # TODO, implement test time!!!
        return x  

In [None]:
def train(gnn_model, data):
        # canonical training loop for a Pytorch Geometric GNN model gnn_model
    # create dataloader for training
    dataloader = DataLoader(dataset=data, batch_size = 2**7)
    # define loss function
    loss_function = nn.BCEWithLogitsLoss()
    # define optimiser
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr = 1e-3)
    # loop over 10 training epochs
    for epoch in range(10):
        # set model to training mode
        gnn_model.train()
        # loop over minibatches for training
        for (k, batch) in enumerate(dataloader):
            # compute current value of loss function via forward pass
            output = gnn_model(batch)
            loss_function_value = loss_function(output[:,0], torch.tensor(batch.y, dtype = torch.float32))
            # set past gradient to zero
            optimizer.zero_grad()
            # compute current gradient via backward pass
            loss_function_value.backward()
            # update model weights using gradient and optimisation method
            optimizer.step()

In [None]:
def build_batch(dataset, indices):
    '''
    Helper function for creating a batch during training. Builds a batch 
    of source and target elements from the dataset. See the next cell for 
    when and how it's used. 
    
    Arguments:
        dataset: List[db_element] -- A list of dataset elements
        indices: List[int] -- A list of indices of the dataset to sample
    Returns:
        batch_input: List[List[int]] -- List of tensorized names
        batch_target: List[int] -- List of numerical categories
        batch_indices: List[int] -- List of starting indices of padding
    '''
    # Recover what the entries for the batch are
    batch = [dataset[i] for i in indices]
    batch_input = np.array(list(zip(*batch))[0])
    batch_target = np.array(list(zip(*batch))[1])
    batch_indices = np.array(list(zip(*batch))[2])
    return batch_input, batch_target, batch_indices # lines, categories

def train(model, optimizer, criterion=nn.BCEWithLogitsLoss(), epochs, batch_size, seed):
    model.to(device)
    model.train()
    train_losses = []
    train_accuracies = []
    eval_accuracies = []
    for epoch in range(epochs):
        random.seed(seed + epoch)
        np.random.seed(seed + epoch)
        torch.manual_seed(seed + epoch)
        indices = np.random.permutation(range(len(train_data)))
        n_correct, n_total = 0, 0
        progress_bar = tqdm(range(0, (len(train_data) // batch_size) + 1))
        for i in progress_bar:
            batch = build_batch(train_data, indices[i*batch_size:(i+1)*batch_size])
            (batch_input, batch_target, batch_indices) = batch_to_torch(*batch)
            (batch_input, batch_target, batch_indices) = list_to_device((batch_input, batch_target, batch_indices))

            logits = model(batch_input, batch_indices)
            loss = criterion(logits, batch_target)
            train_losses.append(loss.item())

            predictions = logits.argmax(dim=-1)
            n_correct += (predictions == batch_target).sum().item()
            n_total += batch_target.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                progress_bar.set_description(f"Epoch: {epoch}  Iteration: {i}  Loss: {np.mean(train_losses[-10:])}")
        train_accuracies.append(n_correct / n_total * 100)
        print(f"Epoch: {epoch}  Train Accuracy: {n_correct / n_total * 100}")

        with torch.no_grad():
            indices = list(range(len(test_data)))
            n_correct, n_total = 0, 0
            for i in range(0, (len(test_data) // batch_size) + 1):
                batch = build_batch(test_data, indices[i*batch_size:(i+1)*batch_size])
                (batch_input, batch_target, batch_indices) = batch_to_torch(*batch)
                (batch_input, batch_target, batch_indices) = list_to_device((batch_input, batch_target, batch_indices))

                logits = model(batch_input, batch_indices)
                predictions = logits.argmax(dim=-1)
                n_correct += (predictions == batch_target).sum().item()
                n_total += batch_target.size(0)
            eval_accuracies.append(n_correct / n_total * 100)
            print(f"Epoch: {epoch}  Eval Accuracy: {n_correct / n_total * 100}")
    
    to_save = {
        "history": {
            "train_losses": train_losses,
            "train_accuracies": train_accuracies,
            "eval_accuracies": eval_accuracies,
        },
        "hparams": {
            "hidden_size": hidden_size,
            "num_layers": num_layers,
            "dropout": dropout,
            "optimizer_class": optimizer_class.__name__,
            "lr": lr,
            "batch_size": batch_size,
            "epochs": epochs,
            "seed": seed
        },
        "model": [
            (name, list(param.shape))
            for name, param in rnn_model.named_parameters()
        ]
    }
    return to_save

In [None]:
def train_final(model, optimizer, criterion=nn.BCEWithLogitsLoss(), epochs, batch_size, seed):
    model.to(device)
    model.train()
    train_losses = []
    train_accuracies = []
    eval_accuracies = []
    dataloader = DataLoader(dataset=data, batch_size = 2**7, shuffle=True)
    for epoch in range(epochs):
        random.seed(seed + epoch)
        np.random.seed(seed + epoch)
        torch.manual_seed(seed + epoch)
        #indices = np.random.permutation(range(len(train_data)))
        n_correct, n_total = 0, 0
        #progress_bar = tqdm(range(0, (len(train_data) // batch_size) + 1))
        progress_bar = tqdm(dataloader)
        for i, (x, y) in enumerate(progress_bar):
#             batch = build_batch(train_data, indices[i*batch_size:(i+1)*batch_size])
#             (batch_input, batch_target, batch_indices) = batch_to_torch(*batch)
#             (batch_input, batch_target, batch_indices) = list_to_device((batch_input, batch_target, batch_indices))

            logits = model(x, y)
            loss = criterion(logits, y)
            train_losses.append(loss.item())

            predictions = logits.argmax(dim=-1)
            n_correct += (predictions == y).sum().item()
            n_total += y.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                progress_bar.set_description(f"Epoch: {epoch}  Iteration: {i}  Loss: {np.mean(train_losses[-10:])}")
        train_accuracies.append(n_correct / n_total * 100)
        print(f"Epoch: {epoch}  Train Accuracy: {n_correct / n_total * 100}")

        with torch.no_grad():
            indices = list(range(len(test_data)))
            n_correct, n_total = 0, 0
            for i in range(0, (len(test_data) // batch_size) + 1):
                batch = build_batch(test_data, indices[i*batch_size:(i+1)*batch_size])
                (batch_input, batch_target, batch_indices) = batch_to_torch(*batch)
                (batch_input, batch_target, batch_indices) = list_to_device((batch_input, batch_target, batch_indices))

                logits = model(batch_input, batch_indices)
                predictions = logits.argmax(dim=-1)
                n_correct += (predictions == batch_target).sum().item()
                n_total += batch_target.size(0)
            eval_accuracies.append(n_correct / n_total * 100)
            print(f"Epoch: {epoch}  Eval Accuracy: {n_correct / n_total * 100}")
    
    to_save = {
        "history": {
            "train_losses": train_losses,
            "train_accuracies": train_accuracies,
            "eval_accuracies": eval_accuracies,
        },
        "hparams": {
            "hidden_size": hidden_size,
            "num_layers": num_layers,
            "dropout": dropout,
            "optimizer_class": optimizer_class.__name__,
            "lr": lr,
            "batch_size": batch_size,
            "epochs": epochs,
            "seed": seed
        },
        "model": [
            (name, list(param.shape))
            for name, param in rnn_model.named_parameters()
        ]
    }
    return to_save

In [None]:
def train(data...):
    # Initilize model (ClassificationViT)
    split_idx = int(0.8 * len(data))
    trainloader = DataLoader(dataset=data[:split_idx], batch_size = 2**7, shuffle=True)
    testloader = DataLoader(dataset=data[split_idx:], batch_size=2**7, shuffle=True)
    model = DrugGCN(...)
    # Move model to GPU 
    model.to(torch_device)
    # Create optimizer for the model

    # You may want to tune these hyperparameters to get better performance
    #optimizer = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95), weight_decay=1e-9)
    optimizer = optim.Adam

    total_steps = 0
    num_epochs = 10
    train_logfreq = 10 # maybe 100
    losses = []
    train_acc = []
    all_val_acc = []
    best_val_acc = 0
    loss_fn = nn.BCEWithLogitsLoss()
    epoch_iterator = trange(num_epochs)
    for epoch in epoch_iterator:
        # Train
        data_iterator = tqdm(trainloader)
        for x, y in data_iterator:
            total_steps += 1
            x, y = x.to(torch_device), y.to(torch_device)
            logits = model(x) # TODO: remember to pre-compute the adjacency matrix!!!
            loss = loss_fn(logits, y)
            accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            data_iterator.set_postfix(loss=loss.item(), train_acc=accuracy.item())

            if total_steps % train_logfreq == 0:
                losses.append(loss.item())
                train_acc.append(accuracy.item())

        # Validation
        val_acc = []
        model.eval()
        for x, y in testloader:
            x, y = x.to(torch_device), y.to(torch_device)
            with torch.no_grad():
                logits = model(x)
            accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
            val_acc.append(accuracy.item())
        model.train()

        all_val_acc.append(np.mean(val_acc))
        # Save best model
        if np.mean(val_acc) > best_val_acc:
            best_val_acc = np.mean(val_acc)

        epoch_iterator.set_postfix(val_acc=np.mean(val_acc), best_val_acc=best_val_acc)

    plt.plot(losses)
    plt.title('Train Loss')
    plt.figure()
    plt.plot(train_acc)
    plt.title('Train Accuracy')
    plt.figure()
    plt.plot(all_val_acc)
    plt.title('Val Accuracy')