In [None]:
from torch.utils.data.dataloader import DataLoader
from tensorboardX import SummaryWriter
import torch.backends.cudnn as cudnn
import torch.optim as optim
from tqdm import tqdm
import torch
import os

%run Model.ipynb
%run DataLoader.ipynb

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
train_dir = "/home/mj/HardDisk/Github/Image_Restorer/Dataset/DIV2K_train_HR"
valid_dir = "/home/mj/HardDisk/Github/Image_Restorer/Dataset/DIV2K_valid_HR"
outputs_dir = "../Model"
batch_size = 8
threads = 4
lr = 5e-4
num_epochs = 100
patch_size = 128
jpeg_quality = 10
use_augmentation = True
use_fast_loader = False


model = My_Model()

model = model.to(device)
criterion = nn.MSELoss(reduction='sum')

optimizer = optim.Adam(model.parameters(), lr=lr)

train_dataset = Dataset(train_dir, patch_size, jpeg_quality, use_augmentation, use_fast_loader)
dataloader_train = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
                        num_workers=threads, pin_memory=True, drop_last=True)

valid_dataset = Dataset(valid_dir, patch_size, jpeg_quality, use_augmentation, use_fast_loader)
dataloader_valid = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True,
                        num_workers=threads, pin_memory=True, drop_last=True)

writer = SummaryWriter("../log")

for epoch in range(num_epochs):
    train_losses = AverageMeter()
    valid_losses = AverageMeter()
    target_losses = AverageMeter()
    
    n_batchs = (len(train_dataset) - len(train_dataset) % batch_size)
    i = 0
    with tqdm(total=n_batchs) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, num_epochs))
        for data in dataloader_train:
            inputs, labels = data[0].to(device), data[1].to(device)

            preds = model(inputs)

            loss = criterion(preds, labels)
            target = criterion(inputs, labels)
            train_losses.update(loss.item(), len(inputs))
            target_losses.update(target.item(), len(inputs))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _tqdm.set_postfix(loss='{:.6f}, target{:.6f}'.format(train_losses.avg, target_losses.avg))
            _tqdm.update(len(inputs))
            
            ###
            writer.add_scalar('Train//Loss', train_losses.avg, epoch*n_batchs + i)
            writer.flush()
            i += 1
            ###
            
        
        for data in dataloader_valid:
            inputs, labels = data[0].to(device), data[1].to(device)
            
            preds = model(inputs)
            
            loss = criterion(preds, labels)
            valid_losses.update(loss.item(), len(inputs))
            

    ###
    writer.add_scalar('Valid//Loss', valid_losses.avg, epoch)
    writer.flush()
    ###
    
    torch.save(model.state_dict(), os.path.join(outputs_dir, 'My_Model_epoch_{}.pth'.format(epoch)))
    
    if((epoch+1)%20):
        optimizer = optim.Adam(model.parameters(), lr=0.5*lr)

writer.close()

epoch: 1/100: 100%|██████████| 1296/1296 [00:28<00:00, 44.99it/s, loss=20654.657272, target1114.252736]
epoch: 2/100: 100%|██████████| 1296/1296 [00:31<00:00, 40.81it/s, loss=9594.526415, target1116.258057] 
epoch: 3/100: 100%|██████████| 1296/1296 [00:32<00:00, 40.44it/s, loss=7011.711272, target1109.246432]
epoch: 4/100: 100%|██████████| 1296/1296 [00:34<00:00, 37.74it/s, loss=6854.404744, target1099.734469]
epoch: 5/100: 100%|██████████| 1296/1296 [00:33<00:00, 38.54it/s, loss=6576.793816, target1121.778900]
epoch: 6/100: 100%|██████████| 1296/1296 [00:33<00:00, 38.24it/s, loss=6346.408656, target1116.162322]
epoch: 7/100: 100%|██████████| 1296/1296 [00:34<00:00, 38.03it/s, loss=7303.589630, target1113.969589]
epoch: 8/100: 100%|██████████| 1296/1296 [00:33<00:00, 38.83it/s, loss=6229.056395, target1130.198233]
epoch: 9/100: 100%|██████████| 1296/1296 [00:34<00:00, 38.07it/s, loss=5582.721862, target1084.473521]
epoch: 10/100: 100%|██████████| 1296/1296 [00:35<00:00, 36.87it/s, loss

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import torchvision

In [None]:
patch_size = 640

test_dir = "/home/mj/HardDisk/Github/Image_Compressor/Dataset/DIV2K_valid_HR"
test_dataset = Dataset(images_dir, patch_size, jpeg_quality, False, use_fast_loader)
test_dataloader = DataLoader(dataset=test_dataset,
                        batch_size=1,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True,
                        drop_last=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for data in test_dataloader:
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    preds = model(inputs)
    
#     print((inputs*255).type(torch.cuda.ByteTensor))
#     print((labels*255).type(torch.cuda.ByteTensor))
#     print((preds*255).type(torch.cuda.ByteTensor))

# torchvision.utils.make_grid(images, nrow=5).permute(1, 2, 0)

    inputs = (inputs*255).cpu().type(torch.ByteTensor)
    labels = (labels*255).cpu().type(torch.ByteTensor)
    preds = (preds*255).cpu().type(torch.ByteTensor)
    
    plt.figure(figsize=(15, 15));
    plt.subplot(131);plt.imshow(torchvision.utils.make_grid(inputs,nrow=1).permute(1, 2, 0));plt.title("JPG");plt.xticks([]);plt.yticks([]);
    plt.subplot(132);plt.imshow(torchvision.utils.make_grid(labels,nrow=1).permute(1, 2, 0));plt.title("Ori");plt.xticks([]);plt.yticks([]);
    plt.subplot(133);plt.imshow(torchvision.utils.make_grid(preds,nrow=1).permute(1, 2, 0));plt.title("Fixed");plt.xticks([]);plt.yticks([]);
    plt.show()
    
    break