# Model Training
This notebook provides an example of a training loop, and contains some utilities

In [None]:
import sys
sys.path.insert(0, "../")
import os
import json
from glob import glob
import argparse
import copy
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch.optim import Adam, Adamax
from torch.utils.data import DataLoader
import torch.nn as nn
from torchinfo import summary

from utils.display_utils import loss_curves, display_x_y, display_progress
from torch_datasets import PSFDataset
from vanilla_cnn_solution import VanillaCNN

In [None]:
def save_logs(logs_dic, path_out, iter, filename):
	logs_df = pd.DataFrame(data=logs_dic, index=[0])

	if iter == 1:
		wmode = "w"
		header = True
	else:
		wmode = "a"
		header = False

	with open(os.path.join(path_out, "logs", filename), mode=wmode) as csv_file:
		logs_df.to_csv(csv_file, header=header)

In [None]:
def train(model,
			  dataset_dir,
			  data_size,
			  loss_fn,
			  optimizer,
			  lr,
			  model_name,
			  save_model_every,
			  plot,
			  plot_every,
			  style,
			  exp_path,
			  batch_size,
			  epochs,
			  device=torch.device("cpu")):

	criterion = loss_fn().to(device, non_blocking=True)
	optimizer = optimizer(model.parameters(), lr=lr)

	start = time.time()

	dataset = PSFDataset(dir_path=dataset_dir, size=data_size, val_frac=0.2, test_frac=0.2)
	display_norm = None

	epoch_header = True
	step = 0

	for epoch in tqdm(range(1, epochs + 1)):

		dataset.method = "train"
		train_dataloader = DataLoader(copy.copy(dataset), batch_size=batch_size, shuffle=True)

		dataset.method = "val"
		val_dataloader = DataLoader(copy.copy(dataset), batch_size=batch_size, shuffle=True)

		print('\n')
		print(f'Epoch {epoch}/{epochs}')
		print('-' * 10)
		
		epoch_logs = {"epoch": epoch}

		for phase in ["valid", "train"]:

			if phase == "train":
				model.train(True)
				dataloader = train_dataloader
			else:
				model.train(False)
				dataloader = val_dataloader

			running_loss = 0.

			for i, (x, y) in enumerate(dataloader):
				x = x.to(device, non_blocking=True)
				y = y.to(device, non_blocking=True)

				# training phase
				if phase == "train":
					# display what the data looks like
					if i == 0 and epoch == 1 and plot:
						dset_name = os.path.basename(dataset_dir)
						display_x_y(x.cpu().detach().numpy(), y.cpu().detach().numpy(), dset_name, size=5, norm=display_norm, style=style)

					model.zero_grad()

					yhat = model(x)					# forward pass
					loss = criterion(yhat, y)		# evaluate loss
					loss.backward()					# backward pass

					optimizer.step()				# update network

					step += 1

					if plot and epoch%plot_every == 0 and i == 0:
						display_progress(x.cpu().detach().numpy(), yhat.cpu().detach().numpy(), y.cpu().detach().numpy(), epoch, norm=display_norm, style=style)

				# validation phase
				else:
					with torch.no_grad():
						yhat = model(x)
						loss = criterion(yhat, y)

				# update cumulative values
				running_loss += float(loss.detach())

			# after all batches processed
			epoch_loss = running_loss / dataloader.__len__()

			# epoch_logs = {"epoch": epoch,
						  # f"{phase}_loss": epoch_loss}
			
			epoch_logs.update({f"{phase}_loss": epoch_loss})
			
			if phase == "train":
				epoch_logs.update({"lr": optimizer.param_groups[0]['lr']})

			# display progress
			print(f'epoch {epoch} ==>  {phase} loss: {epoch_loss:.4e}')

		save_logs(epoch_logs, exp_path, epoch, f"logs.csv")
		
		# Keeping track of the model
		if save_model_every is not None:
			if epoch % save_model_every == 0:
				torch.save(model.state_dict(), exp_path + f'/models/{model_name}_epoch_{epoch:03d}.pt')

	# print training time
	time_elapsed = time.time() - start
	print(f'Training completed in {(time_elapsed // 60):.0f}m {(time_elapsed % 60):.0f}s')

	# **Save model**
	torch.save(model.state_dict(), exp_path + "/models/" + f"{model_name}_complete.pt")

	return

### Experiment parameters

In [None]:
path_in = os.path.join(os.getenv("ASTROMATIC_PATH"))
path_out = os.path.join(os.getenv("ASTROMATIC_PATH"), "Problems", "P7_PSFdeconvolution")
exp_name = "debug"
dataset = "debug_deconv_dataset_2"
npix = 256
data_size = None
batch_size = 10
n_epochs = 3
lr = 5e-3
loss = "MSE"
optimizer = "Adamax"
nla = "ReLU"
model_type = "VanillaCNN"
model_name = "debug_model"
save_model_every = 20
plot = True
plot_every = n_epochs
plt_style = "science"
seed = 0

exp_path = os.path.join(path_out, "experiments", exp_name)

### Experiment setup

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)

if not os.path.exists(os.path.join(exp_path, "models")):
	os.makedirs(os.path.join(exp_path, "models"))
if not os.path.exists(os.path.join(exp_path, "logs")):
	os.makedirs(os.path.join(exp_path, "logs"))
	
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# - loss function -
if loss == "MSE":
	loss_fct = nn.MSELoss

# - optimizer -
if optimizer == "Adamax":
	optimizer = Adamax

in_ch = 2
out_ch = 1

# --- Model ---
if model_type == "VanillaCNN":
	model = VanillaCNN(npix=npix, in_ch=in_ch, out_ch=out_ch, activation=nla).float()
else:
	raise ValueError(f"model of type {model_type} not found")

print(summary(model, input_size=(batch_size, in_ch, npix, npix)))

model = model.to(device, non_blocking=True)

### Train a model

In [None]:
train(model=model,
		   dataset_dir=os.path.join(path_out, "datasets", dataset),
		   data_size=data_size,
		   loss_fn=loss_fct,
		   optimizer=optimizer,
		   lr=lr,
		   model_name=model_name,
		   save_model_every=save_model_every,
		   plot=plot,
		   plot_every=plot_every,
		   style=plt_style,
		   exp_path=exp_path,
		   batch_size=batch_size,
		   epochs=n_epochs,
		   device=device)

In [None]:
if plot:
	loss_curves(exp_path, loss, model_name=model_name, save=False, style=plt_style)