# Cross validation in PyTorch

We will briefly see how to implement cross validation in pytorch. For this short tutorial, we will assume you already know the basics on how to train a simple neural network in PyTorch. Using a classifier on the MNIST dataset:

In [None]:
import torchvision, torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from sklearn.model_selection import KFold

In [None]:
class model(nn.Module):
    def __init__(self):
        super(model,self).__init__()
        self.conv1=nn.Conv2d(1,32,3)
        self.pool1=nn.MaxPool2d(2)
        self.act1=nn.ReLU()
        self.conv2=nn.Conv2d(32,64,3)
        self.pool2=nn.MaxPool2d(2)
        self.act2=nn.ReLU()
        self.conv3=nn.Conv2d(64,10,3)
    def forward(self,image):
        A=self.conv1(image)
        A=self.pool1(A)
        A=self.act1(A)
        A=self.conv2(A)
        A=self.pool2(A)
        A=self.act2(A)
        A=self.conv3(A)
        return torch.mean(A,dim=(2,3))
    
from torch.optim import Adam
Bsize=10

We can still use sklearn to perform the split, returning the indices for the training and test sets:

In [None]:
tensorize=torchvision.transforms.ToTensor()

MNIST=torchvision.datasets.MNIST('/home/morning/ramdisk/MNIST',download=True,train=True,transform=tensorize)
splitter = KFold(5,shuffle=True)
train, test = next(splitter.split(MNIST))
print(len(MNIST))
print(test)

Compared to the usual training loop, all we need to do is to introduce a random sampler in the dataloader object. By using **SubsetRandomSampler**, we can at the same time randomize the order in which each sample is presented during training, and limit this choice to a specific subset of the overall dataset. Having defined it, we proceed as normal.

In [None]:
train_sampler = torch.utils.data.SubsetRandomSampler(train)
traindata=torch.utils.data.DataLoader(MNIST,batch_size=Bsize,sampler=train_sampler)
len(traindata)

Finally, the training loop:

In [None]:
epochs=2

splitter = KFold(5,shuffle=True)
accs = []

import tqdm # To display a progress bar

for train, test in splitter.split(range(len(MNIST))): ### To loop over all folds
    print('A new network')
    # initialize a new network for the new fold, and its optimizer
    net=model()
    Loss=nn.modules.loss.CrossEntropyLoss()
    optimizer=Adam(net.parameters(),lr=0.001)
    
    # prepare the data samplers for these train and test indices
    train_sampler = torch.utils.data.SubsetRandomSampler(train)
    test_sampler = torch.utils.data.SubsetRandomSampler(test)
    
    # prepare the dataloaders
    traindata=torch.utils.data.DataLoader(MNIST,batch_size=Bsize,sampler=train_sampler)
    testdata=torch.utils.data.DataLoader(MNIST,batch_size=Bsize,sampler=test_sampler)
    
    for epoch in range(epochs):
        # Perform the training loop as normal, looping over all epochs
        net.train() 

        for sample, label in tqdm.tqdm(traindata,total=len(traindata)):
            # and performing one training step for each sample and label
            # pair generated by the dataloader
            optimizer.zero_grad()
            predicted=net(sample)
            loss=Loss(predicted,label)
            loss.backward()
            optimizer.step()

        # Before starting the next epoch, evaluate on the test data:
        net.eval() 
        with torch.no_grad():
            results=[]
            accuracylist=[]
            for sample, label in testdata:
                predicted=net(sample)
                results.append(Loss(predicted,label).mean().numpy())

                predictedvalues=torch.max(predicted,dim=1)[1]
                correct=predictedvalues==label
                accuracylist.append(correct.float().mean().numpy())
            print('\nloss:',np.mean(results))
            print('accuracy:',np.mean(accuracylist))
    #Record the last validation accuracy results from each fold
    accs.append(np.mean(accuracylist))

# print the final results
print('Average accuracy:',np.mean(accs))