<a href="https://colab.research.google.com/github/kode-git/FER-Visual-Transformers/blob/main/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training

This notebook is used for trains transformers and deep neural networks.

## Install Dependencies and Import Libraries

In [None]:
!pip install timm
!pip install fvcore
!git clone https://github.com/davda54/sam.git

In [None]:
# classic libraries for collections.
import pandas as pd
import numpy as np

# utility library.
import random, time, copy

# plot libraries.
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

# libraries for image processing.
import os, cv2, glob, imageio, sys
from PIL import Image
# warning library for service warnings.
import warnings

# machine learning libraries .
import timm, torch, torchvision
from torchsummary import summary

# image dataset loading and transformations.
from torchvision import datasets, models, transforms

# utility functions for specific uses.
from __future__ import print_function
from __future__ import division

# optimizer libraries.
from torch.optim import lr_scheduler
import torch.optim as optim
from sam.sam import SAM

# library for basic building blocks.
import torch.nn as nn

# library for saving and loading checkpoints.
import pickle

# libraries for metrics and evaluation phase.
from sklearn import metrics
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay

# libraries for flop analysis.
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str


In [None]:
import os
import glob

In [None]:
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

## GPU Configuration

Transformers are trained using Google Colab Pro GPU: NVIDIA P100.

In [None]:
!nvidia-smi

In [None]:
# Detect if we have a GPU available.
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(device)

## Common utilities

In [None]:
def mkdir_model(base_dir, name_model, counter):
  """
  Making a directory for the model dump.
  """
  try:
    d = "{}/{}".format(base_dir,name_model)
    os.mkdir(d)
  except FileExistsError:
    counter += 1
    mkdir_model(base_dir, str(name_model) + "_" + str(counter), counter)

def save_history(history, filename):
  """
  Save the history in the file.
  """
  if os.path.isfile(filename):
    os.remove(filename)
  file_handler = open(filename + ".pkl", "wb")
  pickle.dump(history, file_handler)
  file_handler.close()


def load_history(filename):
  """
  Load the history from the file.
  """
  file_handler = open(filename + ".pkl", "rb")
  output = pickle.load(file_handler)
  file_handler.close()
  return output # 가중치 파일 인식?


