In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import copy
import numpy as np
import glob
import pathlib
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
from sklearn.metrics import r2_score, root_mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

from temporal_dataset import PorosityDataset
from temporal_model import *

DEVICE = 'cuda'

In [None]:
# Build datasets:
data_dir = '/ix1/xjia/yuw253/porosity/Ksection_porosity_fused_updated_ratio'
dataset = PorosityDataset(data_dir, use_padding=False)

# Create random splits and compute normalization factors:
train_idx, val_idx, test_idx = dataset.get_split(seed=42)

dataset_trn = Subset(dataset, train_idx)
dataset_val = Subset(dataset, val_idx)
dataset_tst = Subset(dataset, test_idx)

In [None]:
def train_one_epoch(model, dataloader, optimizer, device, task, task_id=-1, verbose=False):
    model.train()
    if verbose:
        pbar = tqdm(dataloader)
    else:
        pbar = dataloader
        
    loss_history = []
    for batch in pbar:
        feats, labels = batch

        if task == 'regression':
            feats, labels = feats.to(device), labels[:, 1:].to(device)
            logits = model(feats)
            loss = F.mse_loss(logits.view(-1), labels.view(-1).float())
        else:
            feats, labels = feats.to(device), labels[:, 0].to(device)
            logits = model(feats)
            loss = F.binary_cross_entropy_with_logits(logits, labels.view(-1, 1).float())
            loss += torch.max(torch.tensor(0).to(device), (torch.sigmoid(model.var_weight / model.temperature)).mean() - 0.6)
            
        loss_history.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss_history
        

def eval_one_epoch(model, dataloader, device, task, task_id=-1, verbose=False):
    model.eval()
    loss_history = []
    references, predictions = [], []
    if verbose:
        pbar = tqdm(dataloader)
    else:
        pbar = dataloader
        
    for batch in pbar:
        feats, labels = batch
        if task == 'regression':
            feats, labels = feats.to(device), labels[:, 1:].to(device)
            with torch.no_grad():
                logits = model(feats)
            loss = F.mse_loss(logits.view(-1), labels.view(-1).float())
            preds = logits
        else:
            feats, labels = feats.to(device), labels[:, 0].to(device)
            with torch.no_grad():
                logits = model(feats)
            loss = F.binary_cross_entropy_with_logits(logits, labels.view(-1, 1).float())
            preds = torch.sigmoid(logits)

        loss_history.append(loss.item())
        references.append(labels.cpu())
        predictions.append(preds.squeeze(1).cpu())

    references = torch.concat(references)
    predictions = torch.concat(predictions)

    if task == 'regression':
        scores = []
        for j in range(references.shape[1]):
            scores.append(root_mean_squared_error(predictions[:, j], references[:, j]))
            # scores.append(r2_score(predictions[:, j], references[:, j]))
    else:
        scores = ((predictions > 0.5) == references).float().mean().item()
    return loss_history, scores

In [None]:
task = 'classification'

trn_batch_size = 2048
val_batch_size = 1024
max_epoch = 50

dataloader_trn = DataLoader(dataset_trn, batch_size=trn_batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=val_batch_size, shuffle=False)
dataloader_tst = DataLoader(dataset_tst, batch_size=val_batch_size, shuffle=False)

model = LSTM(input_size=7, hidden_size=256, num_layers=5, output_size=1).to(DEVICE)  
# model = MLP(in_dim=7, embed_dim=512, out_dim=output_dim, num_layer=5).to(DEVICE)
# model = TemporalTransformer(input_size=7, hidden_size=256, num_layers=5, output_size=1).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-03, weight_decay=0.0001)
best_acc, best_checkpoint = -float('inf'), None
train_loss_list, val_loss_list = [], []
for epoch in tqdm(range(max_epoch)):
    trn_loss = train_one_epoch(model, dataloader_trn, optimizer, DEVICE, task=task, task_id=task_id, verbose=True)
    val_loss, val_acc = eval_one_epoch(model, dataloader_val, DEVICE, task=task, task_id=task_id, verbose=False)
    if isinstance(val_acc, list):
        val_acc = np.mean(val_acc)
    print('[epoch {}] trn_loss={} val_loss={} val_acc={}'.format(
        epoch+1, round(np.mean(trn_loss), 3), round(np.mean(val_loss), 3), round(val_acc, 3)))
    if val_acc > best_acc:
        best_acc = val_acc
        best_checkpoint = copy.deepcopy(model.state_dict())
        
    train_loss_list.append(np.mean(trn_loss))
    val_loss_list.append(np.mean(val_loss))

# Compute test accuracy
model.load_state_dict(best_checkpoint)
_, tst_acc = eval_one_epoch(model, dataloader_tst, DEVICE, task=task, task_id=task_id, verbose=False)
tst_acc