In [1]:
import sys
sys.path.append("../../py/")

In [15]:
import os
import logging
import pickle as pk
from datetime import datetime

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

logger = logging.getLogger()
logger.setLevel(logging.INFO)

logging.basicConfig(format='%(asctime)s %(levelname)-4s %(message)s',
                    level=logging.INFO,
                    datefmt='%d-%m-%Y %H:%M:%S')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset

from mlp import MLP

In [46]:
class YearPredictorDataset(Dataset):
    def __init__(self, patch_dataset, meta_df):
        embeddings, labels, map_name_to_class, patch_data_idx_to_emb_idx = YearPredictorDataset.init_from_patch_dataset(patch_dataset, meta_df)
        
        self.patch_dataset = patch_dataset
        self.embeddings = embeddings
        self.labels = labels
        self.patch_data_idx_to_emb_idx = patch_data_idx_to_emb_idx
        self.map_name_to_class = map_name_to_class
    
    def __len__(self):
        return len(embeddings)
    
    def __getitem__(self,i):
        if isinstance(i, slice):
            start = i.start if i.start else 0
            stop = i.stop if i.stop else len(self.X_1)
            step = i.step if i.step else 1

            return [(self.embeddings[j], self.labels[j]) for j in range(start, stop, step)]

        return (self.embeddings[i], self.labels[i])
    
    @staticmethod
    def get_map_name(patch):
        origin_map = patch.origin_map
        end_idx = origin_map.index(".")
        start_idx = len(origin_map) - origin_map[::-1].index("/")
        
        return patch.origin_map[start_idx : end_idx]
    
    @staticmethod
    def assign_class(year):
        if year < 1900:
            return 0
        elif year < 1910:
            return 1
        elif year < 1930:
            return 2
        else:
            return 3
    
    @staticmethod
    def init_from_patch_dataset(patch_dataset, meta_df):
        patch_data_idx_to_emb_idx = {}
        embeddings = []
        
        emb_idx = 0
        Xs = [patch_dataset.X_1, patch_dataset.X_2]
        labels = []
        map_name_to_class = {}
        
        for i in range(len(patch_dataset)):
            for j in range(2):
                patch = Xs[j][i]
                map_name = int(YearPredictorDataset.get_map_name(patch))
                
                if map_name in map_name_to_class:
                    year_class = map_name_to_class[map_name]
                else:
                    year = meta_df[meta_df["IMAGE"] == map_name].iloc[0]["YEAR"]
                    year_class = YearPredictorDataset.assign_class(year)
                    map_name_to_class[map_name] = year_class
                    
                embeddings.append(patch.patch_shift)
                labels.append(year_class)
                patch_data_idx_to_emb_idx[emb_idx] = (i,j)
                emb_idx += 1
        
        return embeddings, labels, map_name_to_class, patch_data_idx_to_emb_idx

In [10]:
with open("../../py/output/patch_train_dataset_128.pk", "rb") as f:
    train_data = pk.load(f)
    
with open("../../py/output/patch_val_dataset_128.pk", "rb") as f:
    val_data = pk.load(f)

In [49]:
metadata = pd.read_csv("../../../os_meta.csv")
meta = metadata[["PUB_SORT", "IMAGE"]].rename(columns = {"PUB_SORT" : "YEAR"})

In [51]:
meta.head()

Unnamed: 0,YEAR,IMAGE
0,1894,82877892
1,1894,82877928
2,1894,82877940
3,1894,82877949
4,1894,82877958


In [53]:
year_counts = meta.groupby("YEAR").count().reset_index().rename(columns = {"IMAGE" : "MAP COUNT"})
year_counts["CLASS"] = year_counts["YEAR"].apply(lambda year : assign_class(year))
year_counts.groupby("CLASS").sum()

Unnamed: 0_level_0,YEAR,MAP COUNT
CLASS,Unnamed: 1_level_1,Unnamed: 2_level_1
0,5685,57
1,5721,57
2,3827,57
3,5814,41


In [54]:
yp_train = YearPredictorDataset(patch_dataset = train_data, meta_df = meta)