def train_model(model, dataloaders, criterion, optimizer,lr_scheduler, num_epochs=25, is_inception=False, 
                is_loaded = False, load_state_ws=None, history_file_acc="history_accuracy",
                history_file_loss="history_loss", n_partial=0, model_folder="", best_acc=0.0 ):
    """
    PyTorch training model with loading support and dump management.
    Trains a model in a series of epochs and return the best configuration.
    Best configuration is given by the best validation accuracy around epochs.
    Training metrics are saved in well formated files.
    """
    
    history = {'val_2' : [], 'train_2' : []}
    loss_history = {'val_2' : [], 'train_2' : []}

    if is_loaded and load_state_ws != None:
      # load the model.
      state_dict = torch.load(load_state_ws)
      model.load_state_dict(state_dict)
      model.eval()
      print('Model loaded correctly')

    print('Starting Training')
    print('-' * 12)

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = best_acc
    for epoch in range(num_epochs):
        epoch_since = time.time()
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 12)
        # Each epoch has a training and validation phase.
        for phase in ['train_2', 'val_2']:
            total = len(dataloaders[phase])
            current = 0
            if phase == 'train_2':
                model.train()  # Set model to training mode.
            else:
                model.eval()   # Set model to evaluate mode.

            running_loss = 0.0
            running_corrects = 0

            dl = dataloaders[phase]
            totalIm=0
            # Iterate over data.
            for inputs, labels in dl:
                totalIm+=len(inputs)
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients.
                optimizer.zero_grad()

                # forward.
                # track history if only in train.
                with torch.set_grad_enabled(phase == 'train_2'):
                    # Get model outputs and calculate loss.
                      outputs = model(inputs)
                      loss = criterion(outputs, labels)

                      _, preds = torch.max(outputs, 1)
                      def closure():
                          outputs = model(inputs)
                          _, preds = torch.max(outputs, 1)
                          loss = criterion(outputs, labels)
                          loss.backward()
                          return loss

                    # backward + optimize only if in training phase.
                      if phase == 'train_2':
                        loss.backward()
                        if type(optimizer) != SAM:
                          optimizer.step()
                        else:
                          optimizer.step(closure)

                        

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                epoch_loss = running_loss / totalIm
                epoch_acc = running_corrects.double() / totalIm
                # status update.
                current += 1
                sys.stdout.write("\r" + f"{epoch + 1}/{num_epochs} - {phase} step : " + str(current * batch_size) + "/" +  str(total * batch_size) + " - " + 
                                 "{}_accuracy : ".format(phase) + "{:4f}".format(epoch_acc) + " - {}_loss : ".format(phase) + "{:4f}".format(epoch_loss))
                sys.stdout.flush()
            epoch_loss = running_loss / totalIm
            epoch_acc = running_corrects.double() / totalIm
            print() # avoid result cleaning .
            if phase == 'train_2':
              history['train_2'].append(epoch_acc)
              loss_history['train_2'].append(epoch_loss)

            # deep copy the model only in case the accusary is better in evaluation (local optima).
            local_optima = False
            if phase == 'val_test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                local_optima = True
            if phase == 'val_2':
                history['val_2'].append(epoch_acc)
                loss_history['val_2'].append(epoch_loss)

        # Increases the internal counter.
        if lr_scheduler:            
            lr_scheduler.step()            
        lr = optimizer.param_groups[0]['lr']
        interval_epoch = time.time() - epoch_since 
        print('\nEpoch {} complete in. {:.0f}m {:.0f}s {} and with a learning rate of {}'.format(epoch + 1, interval_epoch // 60, interval_epoch % 60, "with best local accuracy" if local_optima else "",lr))
        save_history(loss_history, model_folder + os.path.basename(model_folder) + "_" + history_file_loss)
        
        torch.save(model.state_dict(), model_folder + "epoch_{}_{}".format(epoch + 1, os.path.basename(model_folder[:len(model_folder) - 1])))
        print("-" * 12)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val accuracy: {:4f}'.format(best_acc))

    # load best model weights.
    model.load_state_dict(best_model_wts)
    return model, history['train_2'], history['val_2'], best_acc

## Dataset Loading

In [None]:
import os
from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.classes = []  # 클래스 레이블을 저장하는 리스트 추가
        self.load_images()

    def load_images(self):
        problematic_images = []
        for i, class_name in enumerate(os.listdir(self.root_dir)):
            class_dir = os.path.join(self.root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for filename in os.listdir(class_dir):
                if filename.endswith('.jpg'):
                    img_path = os.path.join(class_dir, filename)
                    try:
                        with Image.open(img_path) as img:
                            if self.transform:
                                img = self.transform(img)
                            self.images.append(img_path)
                            self.labels.append(i)  # 클래스 레이블을 숫자로 저장
                            self.classes.append(class_name)  # 클래스 이름 저장
                    except Exception as e:
                        print(f"Skipping problematic image: {img_path}")
                        problematic_images.append(img_path)

        # 문제가 있는 이미지를 데이터셋에서 제거
        for img_path in problematic_images:
            try:
                idx = self.images.index(img_path)
                del self.images[idx]
                del self.labels[idx]
                del self.classes[idx]
            except ValueError:
                print(f"Skipping problematic image: {img_path} not found in the dataset")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        try:
            with Image.open(img_path) as img:
                if self.transform:
                    img = self.transform(img)
                return img, label
        except Exception as e:
            print(f"Skipping problematic image: {img_path}")
            
            pass


In [None]:
# input and batch size specification.
input_size = (224,224)
batch_size = 30

# dataset directory.
data_dir= "../../../../data/image/"

# removing possible .ipybn_checkpoints.
for fd in glob.glob("../../../../data/image/*"):
  for cl in glob.glob(fd + "/.*"):
    os.rmdir(cl)

# loading training and validation set.
data_transforms = {
    'train_2': transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]),
    'val_2': transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]),
}


# 모델 수정 후 미세 조정

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets.
image_datasets = {x: CustomDataset(os.path.join(data_dir, x), transform=data_transforms[x]) for x in ['train_2', 'val_2']}

# Create training and validation dataloaders.
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=7,pin_memory=True) for x in ['train_2', 'val_2']}



In [None]:
# specify the total number of classes.
NUM_CLASSES = 7
model_name = 'vit_base_patch16_224'
# model_name = 'resnet18'
# loading pretrained model.
model = timm.create_model(model_name, pretrained=True)

In [None]:
# flops analysis.
inputs = (torch.randn((1, 3, 224, 224)))
model.eval() 
print('-'*40)

# flop data display.
flop = FlopCountAnalysis(model, inputs)
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))
print("Tot. flops:", flop.total())

