#author: Jedrzej Chmiel

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
!pip install nltk==3.7

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from words_batch_dataset import WordsBatchDataset
from one_item_dataset import OneItemDataset
from corpus import Corpus
from embedding import Embedding
from words_batch import WordsBatch
from tqdm import tqdm
import os
import os.path
import pickle
from typing import List, Dict
from datetime import datetime
from google.colab import drive

In [None]:
import nltk
nltk.download('punkt')

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:
drive.mount('/content/drive')

In [28]:
MAIN_DATA_DIR = 'drive/MyDrive/data_harry_potter'
# MAIN_DATA_DIR = 'data'

In [29]:
corpus = Corpus(MAIN_DATA_DIR+'/dictionary/second_dictionary_30_04.pickle')
one_item_dataset = OneItemDataset(len(corpus))

In [25]:
words_batch_datasets = [WordsBatchDataset(MAIN_DATA_DIR+'/harry_potter_books/prepared_txt/harry_potter_'+str(i)+'_prepared.txt', corpus.dictionary, sequence_length=6)
                        for i in range(1,8)]

In [6]:
def test_words_batch(model:WordsBatch, datasets: List[OneItemDataset], batch_size=2048) -> float:
    """
Tests passed word batch model using all datasets and returns MSE
Parameters:
    model:
        Object of WordsBatch class to be tested.
    datasets:
        List of onjects of WordsBatchDataset class. Datasets on which model will be tested.
    """
    mse = 0.0
    with torch.no_grad():
      loss_function = nn.MSELoss(reduction='sum')
      for dataset in datasets:
        loader = DataLoader(dataset, batch_size=1024, shuffle=False, drop_last=False)
        for X, y in loader:
          X = X.to(DEVICE)
          y = y.to(DEVICE)
          y = model.embedding.to_dense(y)
          pred = model(X)
          loss = loss_function(pred, y)
          mse +=  loss.item()
    
    total_length = sum([len(dataset) for dataset in datasets])
    mse = mse / total_length
    return mse

