In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import cv2
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.models as models
import os
import torch.optim as optim
from torch.autograd import Variable
import glob
import skimage.color as sc


import os 
#from google.colab import drive
#drive.mount('/content/gdrive')
#os.chdir('/content/gdrive/My Drive/SRGAN')

from loss_r import GeneratorLoss
from dataload.srdataload import srDataset_4c
from utils.util import psnr , PSNR , predict_simager


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
learning_rate = 6.3e-4  #6.3e-4 1e-5

class residualBlock(nn.Module):
    def __init__(self,channel=3):
        super(residualBlock,self).__init__()
        self.channels = 3
        self.num_filter = 64
        self.conv1= nn.Conv2d (self.num_filter,self.num_filter,(3,3) ,stride=1,padding = 'same')
        self.batchn1 = nn.BatchNorm2d (self.num_filter)
        self.prelu1 = nn.PReLU()
        self.conv2= nn.Conv2d (self.num_filter,self.num_filter,(3,3) ,stride=1,padding = 'same')
        self.batchn2 = nn.BatchNorm2d (self.num_filter)
        
        
    def forward(self,x):
        in_x = x
        out = self.conv1(x)
        out = self.batchn1(out)
        out = self.prelu1(out)
        out = self.conv2(out)
        out = self.batchn2(out)
        out = out + in_x
        
        return out
    
class  Generator(nn.Module):
    
    def __init__(self,channel=3):
        super(Generator,self).__init__()
        self.channels = channel
        self.num_filter = 64
        self.conv1= nn.Conv2d (self.channels,self.num_filter,(9,9) ,stride=1,padding=4)
        self.prelu1 = nn.PReLU()
        self.f1 = self.f_layers(residualBlock,5)
        
        self.conv2= nn.Conv2d (self.num_filter,self.num_filter,(3,3) ,stride=1,padding =1)
        self.batchn2 = nn.BatchNorm2d (self.num_filter)
        
        # 卷积 后放大图片
        self.conv3 = nn.Conv2d (self.num_filter,4 * self.num_filter,(3,3) ,stride=1,padding=1)
        self.pshuffle1 = nn.PixelShuffle(2)
        self.prelu3 = nn.PReLU()
        
        self.conv4 = nn.Conv2d ( self.num_filter,4 * self.num_filter,(3,3) ,stride=1,padding=1)
        self.pshuffle2 = nn.PixelShuffle(2)
        self.prelu4 = nn.PReLU()
        
        self.conv5= nn.Conv2d (self.num_filter,self.channels,(9,9) ,stride=1,padding=4)
        
        
        
       
    
    def f_layers ( self,block,num):
        layers =[]
        for i in range(num):
            layers.append(block())
            
        return nn.Sequential ( *layers )
            
        
        
    def forward(self,x):
        
        out = self.conv1(x)
        out = self.prelu1(out)
        in_out = out
        out = self.f1(out)
        out = self.conv2(out)
        out = self.batchn2(out)
        out = in_out + out
        
        out = self.conv3(out)
        out = self.pshuffle1(out)
        out = self.prelu3(out)
        
        out = self.conv4(out)
        out = self.pshuffle2(out)
        out = self.prelu4(out)
        
        out = self.conv5(out)
        
        return (torch.tanh(out) + 1)/2
    
