In [1]:
import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import time
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import rearrange
import lightning as L
from lightning.pytorch.utilities.types import STEP_OUTPUT
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from typing import Any
import wandb
import multiprocessing
import matplotlib.pyplot as plt

from TransPath_model import TransPathModel, GridData, PathLogger, TransPathLit

In [2]:
dataset_dir = './TransPath_data'
mode = 'h'
batch_size = 256
max_epochs = 1
learning_rate = 4e-4
weight_decay = 0.0
limit_train_batches = None
limit_val_batches = None
proj_name = 'TransPath_runs'
run_name = 'default'
accelerator = "cuda"
devices = [6]
torch.set_default_device(torch.device(f"{accelerator}:{devices[-1]}"))

In [3]:
train_data = GridData(
        path=f'{dataset_dir}/train',
        mode=mode
    )
val_data = GridData(
        path=f'{dataset_dir}/val',
        mode=mode
    )
resolution = (train_data.img_size, train_data.img_size)

In [4]:
torch.manual_seed(42)
train_dataloader = DataLoader(
        train_data, 
        batch_size=batch_size,
        shuffle=True, 
        # num_workers=multiprocessing.cpu_count(), freezes jupyter notebook
        pin_memory=True,
        generator=torch.Generator(device=f'{accelerator}:{devices[-1]}'),
    )
val_dataloader = DataLoader(
        val_data, 
        batch_size=batch_size,
        shuffle=False, 
        # num_workers=multiprocessing.cpu_count(), freezes jupyter notebook
        pin_memory=True
    )
samples = next(iter(val_dataloader))

In [5]:
torch.manual_seed(42)
callback = PathLogger(samples, mode=mode, num_samples=20)
checkpoints = ModelCheckpoint(dirpath='checkpoints/', filename='{epoch}', every_n_epochs=50)
wandb_logger = WandbLogger(project=proj_name, name=f'{run_name}_{mode}', log_model='all')

#model_path = './weights/alex_100_h_model'
model = TransPathModel()
#model.load_state_dict(torch.load(model_path, weights_only=True))
lit_module = TransPathLit(
        model=model,
        mode=mode,
        learning_rate=learning_rate,
        weight_decay=weight_decay
    )

trainer = L.Trainer(
        logger=wandb_logger,
        accelerator=accelerator,
        devices=devices,
        max_epochs=max_epochs,
        deterministic=False,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[callback, checkpoints],
    )

/home/drozdovdan/miniconda3/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(lit_module, train_dataloader, val_dataloader)
wandb.finish()

You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33mdaniil-drozdovjr[0m ([33mdaniil-drozdovjr-saint-petersburg-state-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]
Loading `train_dataloader` to estimate number of stepping batches.
/home/drozdovdan/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.

  | Name  | Type           | Params | Mode 
-------------------------------------------------
0 | model | TransPathModel | 963 K  | train
1 | loss  | L1Loss         | 0      | train
-------------------------------------------------
962 K     Trainable params
512       Non-trainable params
963 K     Total params
3.854     Total estimated model params size (MB)
135       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                      | 0/? [00:00<?, ?it/s]

/home/drozdovdan/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Training: |                                                                             | 0/? [00:00<?, ?it/s]

Validation: |                                                                           | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


0,1
epoch,▁▁
train_loss,▁
trainer/global_step,▁▁██
val_loss,▁

0,1
epoch,0.0
train_loss,56.90691
trainer/global_step,1999.0
val_loss,17.21014


In [8]:
weights_dir = './weights/'
timestr = time.strftime("%Y%m%d-%H%M%S")
torch.save(model.state_dict(), f'{weights_dir}/model_{mode}_{timestr}')