In [2]:
import os, sys

def find_project_root(start_path):
    cur = os.path.abspath(start_path)
    while True:
        if os.path.isdir(os.path.join(cur, "src")):
            return cur
        parent = os.path.dirname(cur)
        if parent == cur:
            break
        cur = parent
    raise RuntimeError("Project root not found.")

ROOT = find_project_root(os.getcwd())
sys.path.append(ROOT)
os.chdir(ROOT)

print("ROOT:", ROOT)
print("CWD :", os.getcwd())


ROOT: c:\Users\sam\Documents\VSCode_tunnel\cv_2024_upscale
CWD : c:\Users\sam\Documents\VSCode_tunnel\cv_2024_upscale


In [3]:
from src.config.train_config import TrainConfig as cfg

keys = [
    "lr_dir", "hr_dir", "patch_size",
    "batch_size", "learning_rate", "num_epochs",
    "model_name", "checkpoint_dir", "num_workers", "save_every", "exp_name"
]

for k in keys:
    print(f"{k:15s} = {getattr(cfg, k)}")

lr_dir          = data/train_lr
hr_dir          = data/train_hr
patch_size      = None
batch_size      = 16
learning_rate   = 0.0001
num_epochs      = 50
model_name      = unet
checkpoint_dir  = models_ckpt
num_workers     = 2
save_every      = 10
exp_name        = None


In [4]:
def set_train_config(
    model="unet",
    patch=None,
    bs=16,
    lr=1e-4,
    epochs=50,
    ckpt_dir="models_ckpt",
    num_workers=2,
    save_every=10,
    exp_name=None
):
    from src.config.train_config import TrainConfig as cfg

    cfg.model_name    = model
    cfg.patch_size    = patch
    cfg.batch_size    = bs
    cfg.learning_rate = lr
    cfg.num_epochs    = epochs
    cfg.checkpoint_dir = ckpt_dir
    cfg.num_workers   = num_workers
    cfg.save_every    = save_every
    if exp_name is not None:
        cfg.exp_name = exp_name
            
    print("âœ… TrainConfig updated:")
    print(f"  model       = {cfg.model_name}")
    print(f"  patch_size  = {cfg.patch_size}")
    print(f"  batch_size  = {cfg.batch_size}")
    print(f"  lr          = {cfg.learning_rate}")
    print(f"  epochs      = {cfg.num_epochs}")
    print(f"  ckpt_dir    = {cfg.checkpoint_dir}")
    print(f"  num_workers = {cfg.num_workers}")
    print(f"  save_every  = {cfg.save_every}")
    print(f"  exp_name    = {cfg.exp_name}")

In [7]:
set_train_config(
    model="srcnn",
    patch=None,       # None = full image, 32 = patch training
    bs=16,
    lr=1e-4,
    epochs=10,
    ckpt_dir="models_ckpt",
    num_workers=2,
    save_every=5,
)

âœ… TrainConfig updated:
  model       = srcnn
  patch_size  = None
  batch_size  = 16
  lr          = 0.0001
  epochs      = 10
  ckpt_dir    = models_ckpt
  num_workers = 2
  save_every  = 5
  exp_name    = None


In [8]:
from src.train_for_unet import train as train_unet
from src.train_for_srcnn import train as train_srcnn

if cfg.model_name.lower() == "unet":
    train_unet()
elif cfg.model_name.lower() == "srcnn":
    train_srcnn()

Training Device: cuda
Found 492 image pairs.


Epoch 1/10:   0%|          | 0/31 [00:00<?, ?it/s]

Epoch 1/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.77it/s, loss=0.125028]
Epoch 2/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.78it/s, loss=0.076046]
Epoch 3/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.74it/s, loss=0.074954]
Epoch 4/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.74it/s, loss=0.076809]
Epoch 5/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.74it/s, loss=0.075873]
Epoch 6/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.77it/s, loss=0.069137]
Epoch 7/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.88it/s, loss=0.082879]
Epoch 8/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.74it/s, loss=0.078232]
Epoch 9/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.77it/s, loss=0.073930]
Epoch 10/10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 31/31 [00:05<00:00,  5.75it/s, loss=0.060932]