class Discriminator(nn.Module):
    def __init__(self,channel=3):
        super(Discriminator,self).__init__()
        self.channels = channel
        self.num_filter = 64
        self.conv1= nn.Conv2d (self.channels,self.num_filter,(3,3) ,stride=1,padding=1)
        self.lrelu1 = nn.LeakyReLU(0.2)
    
        self.conv2= nn.Conv2d (self.num_filter,self.num_filter,(3,3) ,stride=2,padding=1)
        self.batchn2 = nn.BatchNorm2d (self.num_filter)
        self.lrelu2 = nn.LeakyReLU(0.2)
       
        self.conv3= nn.Conv2d (self.num_filter,2 * self.num_filter,(3,3) ,stride=1,padding=1)
        self.batchn3 = nn.BatchNorm2d ( 2 * self.num_filter)
        self.lrelu3 = nn.LeakyReLU(0.2)
        self.conv4= nn.Conv2d ( 2 * self.num_filter,2 * self.num_filter,(3,3) ,stride=2,padding=1)
        self.batchn4 = nn.BatchNorm2d ( 2 * self.num_filter)
        self.lrelu4 = nn.LeakyReLU(0.2)
        
        self.conv5= nn.Conv2d (2 * self.num_filter,4 * self.num_filter,(3,3) ,stride=1,padding=1)
        self.batchn5 = nn.BatchNorm2d ( 4 * self.num_filter)
        self.lrelu5 = nn.LeakyReLU(0.2)
        self.conv6= nn.Conv2d ( 4 * self.num_filter,4 * self.num_filter,(3,3) ,stride=2,padding=1)
        self.batchn6 = nn.BatchNorm2d ( 4 * self.num_filter)
        self.lrelu6 = nn.LeakyReLU(0.2)
        
        self.conv7= nn.Conv2d (4 * self.num_filter,8 * self.num_filter,(3,3) ,stride=1,padding=1)
        self.batchn7 = nn.BatchNorm2d ( 8 * self.num_filter)
        self.lrelu7 = nn.LeakyReLU(0.2)
        self.conv8= nn.Conv2d ( 8 * self.num_filter,8 * self.num_filter,(3,3) ,stride=2,padding=1)
        self.batchn8 = nn.BatchNorm2d ( 8 * self.num_filter)
        self.lrelu8 = nn.LeakyReLU(0.2)
        
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
           
            
        )
        
        
    def forward(self,x):
        batch_size = x.size(0)
        out = self.conv1(x)
        out = self.lrelu1(out)
        
        out = self.conv2(out)
        out = self.batchn2(out)
        out = self.lrelu2(out)
        
        out = self.conv3(out)
        out = self.batchn3(out)
        out = self.lrelu3(out)
        out = self.conv4(out)
        out = self.batchn4(out)
        out = self.lrelu4(out)
        
        out = self.conv5(out)
        out = self.batchn5(out)
        out = self.lrelu5(out)
        out = self.conv6(out)
        out = self.batchn6(out)
        out = self.lrelu6(out)
        
        out = self.conv7(out)
        out = self.batchn7(out)
        out = self.lrelu7(out)
        out = self.conv8(out)
        out = self.batchn8(out)
        out = self.lrelu8(out)
           
        out = self.fc(out).view(batch_size)
        
        return torch.sigmoid(out)
        

