In [1]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import *
from torchvision.transforms import ToTensor
from torchvision.io import read_image
from PIL import Image
import scipy
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [2]:
from utils import functions, datasets, metrics

### Sprawdzanie dostępności karty graficznej:

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
torch.backends.cudnn.benchmark = True

Definowanie modelu składającego się z 8 warst kowolucyjnych o filtra wielkości 3x3. Warstwy ukryte posiadają funkcje aktywacji 'ReLU'. Wagi będą inicjalizowane przy użyciu funkcji Kaiming, natomiast obiążenie będzie równe 0.

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.stack = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding='same')
        )
        
        def init_weights(m):
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
                torch.nn.init.zeros_(m.bias)
            
        self.apply(init_weights)
        
    def forward(self, x):
        pred = self.stack(x)
        return x + pred

In [None]:
model = NeuralNetwork().to(device)
print(model)

## Wstępne uczenie sieci

Wszytywanie zbioru wykorzystywanego w procesie wstępnęgo uczenia

Aby notebook działał powinieneś wcześniej wygenerować zbiór uczący z wykorzystaniem skryptu "generate_dataset_MZSR.py" w głównym katalogu

In [None]:
training_data = datasets.MZSRPreTrain(
    'datasets/train_SR_X2.dataset',
    transform=ToTensor(),
    target_transform=ToTensor()
)

In [None]:
pred = training_data[1][0].cpu().detach().numpy()
pred = np.moveaxis(pred, 0, -1)

plt.imshow(pred)
plt.show()

pred = training_data[1][1].cpu().detach().numpy()
pred = np.moveaxis(pred, 0, -1)

plt.imshow(pred)
plt.show()

In [None]:
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True, pin_memory=True)

