In [1]:
import numpy as np
import torch
import torch.optim as optim
import pickle
import pandas as pd
import sys
import copy

sys.path.append("..\\..\\src")

from datasets import CV2ImageDataset, dataset_loader
from cifar_model import Net
from model_class import NeuralNet
from run_phase import run_phase
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
if __name__ == '__main__':    
    epochs = 5
    n_splits=2   
    image_height = 32
    image_width = 32
    batch_size = 64
    aug = A.Compose([   
    A.HorizontalFlip(p=0.5),          
    A.Normalize(),            
    ToTensorV2()])

    df = pd.read_csv('data\\train.csv')
    val_df = pd.read_csv('data\\test.csv')
    # Split into folds

    df["fold"] = np.nan
    skf = StratifiedKFold(n_splits=n_splits)
    skf.get_n_splits(df, df.label)
    for fold, (train_index, test_index) in enumerate(skf.split(df, df.label)):
        df.loc[test_index,"fold"]  = int(fold)

    early_stopping_length = 10
    best_model_per_fold = {}
    value_list = list(df.fold.unique())
    for fold in value_list:
        test_df =df.loc[df['fold'] == fold]
        train_df =df.loc[df['fold'] != fold]
        
        train_ds = CV2ImageDataset(train_df, transform=aug, device = device)
        test_ds = CV2ImageDataset(test_df, transform=aug, device = device)
        val_ds = CV2ImageDataset(val_df, transform=aug, device = device)
        
        train_ds_l = dataset_loader(train_ds, batch_size = batch_size)
        test_ds_l = dataset_loader(test_ds, batch_size = batch_size)
        val_ds_l = dataset_loader(val_ds, batch_size = batch_size)

        trainloader = train_ds_l.get_dataloader()
        testloader = test_ds_l.get_dataloader()
        valloader = val_ds_l.get_dataloader() 
        
        net = Net()
        nn_model = NeuralNet(net)
        model = nn_model.get_model()
        model.to(device)
        
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        
        train_phase = run_phase(trainloader,  model,'Train', device, loss = torch.nn.CrossEntropyLoss(), optimizer= optimizer)
        
        test_phase = run_phase(testloader,  model, 'Test', device, loss = torch.nn.CrossEntropyLoss( ))
        val_phase = run_phase(valloader,  model,'Validation', device, loss = torch.nn.CrossEntropyLoss( ))
        
        for epoch in range(epochs):
            print('Current Fold {} and Epoch {}: '.format(int(fold), epoch))
            train_accuracy_meter, train_loss_meter = train_phase.run()
            test_accuracy_meter, test_loss_meter = test_phase.run()
            val_accuracy_meter, val_loss_meter =  val_phase.run()
            
            continue_training = val_accuracy_meter.check_min_value_in_last_elements_of_queue(early_stopping_length)
            save_model_in_fold_flag = val_accuracy_meter.update_fold_on_min_flag()
            if save_model_in_fold_flag:
                best_model_per_fold[fold] = copy.deepcopy(nn_model)

            if not continue_training:
               
               break 

    with open('saved_cifar_model_dictionary.pkl', 'wb') as f:
        pickle.dump(best_model_per_fold, f)

Current Fold 0 and Epoch 0: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.74it/s]


Train Accuracy for epoch : 0.12476
Train Loss for epoch : 0.03597745267868042


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:13<00:00, 28.35it/s]


Test Accuracy for epoch : 0.14832
Test Loss for epoch : 0.035843337697982786


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:07<00:00, 20.00it/s]


Validation Accuracy for epoch : 0.1452
Validation Loss for epoch : 0.0359808274269104
Current Fold 0 and Epoch 1: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:13<00:00, 29.32it/s]


Train Accuracy for epoch : 0.1765319217196022
Train Loss for epoch : 0.03476908751355558


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.81it/s]


Test Accuracy for epoch : 0.24390439525184474
Test Loss for epoch : 0.03248719846184987


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 18.68it/s]