def train_model( g_model,d_model,dataset1,steps,batch_size ,loss1,device,learning_rate ):
    train_dataloader = DataLoader(dataset1, batch_size= batch_size, shuffle=True)
    
    loss_f = loss1
    size = len(train_dataloader.dataset)
    
    optimizerG = optim.Adam(g_model.parameters(),learning_rate)
    optimizerD = optim.Adam(d_model.parameters(),learning_rate)
  
    #img_path = 'test_imgs/b_4down3.jpg'
    #img_path2 = 'test_imgs/butterfly.png'
    img_path = 'test_imgs/down1.png'
    img_path2 = 'test_imgs/t2_1.png'
    data_transform = transforms.Compose([
    transforms.ToTensor()#,transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    
    for step in range( steps):
        g_model.train()
        d_model.train()
        for batch , (X, y) in enumerate(train_dataloader):
        # Compute prediction and loss
           
            input1 ,target = Variable(X), Variable(y, requires_grad=False)
            input1 = input1.to(device)
            target = target.to(device)
           
            real_img =target
            fake_img = g_model(input1)
             # 更新判别器D的参数
            d_model.zero_grad()
            real_out = d_model(real_img).mean()
            fake_out = d_model(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()
            
            # 更新生成器G的参数
            g_model.zero_grad()
            
            fake_img = g_model(input1)
            fake_out = d_model(fake_img).mean()
            ##
            #print ( fake_out.shape)
            #print ( fake_img.shape)
            #print ( real_img.shape)
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            
            fake_img = g_model(input1)
            fake_out = d_model(fake_img).mean()
            
            
            optimizerG.step()
       
            predict_simager ( g_model,img_path,img_path2,data_transform,device)

            if batch % 100 == 0:
                loss, current = g_loss.item(), batch * len(X)
                predict_simager ( g_model,img_path,img_path2,data_transform,device)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

            ''' if (batch+1) % 500 == 0:
              torch.save(g_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'gmodel_weights2.pth')
              torch.save(d_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'dmodel_weights2.pth')
            if (batch+1) % 700 == 0:
              torch.save(g_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'gmodel_weights2.pth')
              torch.save(d_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'dmodel_weights2.pth')
            
            if (batch+1) % 1000 == 0:
              torch.save(g_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'gmodel_weights2.pth')
              torch.save(d_model.state_dict(), 'save_models/'+str(step)+'_'+str(batch)+'dmodel_weights2.pth')'''


        torch.save(g_model.state_dict(), 'save_models_d/'+str(step)+'gmodel_weights2.pth')
        torch.save(d_model.state_dict(), 'save_models_d/'+str(step)+'dmodel_weights2.pth')
        predict_simager ( g_model,img_path,img_path2,data_transform,device)

if __name__ == "__main__":
    
    #hr_path = r'gan_data/hr_hdf5_file.h5'
    #lr_path = r'gan_data/lr_hdf5_file.h5'
    hr_path = r'gan_data/hr_srgan.h5'
    lr_path = r'gan_data/lr_srgan.h5'

    data_transform = transforms.Compose([
        transforms.ToTensor()#,transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    #sr_data = srDataset_4c( lr_dir = lr_path,hr_dir = hr_path, mode='train', transform=data_transform)
    generator = Generator(4).to(device)
    print('# generator parameters:', sum(param.numel() for param in generator.parameters()))
    discriminator = Discriminator(4).to(device) #Discriminator
    print('# discriminator parameters:', sum(param.numel() for param in discriminator.parameters()))
    generator.load_state_dict(torch.load('save_models_d/9gmodel_weights2.pth'))
    discriminator.load_state_dict(torch.load('save_models_d/9dmodel_weights2.pth'))
   # generator_criterion = GeneratorLoss().to(device)
    #print(sr_data[5][0].shape)  # 此处为__getitem__的用法
    
    #train_model( generator,discriminator,sr_data,10,64 ,generator_criterion,device,learning_rate)
        

Using cuda device
# generator parameters: 744588
# discriminator parameters: 5216001


In [None]:
from PIL import Image
img_path = 'test_imgs/down1.jpg'
img_path2 = 'test_imgs/t2_1.png'
img = Image.open(img_path)
target = Image.open(img_path2)
#imgt = transforms1(img).unsqueeze(0)
#imgt = imgt.to(device)
    
imgn = np.array (img)
target = np.array ( target)
print ( imgn.shape)
print ( target.shape)


(64, 64, 3)
(256, 256, 4)


In [2]:
def predict_simagetr1 ( model,img_path,img_path2,transforms1,device):
    img = Image.open(img_path)
    target = Image.open(img_path2)
    imgt = transforms1(img).unsqueeze(0)
    imgt = imgt.to(device)
    
    imgn = np.array (img)
    target = np.array ( target)
    print ( imgn.shape)
    
    pre = model(imgt)
    pren = pre.cpu().detach().numpy().squeeze(0)
    pren *= 225
    pren = np.float32(pren).transpose(1,2,0)
    #pre_t = detransform(pre,device)
    #print ( imgn)
    #print ( pre_t)
    pre_s = pren.astype(np.uint8)
    pre_s1 = Image.fromarray ( pre_s)
    pre_s1.save ("test_set5/d_data/test_image24.png")
    print ( pren.shape)
    p1 = psnr (target,pren)
    p3 = PSNR ( pren,target)
    #p4 = PSNR (imgn,target)
    #p2 = psnr ( target,imgn )
    print ( p1)
    print ( p3)
    #print ( p4)
    #print ( p2)
img_path = 'C:/e/SR_datasets/level6_process/test/down_cubic/24t.png'
img_path2 = 'C:/e/SR_datasets/level6_process/test/down_cubic/24down.png'
#img_path = 'C:/e/SR_datasets/SR/processed/LR2/9070.jpg' #C:\e\SR_datasets\SR\processed\HR
#img_path2 = 'C:/e/SR_datasets/SR/processed/HR/9070.jpg'  #C:/e/SR_datasets/SR/test/butterfly/0h.jpg
predict_simagetr1 ( generator,img_path2,img_path,data_transform,device)

(64, 64, 4)
(256, 256, 4)
20.484045741139575
20.484045741139575


In [2]:
from google.colab import drive 
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!nvidia-smi

Fri May  6 01:07:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces