In [18]:
import torch
from torch import nn
from myModel import SRResNet
from myDataset import dataSet
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import numpy as np


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataPath = './output/'
writer = SummaryWriter()

def run():
    SRnet = SRResNet()
    
    SRnet =  SRnet.to(device)
    loss_function = nn.MSELoss().to(device) 
    optimizer = optim.Adam(SRnet.parameters(),lr=0.001) 
    
    train_dataset = dataSet(dataPath=dataPath, dataType='train')
    train_loader = DataLoader(train_dataset,batch_size=64)
    
    Epoches = 50
    for epoch in range(Epoches):
        SRnet.train()
        total_loss = 0.0
    
    
        #n_iter = len(train_loader)
        #print(n_iter)
        
        for i,(result, cropResult) in tqdm(enumerate(train_loader,1)):
            result = result.to(device)
            cropResult = cropResult.to(device)
            #print(i)
            
            #forward 
            next_imgs = SRnet(cropResult)
            
            #calculate loss
            loss = loss_function(next_imgs,result)
            #backward
            optimizer.zero_grad()
            loss.backward()
            #renew model
            optimizer.step()
            
            total_loss += loss.item()
            
            
            if i == 4:
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_1', make_grid(cropResult[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_2', make_grid(next_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_3', make_grid(result[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
            
            


            
        average_lose = total_loss/(i+1)
        writer.add_scalar('SRResNet/MSE_Loss', average_lose, epoch)   

        
        #delete memory
        del result, cropResult, next_imgs
        
    writer.close()   
    print("trainning end...")




In [20]:
run()

4it [00:34,  8.56s/it]
4it [00:32,  8.11s/it]
4it [00:32,  8.09s/it]
4it [00:32,  8.16s/it]
4it [00:32,  8.13s/it]
4it [00:32,  8.08s/it]
4it [00:32,  8.14s/it]
4it [00:32,  8.13s/it]
4it [00:32,  8.10s/it]
4it [00:32,  8.07s/it]
4it [00:34,  8.69s/it]
4it [00:36,  9.04s/it]
4it [00:35,  8.88s/it]
4it [00:35,  8.92s/it]
4it [00:35,  8.96s/it]
4it [00:35,  8.90s/it]
4it [00:35,  8.80s/it]
4it [00:35,  8.85s/it]
4it [00:35,  8.92s/it]
4it [00:35,  8.89s/it]
4it [00:35,  8.88s/it]
4it [00:35,  8.84s/it]
4it [00:35,  8.93s/it]
4it [00:35,  8.89s/it]
4it [00:35,  8.84s/it]
4it [00:35,  8.84s/it]
4it [00:35,  8.79s/it]
4it [00:35,  8.85s/it]
4it [00:35,  8.88s/it]
4it [00:35,  8.76s/it]
4it [00:35,  8.77s/it]
4it [00:35,  8.94s/it]
4it [00:35,  8.95s/it]
4it [00:35,  8.92s/it]
4it [00:35,  8.84s/it]
4it [00:35,  8.87s/it]
4it [00:35,  8.94s/it]
4it [00:35,  8.87s/it]
4it [00:35,  8.98s/it]
4it [00:35,  8.95s/it]
4it [00:35,  8.88s/it]
4it [00:35,  8.88s/it]
4it [00:35,  8.88s/it]
4it [00:35,

trainning end...