ðŸŽ‰ Training Finished! Final model saved to: models_ckpt\srcnn_psfull_bs16_lr1e-4_final.pth





In [12]:
from src.eval import evaluate

model_1 = "srcnn_psfull_bs16_lr1e-4_final.pth"
model_2 = "srcnn_ps32_bs16_lr1e-4_final.pth"

print(f"Evaluating model: {model_1}")
evaluate(
    lr_dir="data/train_lr",
    hr_dir="data/train_hr",
    checkpoint=f"models_ckpt/{model_1}",
)

print(f"\nEvaluating model: {model_2}")
evaluate(
    lr_dir="data/train_lr",
    hr_dir="data/train_hr",
    checkpoint=f"models_ckpt/{model_2}",
)

Evaluating model: srcnn_psfull_bs16_lr1e-4_final.pth


RuntimeError: Error(s) in loading state_dict for UNetSR:
	Missing key(s) in state_dict: "inc.double_conv.0.weight", "inc.double_conv.1.weight", "inc.double_conv.1.bias", "inc.double_conv.1.running_mean", "inc.double_conv.1.running_var", "inc.double_conv.3.weight", "inc.double_conv.4.weight", "inc.double_conv.4.bias", "inc.double_conv.4.running_mean", "inc.double_conv.4.running_var", "down1.1.double_conv.0.weight", "down1.1.double_conv.1.weight", "down1.1.double_conv.1.bias", "down1.1.double_conv.1.running_mean", "down1.1.double_conv.1.running_var", "down1.1.double_conv.3.weight", "down1.1.double_conv.4.weight", "down1.1.double_conv.4.bias", "down1.1.double_conv.4.running_mean", "down1.1.double_conv.4.running_var", "down2.1.double_conv.0.weight", "down2.1.double_conv.1.weight", "down2.1.double_conv.1.bias", "down2.1.double_conv.1.running_mean", "down2.1.double_conv.1.running_var", "down2.1.double_conv.3.weight", "down2.1.double_conv.4.weight", "down2.1.double_conv.4.bias", "down2.1.double_conv.4.running_mean", "down2.1.double_conv.4.running_var", "down3.1.double_conv.0.weight", "down3.1.double_conv.1.weight", "down3.1.double_conv.1.bias", "down3.1.double_conv.1.running_mean", "down3.1.double_conv.1.running_var", "down3.1.double_conv.3.weight", "down3.1.double_conv.4.weight", "down3.1.double_conv.4.bias", "down3.1.double_conv.4.running_mean", "down3.1.double_conv.4.running_var", "conv_up1.double_conv.0.weight", "conv_up1.double_conv.1.weight", "conv_up1.double_conv.1.bias", "conv_up1.double_conv.1.running_mean", "conv_up1.double_conv.1.running_var", "conv_up1.double_conv.3.weight", "conv_up1.double_conv.4.weight", "conv_up1.double_conv.4.bias", "conv_up1.double_conv.4.running_mean", "conv_up1.double_conv.4.running_var", "conv_up2.double_conv.0.weight", "conv_up2.double_conv.1.weight", "conv_up2.double_conv.1.bias", "conv_up2.double_conv.1.running_mean", "conv_up2.double_conv.1.running_var", "conv_up2.double_conv.3.weight", "conv_up2.double_conv.4.weight", "conv_up2.double_conv.4.bias", "conv_up2.double_conv.4.running_mean", "conv_up2.double_conv.4.running_var", "conv_up3.double_conv.0.weight", "conv_up3.double_conv.1.weight", "conv_up3.double_conv.1.bias", "conv_up3.double_conv.1.running_mean", "conv_up3.double_conv.1.running_var", "conv_up3.double_conv.3.weight", "conv_up3.double_conv.4.weight", "conv_up3.double_conv.4.bias", "conv_up3.double_conv.4.running_mean", "conv_up3.double_conv.4.running_var", "outc.weight", "outc.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias". 