In [7]:
def train_words_batch(model:WordsBatch, datasets: List[OneItemDataset], batch_size: int, epochs: int,
                    optimizer: optim.Optimizer, saves_dir:str = None, results:Dict[str, int or float] = None,
                    start_epoch:int = 0):
    """
This function trains words_batch model. It uses MSE loss function to do this.
Parameters:
    model:
        Model to be trained. Object of WordsBatch class.
    datasets:
        List of all datasets on which model will be trained. First function uses all inputs from first dataset,
        then all input from second dataset, ..., then all outputs from last dataset. This is the end of first epoch.
    batch_size: the batch size to used
    optimizer: optimizer to use to update weights and biases.
    saves_dir: If not None in this directory function will save the model after each epoch. Files will be named as
        words_batch_epoch_1.pth, words_batch_epoch_2.pth, words_batch_epoch_3.pth, ...
        If in passed directory already exist file called for example words_batch_epoch_1.pth it will be truncated.
        If passed directory does not exist, it will be created.
Returns:
    List of MSEs before each epoch.
    """
    model.train()
    loss_function = nn.MSELoss()
    loaders = [DataLoader(dataset, batch_size, shuffle=True) for dataset in datasets]
    if saves_dir is not None:
        if not os.path.exists(saves_dir):
            os.makedirs(saves_dir)
        with open(saves_dir+'/results.pickle', 'wb') as file:
                    pickle.dump(results, file)
        with open(saves_dir+'/results.txt', 'wt') as file:
            for key, item in results.items():
                file.write(key + ': ' + str(item) + '\n')

    end_epoch = start_epoch+epochs
    for epoch in range(start_epoch, end_epoch):
        for i,loader in enumerate(loaders):
            print(f"Epoch {epoch}/{end_epoch}, dataset {i+1}/7")
            for X, y in tqdm(loader, desc="batch: "):
                X = X.to(DEVICE)
                with torch.no_grad():
                    y = model.embedding.to_dense(y.to(DEVICE))
                pred = model(X)
                loss = loss_function(pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if saves_dir is not None:
            model.save(saves_dir+'/'+f"words_batch_epoch_{epoch}.pth")
            if results is not None:
                mse = test_words_batch(model, datasets)
                print(f"MSE is: {mse}")
                results[f'mse_after_epoch_{epoch}'] = mse
                with open(saves_dir+'/results.pickkle', 'wb') as file:
                    pickle.dump(results, file)
                with open(saves_dir+'/results.txt', 'at') as file:
                    file.write(f'mse_after_epoch_{epoch} : {mse}\n')
                model.train()

In [8]:
def train_encoding(model: Embedding, dataset: OneItemDataset, batch_size: int, epochs: int,
                   optimizer: optim.Optimizer, saves_dir: str = None, with_training:bool = True) -> List[float]:
    """
This function trains encoding part of Embedding class object. It uses CrossEntropyLoss.
Parameters:
    model:
        Model to be trained. Object of Embedding class.
    datasets:
        Object of OneItemDataset class. Data on which to train.
    batch_size: the batch size to used
    optimizer: optimizer to use to update weights and biases.
    saves_dir: If not None in this directory function will save the model after each epoch. Files will be named as
        words_batch_epoch_1.pth, words_batch_epoch_2.pth, words_batch_epoch_3.pth, ...
        If in passed directory already exist file called for example words_batch_epoch_1.pth it will be truncated.
        If passed directory does not exist, it will be created.
Returns:
    List of acurate factor before each epoch.
    """
    model.train()
    loss_function = nn.CrossEntropyLoss()
    acurate = 0
    acurates = []
    loader = DataLoader(dataset, batch_size, shuffle=True)
    if saves_dir is not None:
        if not os.path.exists(saves_dir):
            os.makedirs(saves_dir)
    for epoch in range(epochs):
        print(f"Epoch {epoch}/{epochs}")
        for y in tqdm(loader, desc="batch: "):
            with torch.no_grad():
                X = embedding.to_dense(y)
            pred = model(X)
            loss = loss_function(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acurate += torch.count_nonzero(torch.argmin(pred,dim=1)==y).item()

        model.save(saves_dir+'/'+f"embedding_epoch_{epoch}.pth")
    print(f"Acurate: {acurate/len(dataset)}")
    acurates.append(acurate)

In [9]:
def test_encoding(model, dataset):
    """
Tests passed embedding model using all datasets and returns acurate factor.
Parameters:
    model:
        Object of Embedding class to be tested.
    datasets:
        Object of OneItemDataset class. Data on which to test.
Return:
    Acurate factor (float).
    """
    acurate = 0
    model.eval()
    with torch.no_grad():
        for y in tqdm(dataset):
            y = torch.unsqueeze(y, dim=0).to(DEVICE)
            X = model.to_dense(y)
            pred = model(X)
            acurate += torch.count_nonzero(torch.argmin(pred,dim=1)==y).item()
    result = acurate/len(dataset)
    print(f"Acurate fraction: {result}")
    return  result

In [30]:
embedding = Embedding(corpus_size=len(corpus), embedding_size=64, dropout_factor=0.18, sizes=[128, 512])
embedding = embedding.to(DEVICE)
words_batch = WordsBatch(embedding, hidden_state_size=96, dropout_factor=0.18, sequence_length=6, dense_layer_size=256)
words_batch = words_batch.to(DEVICE)

In [39]:
dir_name = MAIN_DATA_DIR+'/models/'+ datetime.now().strftime("training_embedding_%d_%m_%Y___%H_%M")
results = words_batch.info()
results['mse_initial'] = test_words_batch(words_batch, words_batch_datasets)

In [None]:
results, dir_name

In [None]:
# words_batch = WordsBatch.load(dir_name+'/words_batch_epoch_6.pth')
# words_batch = words_batch.to(DEVICE)
# with open(dir_name+'/results.pickle', 'rb') as file:
#   results = pickle.load(file)
# results

In [49]:
# dir_name = MAIN_DATA_DIR+'/models/training_words_batch_09_05_2022___15_58'

In [32]:
optimizer_words_batch = optim.Adam(words_batch.parameters())
results['optimizer_type'] = 'Adam'

In [None]:
# train_words_batch(words_batch, words_batch_datasets, batch_size=16, epochs=15,
#               optimizer=optimizer_words_batch, saves_dir=dir_name, results=results,
#               start_epoch = 0)