In [None]:
import sys
sys.path.append('/home/jiajunb/neural-dimension-reduction')

In [None]:
import os

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from src.models.distance_modeling import SurveyorDataSet, Surveyor, thesis_kl_div_add_mse_loss

import copy

torch.manual_seed(0)

In [None]:
def far_func2(sorted_dist: torch.tensor, indices: torch.tensor):
    return sorted_dist[:, 1].reshape(-1, 1), indices[:, 1].reshape(-1, 1)

train_dataset = SurveyorDataSet.from_df('/home/jiajunb/neural-dimension-reduction/data/train.csv', far_func2)
val_dataset = SurveyorDataSet.from_df('/home/jiajunb/neural-dimension-reduction/data/dev.csv', far_func2)

In [None]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1000, pin_memory=True)

In [None]:
weight_decay = 1e-5
learning_rate = 1e-5
num_epoches = 400

In [None]:
device = torch.device('cuda:1')
model = Surveyor()


no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(
        nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(
        nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
]

optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters, lr=learning_rate)

model = model.to(device)

In [None]:
def train_one_epoch(train_loader, model, optimizer, verbose):
    model = model.to(device)
    model.train()
    loss_sum = 0.
    for i, batch in enumerate(train_loader):
        x1, x2, labels, q = batch
        x1, x2, labels, q = x1.to(device), x2.to(device), labels.to(device), q.to(device)
        logits, p, out1, out2, loss = model(x1, x2, q, labels)
        model.zero_grad()  # reset gradient
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        if verbose and i % 20 == 0:
            print(f'training loss: {loss_sum / (i + 1):.4f}')
    return loss_sum / len(train_loader)

def val_one_epoch(val_loader, model):
    model.eval()
    loss_fn1 = nn.CrossEntropyLoss()
    loss_fn2 = thesis_kl_div_add_mse_loss
    preds_list = list()
    labels_list = list()
    val_xentropy_loss = 0.
    val_thesis_loss = 0.
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            x1, x2, labels, q = batch
            x1, x2, q = x1.to(device), x2.to(device), q.to(device)
            logits, p, out1, out2 = model(x1, x2, q, labels=None)
            preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
            preds_list.append(preds.cpu())
            labels_list.append(labels.cpu())
            labels = labels.to(device)
            val_xentropy_loss += loss_fn1(logits, labels).item()
            val_thesis_loss += loss_fn2(p, q).item()
    y_preds = torch.cat(preds_list)
    y_golds = torch.cat(labels_list)
    accuracy = float((y_preds == y_golds).sum().item()) / len(y_preds)
    return val_xentropy_loss / len(y_preds), val_thesis_loss / len(y_preds), accuracy

In [None]:
def train_with_eval(train_loader, val_loader, model, optimizer, num_epoches, verbose):
    best_model = None
    best_avg_xentropy_loss, best_avg_thesis_loss, best_val_accuracy = float('inf'), float('inf'), 0. 
    for epoch_idx in range(1, num_epoches + 1):
        avg_loss = train_one_epoch(train_loader, model, optimizer, False)
        avg_xentropy_loss, avg_thesis_loss, val_accuracy = val_one_epoch(val_loader, model)
        if val_accuracy >  best_val_accuracy:
            best_avg_xentropy_loss, best_avg_thesis_loss, best_val_accuracy = avg_xentropy_loss, avg_thesis_loss, val_accuracy
            best_model = copy.deepcopy(model.cpu())
        if verbose and (epoch_idx) % 5 == 0:
            print(f'epoch [{epoch_idx}]/[{num_epoches}] training loss: {avg_loss:.4f} '
                  f'val_cross_entropy_loss: {avg_xentropy_loss:.4f} '
                  f'val_thesis_loss: {avg_thesis_loss:.4f} '
                  f'val_accuracy: {val_accuracy:.4f} ')
    return best_avg_xentropy_loss, best_avg_thesis_loss, best_val_accuracy, best_model, model

In [None]:
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=1000, pin_memory=True)

In [None]:
best_avg_xentropy_loss, best_avg_thesis_loss, best_val_accuracy, best_model, final_model = train_with_eval(train_loader, val_loader, model, optimizer, num_epoches, True)


In [None]:
best_avg_xentropy_loss, best_avg_thesis_loss, best_val_accuracy

In [None]:
# torch.save({
#     "best_model": best_model.state_dict(),
#     "best_avg_xentropy_loss": best_avg_xentropy_loss,
#     "best_avg_thesis_loss": best_avg_thesis_loss, 
#     "best_val_accuracy": best_val_accuracy
# }, '../saves/surveyor.on.full.100')

In [None]:
# os.makedirs('checkpoints')