In [None]:
# adapting head for 8 classes classify (fine-tuning).

if model_name == 'resnet18':
  model.fc = nn.Linear(512, NUM_CLASSES)
else: 
  model.head = nn.Linear(768, NUM_CLASSES)
  
# display modified model.
model.eval()

In [None]:
optimizer_set = input('Digit 0 for SGD or other values for SAM: ')
if optimizer_set == str(0):
  optimizer_set = "Adam"
else:
  optimizer_set = "AdamW"
print('Chosen {} for the model training.'.format(optimizer_set))

In [None]:
# Detect if we have a GPU available.
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(device)
# Send the model to GPU
model = model.to(device)
feature_extract=True

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model.parameters()
print("Params to learn:")

for name,param in model.named_parameters():
    if param.requires_grad == True:
          print("\t",name)

print('-'*40)
lr_in = 0.001
momentum_in = 0.9
if optimizer_set == "Adam":
  # stochasic gradient descent.
  optimizer_ft = optim.Adam(params_to_update, lr=lr_in)
else:
  # shapeness-aware minimizer.
  optimizer_ft = optim.Adam(model.parameters(), lr=lr_in)

print(optimizer_ft)

## Start Training

In [None]:
warnings.filterwarnings('ignore')

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()
num_epochs = input('Digits the initial number of epochs, invalid values are equals to 10 epochs: ')
try:
  int(num_epochs)
except ValueError:
  print('Default number of 10 epochs selected.')
  num_epochs = 500

In [None]:
# model general info.
name_model = "vfer_small_5"
base_dir = "../Models/"

# model files for saving history and model data.
model_folder = base_dir + name_model + "/"
model_file = model_folder + name_model + ".pth"
train_history = model_folder + name_model + "_" + "history_train"
val_history = model_folder + name_model + "_" + "history_val"


# Learning Rate schedule: decays the learning rate by a factor of `gamma` .
# every `step_size` epochs.
scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

In [None]:
mkdir_model(base_dir, name_model, 0)
# Train and evaluate
model, train_hist, val_hist, best_acc = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, 
                                          is_inception=False)
#Saving the updated model for the inference phase
torch.save(model.state_dict(), model_file)

# Save histories data
save_history(train_hist, train_history)
save_history(val_hist, val_history)

In [None]:
# Num epochs for this snippet
num_epochs = 10

# model general info
name_model = "vfer_small_15"
base_dir = "../Models/"
mkdir_model(base_dir, name_model, 0)

# model files for saving history and model data
model_folder = base_dir + name_model + "/"
model_file = model_folder + name_model + ".pth"
train_history = model_folder + name_model + "_" + "history_train"
val_history = model_folder + name_model + "_" + "history_val"

# changing starting lr
lr_in = 0.01
optimizer_ft = optim.Adam(model.parameters(), lr=lr_in, betas=(momentum_in, 0.9))
scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

# Train and evaluate
model, train_hist, val_hist, best_acc = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, 
                                          is_inception=False, is_loaded=True, model_folder= model_folder, best_acc=best_acc,
                                          load_state_ws="../Models/vfer_small_5/vfer_small_5.pth")


#Saving the updated model for the inference phase
torch.save(model.state_dict(), model_file)

# Save histories data
save_history(train_hist, train_history)
save_history(val_hist, val_history)

In [None]:
# model general info
name_model = "vfer_sam_25"
base_dir = "/content/drive/MyDrive/Models/"
mkdir_model(base_dir, name_model, 0)

# model files for saving history and model data
model_folder = base_dir + name_model + "/"
model_file = model_folder + name_model + ".pth"
train_history = model_folder + name_model + "_" + "history_train"
val_history = model_folder + name_model + "_" + "history_val"

# updating num_epochs
num_epochs = 5
# changing starting lr
lr_in = 0.001
optimizer_ft = optim.SGD(model.parameters(), lr=lr_in, momentum=momentum_in)
scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

# Train and evaluate
model, train_hist, val_hist, best_acc = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, 
                                          is_inception=False, is_loaded=True, model_folder= model_folder,
                                          load_state_ws="/content/drive/MyDrive/Models/vfer_sam_10/vfer_sam_10.pth", best_acc=best_acc )


#Saving the updated model for the inference phase
torch.save(model.state_dict(), model_file)

# Save histories data
save_history(train_hist, train_history)
save_history(val_hist, val_history)