## Imports

In [None]:
import torch
from torch.utils.data import DataLoader
from model.cnn import CNN
from model.encoder import Encoder
from model.decoder import Decoder
from model.endtoend import HME2LaTeX
from data_processing.loadData import HMEDataset
from model.language import Lang, tensorFromSentence, indexesFromSentence
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


## Dataset

### Initialization

In [None]:
labels_file = './data/symbol_train_labels.txt'
images_directory = './data/symbol_train_png/'

dataset = HMEDataset(labels_file, images_directory)

BATCH_SIZE = 32

# Runs on GPU if cuda is installed, else on CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

### Accessing items

In [None]:
# Extract tuple from index 0
image_tensor, target_label, index = dataset[0]

# Prints visual of extracted information
print(image_tensor)
print('-'*70)
print(target_label)
print('-'*70)
print(index)


### DataLoader

In [None]:
train_dataloader = DataLoader(dataset, BATCH_SIZE, shuffle=True)

## Language

### Initialization

In [None]:
# Create Lang object
latex = Lang('latex')

# Extract label column from dataset
label_list = dataset.labels_file.iloc[:,1]

# Populate latex language by cycling through label column
for label in label_list:
    latex.addSentence(label)

# Language display
print(latex.index2word)

In [None]:
# Creates a tensor with each image's label as its index to the latex language
labels_latex_index = [tensorFromSentence(latex, i) for i in dataset.labels_file.iloc[:,1]]
labels_by_lang_index = torch.cat(labels_latex_index).unsqueeze(1).unsqueeze(0)

In [None]:
cnn = CNN(device).to(device)
encoder = Encoder(input_size=512, hidden_size=256, seq_size=(BATCH_SIZE*31), batch_size=BATCH_SIZE).to(device)
decoder = Decoder(input_size=1, hidden_size=512, output_size=latex.n_words, num_features=32*31, batch_size=BATCH_SIZE, device=device).to(device)
model  = HME2LaTeX(cnn, encoder, decoder, labels_by_lang_index.shape[0], BATCH_SIZE, latex.n_words, device)

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

In [None]:
PATH = './symbol_model_4.tar'

In [None]:
#uncomment if you want to train existing model:
# checkpoint = torch.load(PATH, map_location=device)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# loss_list = checkpoint['loss']
# model.train()

In [None]:
loss_list = []
accuracy_list = []

In [None]:
epochs = 4
total_batches = len(dataset) // BATCH_SIZE
for epoch in range(epochs):
    print('Epoch', str(epoch+1) + '/' + str(epochs))
    print('\t' + 'Batch'.ljust(20) + '\t' + 'Accuracy'.ljust(20) + '\t' + 'Loss'.ljust(20))

    # Iterate through every batch in dataset
    for i, (batch_images, batch_labels, batch_indices) in enumerate(train_dataloader):
        
        # Initialize optimizer gradient to zero
        optimizer.zero_grad()

        # Load data into device
        batch_label_indices = labels_by_lang_index[0][batch_indices].float().to(device)
        batch_images = batch_images.float().to(device)

        # Forward pass through model
        batch_prediction_probabilities = model(batch_images, batch_label_indices)[0]

        # Calculate batch accuracy
        batch_predicted_labels = torch.argmax(batch_prediction_probabilities, dim=1)
        batch_label_indices = batch_label_indices.squeeze(1).long()
        correct = torch.sum(batch_predicted_labels == batch_label_indices).item()
        batch_accuracy = correct / len(batch_label_indices)
        batch_accuracy_percentage = batch_accuracy * 100
        

        # Calculate batch loss using Categorical Cross Entropy
        batch_loss = loss(batch_prediction_probabilities, batch_label_indices)

        # Add items to the list for history tracking
        loss_list.append(batch_loss)
        accuracy_list.append(batch_accuracy_percentage)

        # Perform backward propagation to compute gradients
        batch_loss.backward()

        # Update model parameters using optimizer
        optimizer.step()
        
        # Log batch information
        torch.save({
            'model_state_dict' : model.state_dict(),
            'optimizer_state_dict' : optimizer.state_dict(),
            'loss': batch_loss.item(),
            'losses': loss_list,
            'accuracies': accuracy_list
        }, PATH)
        
        # Display terminal updates for every 20th batch
        if i % 50 == 0 and i != 0:
            print('\t' + (str(i) + '/' + str(total_batches)).ljust(20),
                  '\t' + (str(batch_accuracy_percentage) + '%').ljust(20),
                  '\t' + str(batch_loss.item()).ljust(20))
