In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from upscaler import SuperResNet
from dataset import SuperResDataset
from lit_upscaler import ImageLoggerCallback, LitSuperResNet
import matplotlib.pyplot as plt



DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
images = []
lesions = []
from skimage.io import imread
import os


TRAIN_PATH = os.getenv("TRAIN_DATA_PATH")
VAL_PATH = os.getenv("VAL_DATA_PATH")

print(f"Train data: {TRAIN_PATH}")
print(f"Validation data: {VAL_PATH}")

train_dataset = SuperResDataset(
    TRAIN_PATH, 
    crop_size=64, 
    downscale_denoise=2
)
val_dataset = SuperResDataset(
    VAL_PATH, 
    crop_size=64, 
    downscale_denoise=2
)

print(f'Loaded {len(train_dataset)} train mages, {len(val_dataset)} val images')

Train data: /mnt/c/Users/efimplotnikov/Pictures/2016 - велики/Осень 2016
Validation data: /mnt/c/Users/efimplotnikov/Pictures/2016 - велики/Лето 2016/07.23 - 07.24 - вырица всей толпой, грибы
Loaded 698 train mages, 121 val images


In [4]:
import lightning as L

from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


torch.set_float32_matmul_precision('medium')

checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints",
    filename="upscaler-{epoch:03d}",
    save_top_k=3,
    monitor="val/loss",
    mode="min"
)

val_dl_for_logging = DataLoader(val_dataset, batch_size=8, shuffle=False)
logger_cb = ImageLoggerCallback(val_dl_for_logging, log_every_n_epochs=1)

COMMENT = "v1"

lit = LitSuperResNet(lr=1e-4).to(DEVICE)

trainer = L.Trainer(
    max_epochs=600,
    logger = TensorBoardLogger("lightning_logs", name=COMMENT, version=3),
    log_every_n_steps=5,
    callbacks=[checkpoint_cb, logger_cb],
)

train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8, persistent_workers=True)
val_dl   = DataLoader(val_dataset,   batch_size=16, shuffle=False, num_workers=8, persistent_workers=True)
trainer.fit(lit, train_dl, val_dl)

  x = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params | Mode  | FLOPs
------------------------------------------------------
0 | model | SuperResNet | 998 K  | train | 0    
------------------------------------------------------
998 K     Trainable params
0         Non-trainable params
998 K     Total params
3.994     Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [128, 128, 3, 3], expected input[16, 64, 31, 31] to have 128 channels, but got 64 channels instead