In [5]:
import torch
from RealESRGAN import RealESRGAN
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from PIL import Image
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",6)
print(f"Device:\t\t{device}")
import pytorch_model_summary as tms

GPUs used:	8
Device:		cuda:6


In [2]:
class_list=['BRNT','BRLC','BRIL','BRID','BRDC']
params={'image_size':1024,
        'lr':1e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':None,
        'data_path':'../../data/NIA/',
        'image_count':10000,
        'inch':3,
        'modch':64,
        'outch':3,
        'chmul':[1,2,4,8,16,32],
        'numres':2,
        'dtype':torch.float32,
        'cdim':10,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':2.5,
        'threshold':0.1,
        'ddim':True,
        }

topilimage = torchvision.transforms.ToPILImage()
tf=transforms.ToTensor()
def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

In [72]:

# 예시 데이터셋 클래스 (HR, LR 이미지 쌍)
class SRDataset(Dataset):
    def __init__(self,params, hr_imgs,count_list, scale=2):
        self.scale = scale
        self.hr_images = hr_imgs
        self.count_list=count_list
    def __len__(self):
        return 50000
    
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    def random_resample(self,image, scale_factor=0.5):
        # 사용할 보간 방법 리스트
        image=topilimage(image)
        resampling_methods = [
            Image.NEAREST,
            Image.BOX,
            Image.BILINEAR,
            Image.HAMMING,
            Image.BICUBIC,
            Image.LANCZOS
        ]

        # 랜덤하게 보간 방법 선택
        chosen_method = random.choice(resampling_methods)

        # 새로운 크기 계산
        new_size = (
            int(image.width * scale_factor),
            int(image.height * scale_factor)
        )

        # 이미지를 새로운 크기로 변경
        image = image.resize(new_size, resample=chosen_method)
        image=tf(image)*2-1
        return image
    
    def __getitem__(self, index):
        if index//10000==0:
            start=0
            ind=random.randint(start,self.count_list[index//10000]-1)
            hr_image=self.trans(tf(Image.open(self.hr_images[ind]).convert('RGB').resize((params['image_size'],params['image_size']))))

        elif index//10000==1:
            start=self.count_list[0] 
            ind=random.randint(start,start+self.count_list[index//10000]-1)
            hr_image=self.trans(tf(Image.open(self.hr_images[ind]).convert('RGB').resize((params['image_size'],params['image_size']))))

        elif index//10000==2:
            start=self.count_list[0]+self.count_list[1]
            ind=random.randint(start,start+self.count_list[index//10000]-1)
            hr_image=self.trans(tf(Image.open(self.hr_images[ind]).convert('RGB').resize((params['image_size'],params['image_size']))))
        elif index//10000==3:
            start=self.count_list[0]+self.count_list[1]+self.count_list[2]
            ind=random.randint(start,start+self.count_list[index//10000]-1)
            hr_image=self.trans(tf(Image.open(self.hr_images[ind]).convert('RGB').resize((params['image_size'],params['image_size']))))

        elif index//10000==4:
            start=self.count_list[0]+self.count_list[1]+self.count_list[2]+self.count_list[3]
            ind=random.randint(start,start+self.count_list[index//10000]-1)
            hr_image=self.trans(tf(Image.open(self.hr_images[ind]).convert('RGB').resize((params['image_size'],params['image_size']))))

        hr_image=hr_image*2-1    
        lr_image = self.random_resample(hr_image, scale_factor=0.5)
        
        return lr_image, hr_image
    
image_path=[]
count_list=[]
for i in tqdm(range(len(class_list))):
    image_list=glob(params['data_path']+class_list[i]+'/*.jpeg')
    count_list.append(len(image_list))
    for j in range(len(image_list)):
        image_path.append(image_list[j])

train_dataset=SRDataset(params,image_path,count_list)
dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)

100%|██████████| 5/5 [00:24<00:00,  4.98s/it]


In [68]:

class CombinedLoss(nn.Module):
    def __init__(self, device, vgg_weight=0.006, l1_weight=1.0, adv_weight=0.001):
        super(CombinedLoss, self).__init__()
        
        # L1 Loss
        self.l1_loss = nn.L1Loss()

        # VGG19 for Perceptual Loss
        vgg = models.vgg19(pretrained=True).features[:16].eval().to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

        # Weights for each loss component
        self.vgg_weight = vgg_weight
        self.l1_weight = l1_weight
        self.adv_weight = adv_weight

    def perceptual_loss(self, fake_img, real_img):
        fake_features = self.vgg(fake_img)
        real_features = self.vgg(real_img)
        return F.l1_loss(fake_features, real_features)

    def forward(self, fake_img, real_img, disc_fake_pred=None):
        # L1 Loss
        l1_loss = self.l1_loss(fake_img, real_img)
        
        # Perceptual Loss
        perceptual_loss = self.perceptual_loss(fake_img, real_img)
        
        # Adversarial Loss (if provided)
        if disc_fake_pred is not None:
            adversarial_loss = F.softplus(-disc_fake_pred).mean()
        else:
            adversarial_loss = 0.0

        # Combined loss with weights
        total_loss = (
            self.l1_weight * l1_loss +
            self.vgg_weight * perceptual_loss +
            self.adv_weight * adversarial_loss
        )
        
        return total_loss
    
model = RealESRGAN(device, scale=2).model.to(device)
criterion = CombinedLoss(device=device, vgg_weight=0.006, l1_weight=1.0, adv_weight=0.001)
optimizer = optim.Adam(model.parameters(), lr=1e-4)



In [78]:
train_dataset[0][0].min()

tensor(-1.)

In [80]:
img=torch.rand(2,3,512,512).to(device)
tms.summary(model, img,show_input=False,print_summary=True)

---------------------------------------------------------------------------
      Layer (type)            Output Shape         Param #     Tr. Param #
          Conv2d-1       [2, 64, 256, 256]           6,976           6,976
            RRDB-2       [2, 64, 256, 256]         719,424         719,424
            RRDB-3       [2, 64, 256, 256]         719,424         719,424
            RRDB-4       [2, 64, 256, 256]         719,424         719,424
            RRDB-5       [2, 64, 256, 256]         719,424         719,424
            RRDB-6       [2, 64, 256, 256]         719,424         719,424
            RRDB-7       [2, 64, 256, 256]         719,424         719,424
            RRDB-8       [2, 64, 256, 256]         719,424         719,424
            RRDB-9       [2, 64, 256, 256]         719,424         719,424
           RRDB-10       [2, 64, 256, 256]         719,424         719,424
           RRDB-11       [2, 64, 256, 256]         719,424         719,424
           RRDB-12      

