<a href="https://colab.research.google.com/github/PiehTVH/What-is-this-Fruit/blob/main/Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Remember to read the kaggle API token to interact with your kaggle account.

In [None]:
# !pip install -q kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d 'moltean/fruits'

In [None]:
!unzip /content/fruits.zip

In [None]:
!pip install timm

In [None]:
!pip install geffnet

In [None]:
!pip install efficientnet_pytorch

In [None]:
import os
import re
import PIL
import sys
import json
import time
import math
import copy
import torch
import pickle
import geffnet
import logging
import fnmatch
import argparse
import torchvision
import numpy as np
%matplotlib inline
import pandas as pd
import seaborn as sns
import torch.nn as nn
from PIL import Image
from pathlib import Path
from copy import deepcopy
from sklearn import metrics
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as data
from geffnet import create_model
from torch.autograd import Variable
from tqdm import tqdm, tqdm_notebook
from torch.optim import lr_scheduler
from efficientnet_pytorch import EfficientNet
from torchvision import transforms, models, datasets
from torch.utils.data.sampler import SubsetRandomSampler
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
data_dir = '/content/fruits-360_dataset/fruits-360'
train_dir = data_dir + '/Training'
valid_dir = data_dir + '/Test'

data_transforms = {
    'Training': transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ]),
    'Test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ])
}

# Load the datasets with ImageFolder
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['Training', 'Test']}

batch_size = 64
data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                             shuffle=True, num_workers=4, pin_memory = True)
              for x in ['Training', 'Test']} 

dataset_sizes = {x: len(image_datasets[x]) for x in ['Training', 'Test']}

class_names = image_datasets['Test'].classes

In [None]:
#Save file labels
with open('labels.txt', 'w') as f:
  for i in class_names:
    f.write(i + '\n')

In [None]:
_ = image_datasets['Training'].class_to_idx
cat_to_name = {_[i]: i for i in list(_.keys())}
class_to_idx = {str(i): i for i in range(len(class_names))}

# Run this to test the data loader
images, labels = next(iter(data_loader['Test']))
images.size()

In [None]:
#Save file json: cat_to_name
with open("cat_to_name.json", "w") as outfile:
    json.dump(cat_to_name, outfile)

with open("class_to_idx.json", "w") as outfile:
    json.dump(class_to_idx, outfile)

In [None]:
import warnings 
warnings.filterwarnings('ignore')

In [None]:
def showimage(data_loader, number_images, cat_to_name):
    images, labels = next(iter(data_loader))
    images = images.numpy() # convert images to numpy for display
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(number_images, 4))
    # display 20 images
    for idx in np.arange(number_images):
        ax = fig.add_subplot(2, int(number_images/2), idx+1, xticks=[], yticks=[])
        img = np.transpose(images[idx])
        plt.imshow(img)
        ax.set_title(cat_to_name[labels.tolist()[idx]])
        

#### to show some  images
showimage(data_loader['Test'],2,cat_to_name)

In [None]:
model = create_model('efficientnet_b0', pretrained=True)
# Create classifier
for param in model.parameters():
    param.requires_grad = True

n_classes = 131
model.classifier = nn.Linear(model.classifier.in_features, n_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), 
                      lr=0.001,momentum=0.9,
                      nesterov=True,
                      weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
model.class_to_idx = image_datasets['Training'].class_to_idx
model.idx_to_class = {
    idx: class_
    for class_, idx in model.class_to_idx.items()
}
list(model.class_to_idx.items())

In [None]:
model.to(device)

def train_model(model, criterion, optimizer, scheduler, num_epochs=200, checkpoint=None):
  since = time.time()

  if checkpoint is None:
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = math.inf
    best_acc = 0
  else:
    print(f'Val loss: {checkpoint["best_val_loss"]}, Val accuracy: {checkpoint["best_val_accuracy"]}')
    model.load_state_dict(checkpoint['model_state_dict'])
    best_model_wts = copy.deepcopy(model.state_dict())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    best_loss = checkpoint['best_val_loss']
    best_acc = checkpoint['best_val_accuracy']

  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['Training', 'Test']:
      if phase == 'Training':
        model.train()       # Set model to training mode
      else:
        model.eval()        # Set model to evaluate mode

      running_loss = 0.0
      running_correct = 0

      #Iterate over data
      for i, (inputs, labels) in enumerate(data_loader[phase]):
        inputs = inputs.to(device) 
        labels = labels.to(device) 

        #zero the parameter gradients
        optimizer.zero_grad()
        
        if i % 1000 == 999:
          print('[%d, %d] loss: %.8f' % 
                          (epoch + 1, i, running_loss / (i * inputs.size(0))))
          
        #forward
        #track history if only in train
        with torch.set_grad_enabled(phase == 'Training'):
          outputs = model(inputs)
          _, preds = torch.max(outputs, 1)
          loss = criterion(outputs, labels)

          #backward + optimize only if in training phase
          if phase == 'Training':
            loss.backward()
            optimizer.step()

        #statistics
        running_loss += loss.item() * inputs.size(0)
        running_correct += torch.sum(preds == labels.data)

      if phase == 'Training':
        scheduler.step()

      epoch_loss = running_loss / dataset_sizes[phase]
      epoch_acc = running_correct.double() / dataset_sizes[phase]

      print('{} Loss: {:.8f} Acc: {:.8f}'.format(
                phase, epoch_loss, epoch_acc))
      
      #deep copy the model
      if phase == 'Test' and epoch_loss < best_loss:
        print(f'New best model found!')
        print(f'New record loss: {epoch_loss}, previous record loss: {best_loss}')
        best_loss = epoch_loss 
        best_acc = epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_loss': best_loss,
                    'best_val_accuracy': best_acc,
                    'scheduler_state_dict': scheduler.state_dict(),},
                   CHECK_POINT_PATH)
        print(f'New record loss is SAVED: {epoch_loss}')  


      """if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                iteration_change_loss = 0

            if iteration_change_loss == 10: #choose a number of epochs for patience
                print('Early stopping after {0} iterations without the decrease of the val loss'. format(iteration_change_loss))
                break"""
    print()
  time_elapsed = time.time() - since
  print('Training complete in {:.0f}m {:.0f}s'.format(
      time_elapsed // 60, time_elapsed % 60))
  print('Best val Acc: {:.8f} Best val loss: {:.8f}'.format(best_acc, best_loss))

  #Load best model weights
  model.load_state_dict(best_model_wts)
  return model, best_loss, best_acc

In [None]:
CHECKPOINT_PATH = '/content/EfficientNet_B0_SGD.pth'
try:
  checkpoint = torch.load(CHECKPOINT_PATH)
  print('checkpoint loaded')
except:
  checkpoint = None
  print('checkpoint not found')
if checkpoint == None:
  CHECK_POINT_PATH = CHECKPOINT_PATH

model, best_val_loss, best_val_acc = train_model(model,
                                                 criterion,
                                                 optimizer,
                                                 scheduler,
                                                 num_epochs = 100,
                                                 checkpoint = None #torch.load(CHECK_POINT_PATH)
                                                 ) 
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'best_val_accuracy': best_val_acc,
            'scheduler_state_dict': scheduler.state_dict(),
            }, CHECK_POINT_PATH)

#End