In [57]:
class YearPredictorClassifier(nn.Module):
    def __init__(self, first_hidden_parameters, second_hidden_parameters, output_parameters):
        super(YearPredictorClassifier, self).__init__()
        
        self.hidden_mlp_1 = MLP(**first_hidden_parameterst_hidden_parameters)
        self.hidden_mlp_2 = MLP(**second_hidden_parameters)
        self.output_mlp = MLP(**output_parameters)
    
        self.classifier = nn.Sequential(self.hidden_mlp_1, self.hidden_mlp_2, self.output_mlp)
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimiser = None
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.checkpoint = {"epoch": 0,
                           "batch": 0,
                           "model_state_dict": self.state_dict(),
                           "optimiser_state_dict": None,
                           "loss": 0,
                           "avg_batch_losses_20": [],
                           "batch_losses": [],
                           "validation_losses ": [],
                           "run_start": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
                           "run_end": None,
                           "model_kwargs": self.kwargs}
        
    def forward(self, x):
        return self.classifier(x)
    
    def predict(self, x):
        y = self.classifier(x)
        
        return np.argmax(y)
    
    def compile_optimiser(self, **kwargs):
        self.optimiser = optim.Adam(self.parameters(), **kwargs) 
        
    def get_loss(self, y_true, y_pred):
        return self.criterion(y_true, y_pred)
    
    @torch.no_grad()
    def evaluate(self, evaluation_loader, validation = True):

        eval_losses = []

        self.eval()

        for embeddings, y_true in validation_loader:
            embeddings, y_true = embeddings.to(self.device), y_true.to(self.device)
            y_pred = self.classifier(embeddings)
            val_losses.append(self.get_loss(y_true = y_true, y_pred = y_pred).cpu())

        if validation:
            self.train()
            return np.mean(val_losses)

        return val_losses
    
    def update_checkpoint(self, checkpoint_dir, batch_losses, validation_losses, **checkpoint_data):

        for k, v in checkpoint_data.items():
            if k in self.checkpoint:
                self.checkpoint[k] = v

        if checkpoint_dir is not None:

            if not os.path.isdir(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            model_params_dir = os.path.join(checkpoint_dir, "year_classifier_checkpoint.pt")
            torch.save(self.checkpoint, model_params_dir)

            batch_loss_train_dir = os.path.join(checkpoint_dir,
                                                f"batch_loss_logs_t{checkpoint_data.get('epoch', 0)}.pk")
            with open(batch_loss_train_dir, "wb") as f:
                pk.dump(batch_losses, f)

            batch_loss_validation_dir = os.path.join(checkpoint_dir,
                                                     f"batch_loss_logs_v{checkpoint_data.get('epoch', 0)}.pk")
            with open(batch_loss_validation_dir, "wb") as f:
                pk.dump(validation_losses, f)
    
    def train_model(self, train_loader, validation_loader, epochs, checkpoint_dir=None, batch_log_rate=100):

        self.to(self.device)

        for epoch in range(epochs):
            batch_losses = []
            validation_losses = []
            avg_batch_losses_20 = []

            logging.info(f"Starting Epoch: {epoch + 1}")

            for batch, (embeddings, y_true) in enumerate(train_loader):

                self.optimiser.zero_grad()

                embeddings, y_true = embeddings.to(self.device), y_true.to(self.device)
                y_pred = self.classifier(embeddings)

                loss = self.get_loss(y_true = y_true, y_pred = y_pred)

                batch_losses.append(loss.cpu().detach())

                loss.backward()
                self.optimiser.step()

                if batch % (len(train_loader) // batch_log_rate + 1) == 0 and batch != 0:
                    with torch.no_grad():
                        avg_loss = np.mean(batch_losses[-20:])
                        avg_batch_losses_20.append(avg_loss)
                        logging.info(
                            f"Epoch {epoch + 1}: [{batch + 1}/{len(train_loader)}] ---- CrossEntropy Training Loss = {avg_loss}")

                        if batch % (len(train_loader) // (batch_log_rate//4) + 1) == 0:
                            validation_loss = self.get_validation_loss(validation_loader)
                            validation_losses.append(validation_loss)
                            logging.info(
                                f"Epoch {epoch + 1}: [{batch + 1}/{len(train_loader)}] ---- CrossEntropy Validation Loss = {validation_loss}")

                        self.update_checkpoint(checkpoint_dir=checkpoint_dir,
                                               batch_losses=batch_losses,
                                               validation_losses=validation_losses,
                                               epoch=epoch,
                                               batch=batch,
                                               model_state_dict=self.state_dict(),
                                               optimiser_state_dict=self.optimiser.state_dict,
                                               loss=loss.cpu().detach(),
                                               avg_batch_losses_20=avg_batch_losses_20,
                                               run_end=datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

            with torch.no_grad():
                self.update_checkpoint(checkpoint_dir=checkpoint_dir,
                                       batch_losses=batch_losses,
                                       validation_losses=validation_losses,
                                       epoch=epochs,
                                       batch=len(train_loader),
                                       model_state_dict=self.state_dict(),
                                       optimiser_state_dict=self.optimiser.state_dict,
                                       loss=loss.cpu().detach(),
                                       avg_batch_losses_20=avg_batch_losses_20,
                                       run_end=datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

        return self.checkpoint