In [None]:
def pre_train(dataloader, model, loss_fn, optimizer, scheduler):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        pred = model(X)

        loss = loss_fn(pred, y)
        
        for param in model.parameters():
            param.grad = None
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        if batch % 100 == 0:
            print(scheduler.get_last_lr())
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)
decay_rate = 0.5
decay_step = 1e5

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: max(0.25, decay_rate ** (32 * epoch // decay_step)))

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    pre_train(train_dataloader, model, loss_fn, optimizer, scheduler)
    torch.save(model.state_dict(), 'models/MZSR_pretrained.model')
print("Done!")

## Wszytywanie zbioru wykorzystywanego w procesie meta-uczenia
Wczytywanie wag nauczonych na etapie wstępnego uczenia

In [None]:
model.load_state_dict(torch.load('models/MZSR_pretrained.model'))

In [None]:
training_data = datasets.MZSRMetaTrain(
    'datasets/train_SR_MZSR.dataset',
    transform=ToTensor(),
    target_transform=ToTensor()
)

In [None]:
pred = training_data[0][0].cpu().detach().numpy()
pred = np.moveaxis(pred, 0, -1)

plt.imshow(pred)
plt.show()

pred = training_data[0][1].cpu().detach().numpy()
pred = np.moveaxis(pred, 0, -1)

plt.imshow(pred)
plt.show()

In [None]:
import learn2learn as l2l
import tqdm

class MetaTrainer:
    def __init__(self, dataset, model, alpha=1e-2, beta=1e-4, loss_fn=nn.L1Loss(), task_iter=5, batch_size=5, task_batch=8):
        self.dataset = dataset
        self.dataloader = DataLoader(dataset, batch_size=2*task_batch, shuffle=True, pin_memory=True)
        
        self.task_iter = task_iter
        self.loss = loss_fn
        self.batch_size = batch_size
        self.task_batch = task_batch
        self.step = 0
        self.beta = beta
        
        self.maml_model = l2l.algorithms.MAML(model, lr=alpha, first_order=False)
        self.optimizer = torch.optim.Adam(self.maml_model.parameters(), lr=self.beta)
        

    def adapt(self, batch, learner):
        data, labels = batch
        data, labels = data.to(device), labels.to(device)

        # Separate data into adaptation/evalutation sets
        adaptation_indices = np.zeros(data.size(0), dtype=bool)
        adaptation_indices[range(0, data.size(0), 2)] = True
        evaluation_indices = torch.from_numpy(~adaptation_indices)
        adaptation_indices = torch.from_numpy(adaptation_indices)
        adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
        evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

        pretrain_error = self.loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(pretrain_error)

        predictions = learner(evaluation_data)
        valid_error = self.loss(predictions, evaluation_labels)

        meta_train_error = valid_error * self.loss_weight[0]
        
        # Adapt the model
        for step in range(1, self.task_iter):
            train_error = self.loss(learner(adaptation_data), adaptation_labels)
            learner.adapt(train_error)
            
            predictions = learner(evaluation_data)
            valid_error = self.loss(predictions, evaluation_labels)
            
            meta_train_error = meta_train_error + valid_error * self.loss_weight[step]
                
        return meta_train_error / self.batch_size, float(pretrain_error), float(meta_train_error)

    def meta_train(self, epochs=5):
        for t in range(epochs):
            tqdm.tqdm.write(f"Epoch {t+1}\n-------------------------------")
            self.epoch()
            torch.save(model.state_dict(), f'models/MZSR_meta-learned-{t+1}.model')
        print("Done!")
        
    def reduce_loss(self, loss):
        error = loss[0]
        
        for i in loss[1:]:
            error = error + i
            
        return i
            
    def epoch(self):
        pbar = tqdm.tqdm(self.dataloader)
        
        for num, batch in enumerate(pbar):
            if num % self.batch_size == 0:
                self.optimizer.zero_grad()
                self.loss_weight = self.get_loss_weights(self.step)
            
                pretrain_losses = 0
                meta_train_losses = 0

            # Compute meta-training loss
            learner = self.maml_model.clone()
            error, pretrain_error, meta_train_error = self.adapt(batch,
                                                                 learner)
            pretrain_losses += pretrain_error / 8
            meta_train_error += meta_train_error / 8
            error.backward()

            self.dataloader.dataset.regenerate_kernel()
            
            # print(num)

            if (num + 1) % self.batch_size == 0:
                self.optimizer.step()
                self.step += 1
                pbar.set_postfix({'loss': f'[{float(pretrain_losses):>7f}, {float(meta_train_error):>7f}]'})
    
    def get_loss_weights(self, step):
        loss_weights = np.ones(shape=(self.task_iter)) * (1.0 / self.task_iter)
        decay_rate = 1.0 / self.task_iter / (10000 / 3)
        min_value= 0.03 / self.task_iter
        
        loss_weights_pre = np.maximum(loss_weights[:-1] - step * decay_rate, np.ones(shape=(self.task_iter-1)) * min_value)
        loss_weight_cur= np.minimum(loss_weights[-1] + step * decay_rate * (self.task_iter-1), 1.0 - ((self.task_iter - 1) * min_value))
        
        loss_weights = np.concatenate((loss_weights_pre, [loss_weight_cur]))
        return loss_weights

In [None]:
torch.cuda.empty_cache()

trainer = MetaTrainer(training_data, model, task_iter=5)

trainer.meta_train(epochs=3)

## Testowanie nauczonego modelu

Funkcja zapisuje 4 zdjęcia w katalogu result: zdjęcie w niskiej rozdzielczości (low), wynik działania algorytmu 'bicubic' (bicubic), wynik działania sieci bez douczania (init) oraz wynik po wykonaniu douczenia z wykorzystaniem idealnego kernela zmniejszającego (trained)

In [None]:
import gkernel
import scipy.io as sio
from image_resize import *
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

def meta_test(model, sample_number, dataset=None):
    model.load_state_dict(torch.load('models/MZSR_meta-learned-3.model'))
    # load()
    kernel = gkernel.anisotropic_Gaussian(15, 0, 0.2, 0.2)
    
    path = 'datasets/'
    results_path = 'results/'
    if dataset == 'Urban':
        path += 'Urban100/img_'
        results_path += 'Urban100/'
    else:
        path += 'DIV2K/DIV2K/DIV2K_valid_HR/'
        results_path += 'DIV2K/'
        
    hr_image = mpimg.imread(f'{path}{sample_number}.png') #Urban100/img_085.png
        
    train_lr_image = image_resize(hr_image, scale=1/2, kernel=kernel).astype(np.float32)
    
    image = Image.fromarray(np.uint8(train_lr_image*255), mode='RGB')
    image.save(f'{results_path}{sample_number}-low.png')
    
    training_data = datasets.MZSRMetaTest(
        train_lr_image,
        kernel,
        transform=ToTensor(),
        target_transform=ToTensor()
    )
    dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
    
    lr_image = image_resize(train_lr_image, scale=2, kernel='cubic').astype(np.float32)
    image = Image.fromarray(np.uint8(np.clip(lr_image, 0, 1)*255), mode='RGB')
    image.save(f'{results_path}{sample_number}-bicubic.png')
    
    print(metrics.PSNR(np.clip(lr_image, 0, 1), hr_image))
    print(metrics.SSIM(np.clip(lr_image, 0, 1), hr_image))
    
    train_lr_image = ToTensor()(train_lr_image)
    train_lr_image = train_lr_image[None]
    
    lr_image = ToTensor()(lr_image)
    lr_image = lr_image[None]
    
    model.eval()
    with torch.no_grad():
        X = lr_image.to(device)
        pred = model(X)
        
        pred = pred.cpu().detach().numpy()[0]
        pred = np.moveaxis(pred, 0, -1)

        pred = pred.clip(0, 1)
        image = Image.fromarray(np.uint8(pred*255), mode='RGB')
        image.save(f'{results_path}{sample_number}-init.png')
        
    model.train()
    size = len(dataloader.dataset)
    for i in range(5):
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    model.eval()
    with torch.no_grad():
        X = lr_image.to(device)
        pred = model(X)
        
        pred = pred.cpu().detach().numpy()[0]
        pred = np.moveaxis(pred, 0, -1)

        pred = pred.clip(0, 1)
        image = Image.fromarray(np.uint8(pred*255), mode='RGB')
        image.save(f'{results_path}{sample_number}-trained.png')
        
#for i in range(801, 901):
#    meta_test(model, '0' + str(i))
    
meta_test(model, '085', dataset='Urban')