In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from time import time

from dataloader import ImageDataset
from model import *
from utils import save_model

np.random.seed(69)
torch.manual_seed(69)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(f'Using: {device}')
if str(device) == 'cuda': print(torch.cuda.get_device_name()) 

In [None]:
BATCH_SIZE = 4
EPOCHS = 200
IMAGE_SIZE = 512

# for checkpoints
SAVE_CHECKPOINTS = False
CHECKPOINT_INTERVAL = 5

# for preload
preload = True
preload_optimizer = False
checkpoint_model = 'SRTransformer6_best_86.pth'

train_set = ImageDataset("data/train/", 2, size=IMAGE_SIZE)
test_set = ImageDataset("data/validation/", 2, size=IMAGE_SIZE)

train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

# init model
model = SRTransformer6()
model.to(device)

# loss function
loss_fn = nn.L1Loss()
loss_mse = nn.MSELoss()

# create the optimizer
optimizer = optim.Adam(model.parameters(), lr=3.56E-04)
# optimizer = optim.AdamW(model.parameters(), lr=0.00001)

In [None]:
from torch_lr_finder import LRFinder
from utils import save_plot

# search for the best learning rate
lr_finder = LRFinder(model, optimizer, loss_fn, device=device)
lr_finder.range_test(train_loader, start_lr=1e-6, end_lr=1, num_iter=50)
ax, lr = lr_finder.plot()
lr_finder.reset()

# save figure
fig = ax.get_figure()
save_plot(fig, f'{model.__class__.__name__}_lr_finder')

In [None]:
# Load other model
if preload:
    checkpoint = torch.load('models/'+checkpoint_model)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch'] + 1
    if preload_optimizer: 
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    epoch = 0
    for g in optimizer.param_groups:
        g['lr'] = lr

In [None]:

all_losses = []
test_losses = []
best_loss = None
start_epoch = epoch
for epoch in range(epoch, EPOCHS):
    t0 = time()
    all_losses.append([])
    model.train() # training mode
    for batch in train_loader:
        # load data to the device
        x, y = batch[0].to(device), batch[1].to(device)

        # optimizer.zero_grad()
        for param in model.parameters():
            param.grad = None

        out = model.forward(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()

        # hold the loss
        all_losses[-1].append(loss.item())

    model.eval() # evaluation mode
    with torch.no_grad():
        t1 = time()
        test_loss = 0
        test_loss_mse = 0
        for batch in test_loader:
            # load data to the device
            x, y = batch[0].to(device), batch[1].to(device)
            out = model.forward(x)
            loss = loss_fn(out, y)
            loss_mse_data = loss_mse(out, y)
            test_loss += loss.item()
            test_loss_mse += loss_mse_data.item()
        test_losses.append(test_loss / len(test_loader))
        print(f'{epoch}: Val loss (MSE): {test_loss_mse:.6f} | Val loss: {test_losses[-1]:.6f} | loss: {sum(all_losses[-1])/len(train_loader):.6f} | Train time: {t1-t0:.2f} | Test time: {time()-t1:.2f}')

    if SAVE_CHECKPOINTS and epoch % CHECKPOINT_INTERVAL == CHECKPOINT_INTERVAL-1:
        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        save_model(state, 'models', f'{model.__class__.__name__}_{epoch:02d}.pth')
    
    # save the best model
    if best_loss is None or best_loss > test_losses[-1]:
        best_loss = test_losses[-1]
        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        save_model(state, 'models', f'{model.__class__.__name__}_best_{epoch:02d}.pth')

state = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}
save_model(state, 'models', f'{model.__class__.__name__}_{epoch:02d}_final.pth')



In [None]:
import matplotlib.pyplot as plt
from utils import save_plot

train_losses = [sum(l)/len(train_loader) for l in all_losses]

fig = plt.figure(figsize=(8,6))
epoch_data = list(range(start_epoch, start_epoch+len(train_losses)))
plt.plot(epoch_data, test_losses, label='validation')
plt.plot(epoch_data, train_losses, label='train')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()
save_plot(fig, f'{model.__class__.__name__}_losses')