### In this image classification tutorial, we use PyTorch programming framework. This tutorial will be carried out in three steps: 

In [49]:
import torch 
from load_data import Data_Loader, show_transformed_images
import torchvision.models as models 
import torch.nn as nn 
import torch.optim as optim

In [50]:
train_loader, test_loader= Data_Loader()

### 1. Setting the device to GPU: 

In [51]:
def set_device(): 
    if torch.cuda.is_available(): 
        dev= "cuda:6"
    else: 
        dev= "cpu"
    return torch.device(dev)

device= set_device()

### The model we are going to be using in this case is ResNet which has already been trained on image dataset. 

In [52]:
resnet18_model= models.resnet18(pretrained= False)
num_features= resnet18_model.fc.in_features
num_classes= 10                                             #10 monkey species
resnet18_model.fc= nn.Linear(num_features, num_classes)     #Modifying the fully connected layer of the model to match our prediction task 
resnet18_model= resnet18_model.to(device)                   #Setting the training to GPU        

loss_fn= nn.CrossEntropyLoss()
optimizer= optim.SGD(resnet18_model.parameters(), lr= 0.01, momentum= 0.9, weight_decay= 0.003)



In [53]:
def train_nn(model, train_loader, test_loader, loss, optimizer, n_epochs, patience): 
    device= set_device()
    counter= 0
    
    for epoch in range(n_epochs): 
        print("Epoch number %d" % (epoch+1))
        model.train()
        running_loss = 0.
        running_correct= 0.
        total= 0 
        best_acc= 0
        
        for data in train_loader: 
            images, labels= data 
            images= images.to(device)
            labels= labels.to(device)
            total += labels.size(0)
            
            optimizer.zero_grad()       #Seting the gradients to zero before back propagation 
            
            outputs= model(images)
            
            _, predicted= torch.max(outputs.data, 1)
            
            losses= loss(outputs, labels)
            losses.backward()
            
            optimizer.step()
            
            running_loss += losses.item()
            running_correct += (labels==predicted).sum().item()
        
        epoch_loss= running_loss/len(train_loader)
        epoch_accuracy= 100.0 * running_correct/total
        
        print("-Training dataset. Got %d out of %d images correctly. acc: %.3f%%. Epoch loss: %.3f" 
              % (running_correct, total, epoch_accuracy,epoch_loss))
        
        validation_accuracy= evaluate_model(model, test_loader)
        
        if (validation_accuracy > best_acc): 
            best_acc= validation_accuracy
            save_checkpoint(model, epoch, optimizer, best_acc)
        
        else: 
            counter += 1
            
        if counter >= patience: 
            print(f"Early stopping at epoch {epoch}")
            break
        
    
    print("Finished")
    return model

### Defining a function for evaluating the performance of the model on validation set 

In [54]:
def evaluate_model(model, test_loader): 
    model.eval()
    predicted_correctly= 0
    total= 0
    device= set_device()
    
    with torch.no_grad(): 
        for data in test_loader: 
            images, labels= data
            images= images.to(device)
            labels= labels.to(device)
            total += labels.size(0)
            
            outputs= model(images)
            _, predicted= torch.max(outputs.data, 1) 
            
            predicted_correctly += (predicted==labels).sum().item()
   
    epoch_acc= 100.0 * predicted_correctly/total
    
    print("-Testing dataset. Got %d out of %d images correctly. acc: %.3f%%" % (predicted_correctly, total, epoch_acc))
    
    return epoch_acc

### Defining a function for saving the checkpoint: 

In [55]:
def save_checkpoint(model, epoch, optimizer, best_acc): 
    state= {
        "epoch": epoch + 1, 
        "model": model.state_dict(),
        "best_accuracy": best_acc,
        "optimizer": optimizer.state_dict(),
        "comments": "Monkey classification model" 
    }
    
    torch.save(state, "checkpoints/model_best_checkpoint.pth.tar")

### Training the model: 

In [56]:
train_nn(model= resnet18_model, train_loader= train_loader, test_loader= test_loader, loss= loss_fn, optimizer= optimizer, n_epochs= 100, patience= 10)

Epoch number 1
-Training dataset. Got 275 out of 1097 images correctly. acc: 25.068%. Epoch loss: 2.168
-Testing dataset. Got 38 out of 272 images correctly. acc: 13.971%
Epoch number 2
-Training dataset. Got 389 out of 1097 images correctly. acc: 35.460%. Epoch loss: 1.848
-Testing dataset. Got 80 out of 272 images correctly. acc: 29.412%
Epoch number 3
-Training dataset. Got 476 out of 1097 images correctly. acc: 43.391%. Epoch loss: 1.613
-Testing dataset. Got 113 out of 272 images correctly. acc: 41.544%
Epoch number 4
-Training dataset. Got 525 out of 1097 images correctly. acc: 47.858%. Epoch loss: 1.514
-Testing dataset. Got 108 out of 272 images correctly. acc: 39.706%
Epoch number 5
-Training dataset. Got 528 out of 1097 images correctly. acc: 48.131%. Epoch loss: 1.461
-Testing dataset. Got 136 out of 272 images correctly. acc: 50.000%
Epoch number 6
-Training dataset. Got 643 out of 1097 images correctly. acc: 58.614%. Epoch loss: 1.182
-Testing dataset. Got 141 out of 272 i

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [59]:
checkpoint= torch.load("checkpoints/model_best_checkpoint.pth.tar")
resnet18_model.load_state_dict(checkpoint["model"])
torch.save(resnet18_model, "saved_model/save_model.pth")