In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader,random_split,Dataset, ConcatDataset ,SubsetRandomSampler 
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torchinfo import summary
from tqdm import tqdm
from sklearn.model_selection import KFold

In [None]:
seed = 129
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

num_epochs = 10
batch_size = 32
device = 'cuda'

In [None]:
train_dataset = torchvision.datasets.CIFAR100('./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR100('./data', train=False, download=True, transform=transforms.ToTensor())

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, epochs):
    model.train()
    train_loss, total_samples, total_correct = 0,0,0
    
    for i,(imgs, labels) in enumerate(tqdm(train_loader)):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * imgs.size(0)
        _, predicted = torch.max(outputs.data,1)

        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        accuracy = (total_correct/total_samples)*100

    print("Trainining Epoch: [{}]  loss: [{:.2f}] Accuracy :[{:.2f}]".format(epochs+1, train_loss/len(train_loader), accuracy))

    return train_loss/len(train_loader), accuracy

In [None]:
def test_model(model, test_loader, criterion,  device, epochs):
    model.eval()
    test_loss, total_samples, total_correct = 0,0,0
    with torch.no_grad():
        for i,(imgs, labels) in enumerate(tqdm(test_loader)):
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * imgs.size(0)
            _, predicted = torch.max(outputs.data,1)

            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            accuracy = (total_correct/total_samples)*100

    print("Testing Epoch: [{}]  loss: [{:.2f}] Accuracy :[{:.2f}]".format(epochs+1, test_loss/len(test_loader), accuracy))

    return test_loss/len(test_loader), accuracy

In [None]:
k=3
splits=KFold(n_splits=k,shuffle=True,random_state=0)
per_fold_result={}

In [None]:
for i, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(train_dataset)))):

    print("Fold no.{}:".format(i + 1))

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(val_idx)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers = 4)
    valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers = 4)
    
    model = torchvision.models.resnet18().to('cuda')
    model.fc = torch.nn.Linear(in_features=512, 
                    out_features=100, 
                    bias=True).to('cuda')
          
    optimizer = torch.optim.SGD(model.parameters(),lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    history = {'train_loss': [], 'valid_loss': [],'train_acc':[],'valid_acc':[]}

    for epoch in range(num_epochs):
        train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device, epoch)
        valid_loss, valid_acc = test_model(model, valid_loader, criterion, device, epoch)
            
        history['train_loss'].append(train_loss)
        history['valid_loss'].append(valid_loss)
        history['train_acc'].append(train_acc)
        history['valid_acc'].append(valid_acc)

    per_fold_result['fold{}'.format(i+1)] = history  

In [None]:
plt.figure(figsize=(8,5))
plt.subplot(121)
plt.plot(per_fold_result['fold1']['train_loss'],label='Fold 1 Training Loss')
plt.plot(per_fold_result['fold1']['valid_loss'],label='Fold 1 Validation Loss')
plt.plot(per_fold_result['fold2']['train_loss'],label='Fold 2 Training Loss')
plt.plot(per_fold_result['fold2']['valid_loss'],label='Fold 2 Validation Loss')
plt.plot(per_fold_result['fold3']['train_loss'],label='Fold 3 Training Loss')
plt.plot(per_fold_result['fold3']['valid_loss'],label='Fold 3 Validation Loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.subplot(122)
plt.plot(per_fold_result['fold1']['train_acc'],label='Fold 1 Training Accuracy')
plt.plot(per_fold_result['fold1']['valid_acc'],label='Fold 1 Validation Accuracy')
plt.plot(per_fold_result['fold2']['train_acc'],label='Fold 2 Training Accuracy')
plt.plot(per_fold_result['fold2']['valid_acc'],label='Fold 2 Validation Accuracy')
plt.plot(per_fold_result['fold3']['train_acc'],label='Fold 3 Training Accuracy')
plt.plot(per_fold_result['fold3']['valid_acc'],label='Fold 3 Validation Accuracy')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.show()