In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset, random_split, SubsetRandomSampler
import torchvision
import torchvision.transforms as transforms


import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import KFold
import seaborn as sns

from utils import CNN

# **K-fold Validation**

In [5]:
transforms = transforms.Compose([
        transforms.Resize([120,120]), # the image are not all equal in size, we need to define a standard size
        transforms.ToTensor() # converts images to torch Tensors and rescales pixels to [0;1] 
    ])

train_dataset = torchvision.datasets.GTSRB(
    root='.',
    split='train',
    transform=transforms,
    download=True
)

test_dataset = torchvision.datasets.GTSRB(
    root='.',
    split='test',
    transform=transforms,
    download=True
)

In [10]:
torch.manual_seed(17)
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
splits=KFold(n_splits=4,shuffle=True,random_state=17)

In [12]:
def train_epoch(model,device,dataloader,criterion,optimizer):
    train_loss,train_correct=0.0,0
    model.train()
    for inputs, targets in dataloader:

        inputs,targets = inputs.to(device),targets.to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output,targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.size(0)
        _, predictions = torch.max(output.data, 1)
        train_correct += (predictions == targets).sum().item()

    return train_loss,train_correct
  
def valid_epoch(model,device,dataloader,criterion):
    valid_loss, val_correct = 0.0, 0
    model.eval()
    for inputs, targets in dataloader:

        inputs,targets = inputs.to(device),targets.to(device)
        output = model(inputs)
        loss = criterion(output,targets)
        valid_loss+= loss.item()*inputs.size(0)
        _, predictions = torch.max(output.data,1)
        val_correct+= (predictions == targets).sum().item()

    return valid_loss,val_correct

In [14]:
history = {'train_loss': [], 'test_loss': [],'train_acc':[],'test_acc':[]}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
criterion = nn.CrossEntropyLoss()
BATCH_SIZE = 128

cpu


In [13]:
for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):

    print(f'Fold: {fold+1}')

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
    
    model = CNN(43)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(10):
        train_loss, train_correct=train_epoch(model,device,train_loader,criterion,optimizer)
        test_loss, test_correct=valid_epoch(model,device,test_loader,criterion)

        train_loss = train_loss / len(train_loader.sampler)
        train_acc = train_correct / len(train_loader.sampler) * 100
        test_loss = test_loss / len(test_loader.sampler)
        test_acc = test_correct / len(test_loader.sampler) * 100

        print("Epoch:{}/{} AVG Training Loss:{:.3f} AVG Test Loss:{:.3f} AVG Training Acc {:.2f} % AVG Test Acc {:.2f} %".format(epoch + 1,
                                                                                                             10,
                                                                                                             train_loss,
                                                                                                             test_loss,
                                                                                                             train_acc,
                                                                                                             test_acc))
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)  

Fold 1


KeyboardInterrupt: 