In [2]:
import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import time
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import rearrange
import lightning as L
from lightning.pytorch.utilities.types import STEP_OUTPUT
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from typing import Any
import wandb
import multiprocessing
import matplotlib.pyplot as plt

from TransPath_model import TransPathModel, GridData, PathLogger, TransPathLit

In [None]:
dataset_dir = './TransPath_data'
mode = 'h'
batch_size = 256
max_epochs = 200
learning_rate = 4e-4
weight_decay = 0.0
limit_train_batches = None
limit_val_batches = None
proj_name = 'TransPath_runs'
run_name = 'default'
accelerator = "cuda"
devices = [6]
torch.set_default_device(torch.device(f"cuda:{devices[-1]}"))

In [None]:
train_data = GridData(
        path=f'{dataset_dir}/train',
        mode=mode
    )
val_data = GridData(
        path=f'{dataset_dir}/val',
        mode=mode
    )
resolution = (train_data.img_size, train_data.img_size)

In [None]:
torch.manual_seed(42)
train_dataloader = DataLoader(
        train_data, 
        batch_size=batch_size,
        shuffle=True, 
        # num_workers=multiprocessing.cpu_count(), freezes jupyter notebook
        pin_memory=True,
        generator=torch.Generator(device=f'cuda:{devices[-1]}'),
    )
val_dataloader = DataLoader(
        val_data, 
        batch_size=batch_size,
        shuffle=False, 
        # num_workers=multiprocessing.cpu_count(), freezes jupyter notebook
        pin_memory=True
    )
samples = next(iter(val_dataloader))

In [None]:
torch.manual_seed(42)
callback = PathLogger(samples, mode=mode, num_samples=20)
checkpoints = ModelCheckpoint(dirpath='checkpoints/', filename='{epoch}', every_n_epochs=50)
wandb_logger = WandbLogger(project=proj_name, name=f'{run_name}_{mode}', log_model='all')

#model_path = './weights/alex_100_h_model'
model = TransPathModel()
#model.load_state_dict(torch.load(model_path, weights_only=True))
lit_module = TransPathLit(
        model=model,
        mode=mode,
        learning_rate=learning_rate,
        weight_decay=weight_decay
    )

trainer = L.Trainer(
        logger=wandb_logger,
        accelerator=accelerator,
        devices=devices,
        max_epochs=max_epochs,
        deterministic=False,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[callback, checkpoints],
    )

In [None]:
trainer.fit(lit_module, train_dataloader, val_dataloader)
wandb.finish()

In [None]:
weights_dir = './weights/'
timestr = time.strftime("%Y%m%d-%H%M%S")
torch.save(model.state_dict(), f'{weights_dir}/model_{mode}_{timestr}')