<a href="https://colab.research.google.com/github/tamtemtomm/glovitoo/blob/main/glovitoo_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title <p> Import data
from google.colab import files
files.upload()

!unzip data.zip
!rm data.zip >& /dev/null

## Essential Import

In [145]:
# @title <p>Essential Import
import os, shutil, json
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np, pandas as pd, random as rd
import warnings
warnings.filterwarnings("ignore")

In [146]:
# @title <p>Torch Essential Import
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Preprocessing Import

In [147]:
# @title <p> Get data from the file
DATA_DIR  = 'data'
TRAIN_DATA_DICT = {}
VAL_DATA_DICT = {}
VAL_SIZE= 0.2
THRESHOLD = 20
# len_array = []

for letter in os.listdir(DATA_DIR):
  data_size = VAL_SIZE * len(os.listdir(os.path.join(DATA_DIR, letter)))
  for i, file_data in enumerate(os.listdir(os.path.join(DATA_DIR, letter))):
    with open(os.path.join(DATA_DIR, letter, file_data), 'r') as f:
      try :
        data = f.read().splitlines()
        # Check if the data is reach minimum threshold
        if len(data) < THRESHOLD:
          continue

        data_array = []
        for data_line in data[-THRESHOLD:]:
          data_array.append([float(a) for a in data_line.split(', ')])
        # len_array.append(len(data_array))
        if i < data_size :
          VAL_DATA_DICT[file_data.split('.')[0]] =  np.array(data_array)
        else :
          TRAIN_DATA_DICT[file_data.split('.')[0]] =  np.array(data_array)

      except :
        continue

In [148]:
# @title <p> Initialized SIBIDataset
class SIBIDataset(Dataset):
  def __init__(self, data_dict):
    self.data_dict = data_dict
    self.features = []
    self.labels = []

    for label, features in data_dict.items():
      self.features.append(torch.from_numpy(features).float())
      self.labels.append(F.one_hot(torch.tensor(ord(label.split('-')[0]) - 97), num_classes=26))

  def __len__(self):
    return len(self.features)

  def __getitem__(self, idx):
    return self.features[idx], self.labels[idx]

In [149]:
# @title <p> Trainset and Trainloader
trainset = SIBIDataset(TRAIN_DATA_DICT)
valset = SIBIDataset(VAL_DATA_DICT)

trainloader = DataLoader(trainset,batch_size=4,shuffle=True,num_workers=1)
valloader = DataLoader(trainset,batch_size=4,shuffle=True,num_workers=1)

## Initialized model

In [None]:
# Initialized RNN Model
A


In [None]:
# @title <p>Make training loop

def train_loop(train_loader, val_loader, epoch_num, patience, model, loss_fn, optimizer, device, target_acc = 1):

  best_metric, best_metric_epoch, cur_patience = -1, -1, 0
  epoch_loss_values, metric_values = list(), list()
  prev_acc_metric = 0

  for epoch in range(epoch_num):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epoch_num}")
    epoch_loss, step = 0, 1
    steps_per_epoch = data_len // train_loader.batch_size

    model.train()
    for batch_data in train_loader:

        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        if (step) % 10 == 0 :
          print(f"{step}/{data_len // train_loader.batch_size + 1}, training_loss: {loss.item():.4f}")
        step += 1

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    if (epoch + 1) % 10 == 0 :
      print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    model.eval()

    with torch.no_grad():
        preds, labels = list(), list()
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)

            val_pred = model(val_images)

            preds.append(val_pred)
            labels.append(val_labels)

        y_pred = torch.cat(preds)
        y = torch.cat(labels)

        true_pred = [torch.argmax(x1) == torch.argmax(x2) for x1, x2 in zip(y_pred, y)]
        acc_metric = float(sum(1 for i in true_pred if i))/len(y_pred)
        print(f'Accuracy : {acc_metric}')
        metric_values.append(acc_metric)

        if acc_metric > best_metric:
            best_metric = acc_metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(model_path, "best.pth"))
            shutil.copy(os.path.join(model_path, "best.pth"), os.path.join("/content/gdrive/MyDrive/compfestTrain/",str(date.today()),"best.pth"))
            print("saved new best metric network")
        if acc_metric <= prev_acc_metric : cur_patience += 1
        else : cur_patience = 0

        prev_acc_metric = acc_metric

        torch.save(model.state_dict(), os.path.join(model_path, "last.pth"))
        shutil.copy(os.path.join(model_path, "last.pth"), os.path.join("/content/gdrive/MyDrive/compfestTrain/",str(date.today()),"last.pth"))

        print(
            f"current epoch: {epoch + 1} / "
            f"current accuracy: {acc_metric:.4f} best ACC: {best_metric:.4f} "
            f"at epoch: {best_metric_epoch}/ "
            f"cur patience: {cur_patience}"
            )

        if acc_metric == target_acc:
          print("Got target accuracy, Stop the training session")
          break

        if cur_patience == patience:
          print(f"Callback Activated, Stop the training session")
          break

  return epoch_loss_values, metric_values