Validation Accuracy for epoch : 0.25010064412238325
Validation Loss for epoch : 0.03256629137457281
Current Fold 0 and Epoch 2: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:13<00:00, 28.44it/s]


Train Accuracy for epoch : 0.27422200834135385
Train Loss for epoch : 0.030996357377309006


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:13<00:00, 28.24it/s]


Test Accuracy for epoch : 0.31011389156239977
Test Loss for epoch : 0.029631766603281377


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:07<00:00, 19.70it/s]


Validation Accuracy for epoch : 0.31199677938808373
Validation Loss for epoch : 0.029697306358199956
Current Fold 0 and Epoch 3: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.79it/s]


Train Accuracy for epoch : 0.33561918511389155
Train Loss for epoch : 0.028718269608477426


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.39it/s]


Test Accuracy for epoch : 0.3683830606352262
Test Loss for epoch : 0.027451809539908126


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 18.70it/s]


Validation Accuracy for epoch : 0.3698671497584541
Validation Loss for epoch : 0.02751926458688174
Current Fold 0 and Epoch 4: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 27.11it/s]


Train Accuracy for epoch : 0.37455887070901506
Train Loss for epoch : 0.026954610855639802


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.98it/s]


Test Accuracy for epoch : 0.38542669233237087
Test Loss for epoch : 0.026220375838706535


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 17.67it/s]


Validation Accuracy for epoch : 0.39049919484702095
Validation Loss for epoch : 0.026237906741155326
Current Fold 1 and Epoch 0: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.89it/s]


Train Accuracy for epoch : 0.1176
Train Loss for epoch : 0.035794912481307985


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 26.01it/s]


Test Accuracy for epoch : 0.1424
Test Loss for epoch : 0.0351770959854126


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 18.73it/s]


Validation Accuracy for epoch : 0.1405
Validation Loss for epoch : 0.03527008845806122
Current Fold 1 and Epoch 1: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.64it/s]


Train Accuracy for epoch : 0.22152710940006418
Train Loss for epoch : 0.033735850065191246


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.42it/s]


Test Accuracy for epoch : 0.2522056464549246
Test Loss for epoch : 0.03207749380279819


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 17.46it/s]


Validation Accuracy for epoch : 0.25754830917874394
Validation Loss for epoch : 0.03202401714672405
Current Fold 1 and Epoch 2: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.76it/s]


Train Accuracy for epoch : 0.2735803657362849
Train Loss for epoch : 0.031199433002710573


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.49it/s]


Test Accuracy for epoch : 0.2932707731793391
Test Loss for epoch : 0.03037609210620907


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 18.63it/s]


Validation Accuracy for epoch : 0.3035426731078905
Validation Loss for epoch : 0.030286079119656972
Current Fold 1 and Epoch 3: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:14<00:00, 27.33it/s]


Train Accuracy for epoch : 0.3190968880333654
Train Loss for epoch : 0.029326382817692738


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.88it/s]


Test Accuracy for epoch : 0.345283926852743
Test Loss for epoch : 0.028418784541008293


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:09<00:00, 16.36it/s]


Validation Accuracy for epoch : 0.34993961352657005
Validation Loss for epoch : 0.028259372041709183
Current Fold 1 and Epoch 4: 


Train: 100%|███████████████████████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.47it/s]


Train Accuracy for epoch : 0.3685434712864934
Train Loss for epoch : 0.027288519924062546


Test: 100%|████████████████████████████████████████████████████████████████████| 391/391 [00:16<00:00, 23.50it/s]


Test Accuracy for epoch : 0.38494546037856914
Test Loss for epoch : 0.026681615281724608


Validation: 100%|██████████████████████████████████████████████████████████████| 157/157 [00:08<00:00, 18.05it/s]

Validation Accuracy for epoch : 0.39311594202898553
Validation Loss for epoch : 0.026478697472340437



