In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import random
import numpy as np

from models.ar_cnn import AR_CNN, QuickLoss  
from utils.online_patch_dataset import OnlinePatchDataset
from utils.metrics import calculate_psnr, calculate_ssim

In [34]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [35]:
set_seed(42)

num_epochs = 50
batch_size = 32
learning_rate = 1e-3
grad_weight = 0.5

patchs_per_frame = 6
patch_size = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AR_CNN().to(device)
criterion = QuickLoss(grad_weight=grad_weight)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
transform = transforms.ToTensor()

# Diretórios dos dados
train_original_dir = "data/frames_y/train/original"
train_processed_dir = "data/frames_y/train/qp63"
val_original_dir = "data/frames_y/val/original"
val_processed_dir = "data/frames_y/val/qp63"

In [36]:
# Criação dos datasets

train_dataset = OnlinePatchDataset(
    original_dir = train_original_dir,
    processed_dir = train_processed_dir,
    patch_size = patch_size,
    patches_per_frame = patchs_per_frame,
    transform = transform
)

val_dataset = OnlinePatchDataset(
    original_dir = val_original_dir,
    processed_dir = val_processed_dir,
    patch_size = patch_size,
    patches_per_frame = patchs_per_frame,
    transform = transform
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle= False, num_workers=4)

In [41]:
def train_loop():
    for epoch in range(0, num_epochs + 1):
        print(f'Starting epoch {epoch}')
        model.train()
        running_loss = 0.0

        for batch_idx, (orig, proc) in enumerate(train_loader):
            orig = orig.to(device)
            proc = proc.to(device)

            optimizer.zero_grad()
            output = model(proc)
            loss = criterion(output, orig)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}/{num_epochs} - Batch: {batch_idx}/{len(train_loader)} - loss: {loss:.6f}' )
        
        avg_train_loss = running_loss / len(train_loader)
        print(f'\navg_loss: {avg_train_loss}')

        model.eval()
        total_psnr = 0.0
        baseline_total_psnr = 0.0
        total_ssim = 0.0
        baseline_total_ssim = 0.0   

        val_loss = 0.0
        with torch.no_grad():
            for orig, proc in val_loader:
                orig = orig.to(device)
                proc = proc.to(device)
                output = model(proc)
                loss = criterion(output, orig)
                val_loss += loss.item()

                total_psnr += calculate_psnr(output, orig)
                total_ssim += calculate_ssim(output, orig)

                baseline_total_psnr += calculate_psnr(proc, orig)
                baseline_total_ssim += calculate_ssim(proc, orig)


        avg_psnr = total_psnr / len(val_loader)
        avg_baseline_psnr = baseline_total_psnr / len(val_loader)
        avg_ssim = total_ssim / len(val_loader)
        avg_baseline_ssim = baseline_total_ssim / len(val_loader)
        
        summary_msg = (
            f"Epoch [{epoch}/{num_epochs}], Loss: {avg_train_loss:.4f}, "
            f"Val PSNR (Rede): {avg_psnr:.2f} dB, Baseline PSNR: {avg_baseline_psnr:.2f} dB, "
            f"Val SSIM (Rede): {avg_ssim:.4f}, Baseline SSIM: {avg_baseline_ssim:.4f}"
        )
        print(summary_msg)
        #avg_val_loss = val_loss / len(val_loader)

        # Salva o modelo a cada época (opcional)
        checkpoint_path = f"model_epoch_{epoch}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Modelo salvo em {checkpoint_path}")


In [42]:
train_loop()
    

Starting epoch 0
Epoch: 0/50 - Batch: 0/856 - loss: 0.032921
Epoch: 0/50 - Batch: 100/856 - loss: 0.028255
Epoch: 0/50 - Batch: 200/856 - loss: 0.033035
Epoch: 0/50 - Batch: 300/856 - loss: 0.030309
Epoch: 0/50 - Batch: 400/856 - loss: 0.044145
Epoch: 0/50 - Batch: 500/856 - loss: 0.027116
Epoch: 0/50 - Batch: 600/856 - loss: 0.029818
Epoch: 0/50 - Batch: 700/856 - loss: 0.034245
Epoch: 0/50 - Batch: 800/856 - loss: 0.037121

avg_loss: 0.03209924878265326
Epoch [0/50], Loss: 0.0321, Val PSNR (Rede): 22.29 dB, Baseline PSNR: 22.29 dB, Val SSIM (Rede): 0.5506, Baseline SSIM: 0.5508
Modelo salvo em model_epoch_0.pth
Starting epoch 1
Epoch: 1/50 - Batch: 0/856 - loss: 0.028126
Epoch: 1/50 - Batch: 100/856 - loss: 0.034965
Epoch: 1/50 - Batch: 200/856 - loss: 0.034166
Epoch: 1/50 - Batch: 300/856 - loss: 0.028730
Epoch: 1/50 - Batch: 400/856 - loss: 0.028679
Epoch: 1/50 - Batch: 500/856 - loss: 0.036695
Epoch: 1/50 - Batch: 600/856 - loss: 0.034961
Epoch: 1/50 - Batch: 700/856 - loss: 0.033

KeyboardInterrupt: 