In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as st
import torchvision.utils as vutils
from torchvision import models

from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchsummary import *
from tqdm import tqdm_notebook as tqdm

from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
import h5py
import cv2
import os
import pywt


import torchnet as tnt
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
import shutil

In [None]:
params = {
    'log_file': 'ResNet_smallNorb',
    'n_epoch': 100,
    'n_batch': 64,
    'lr': 0.001,
    'n_class': 5,
    'img_size': (1, 96, 96),
    'device': 'cuda:1',
    'lr_decay': 0.5,
    'lr_step': 20
}

torch.cuda.set_device(params['device'])
torch.set_default_tensor_type('torch.cuda.FloatTensor')
device=torch.device(params['device'])
torch.cuda.manual_seed(1)

# Degraded model

### add noise

In [None]:
def Gaussian_noise(img):

    noise_model = np.random.normal(0, 0.1, img.shape)
    noise_img = img + noise_model
    
    return noise_img

def salt_pepper_noise(img):
   
    noise_model = np.random.choice([0, 1], size=img.shape, p=[0.5, 0.5])
    m = np.random.choice([0, 1], size=img.shape, p=[0.9, 0.1])
    noise_img = (1 - m) * img + m * noise_model
    
    return noise_img

def speckle_noise(img):
    
    noise_model = np.random.uniform(-0.1, 0.1, img.shape)        
    noise_img  = img *(1 + noise_model)
    
    return noise_img

def add_noise(img):
    
    noise_type = np.random.randint(3)
    if(noise_type == 0):
        return Gaussian_noise(img)
    elif(noise_type == 1):
        return salt_pepper_noise(img)
    elif(noise_type == 2):
        return speckle_noise(img)

### resolution reduction

In [None]:
def wavelets_transform(img, level = 1):
    
    LLs = []
    LHs = []
    HLs = []
    HHs = []
    s = img.shape[0]
    imgs = np.zeros((s,s))
    
    for i in range(level):
        coeffs2 = pywt.dwt2(img, 'haar')
        LL, (LH, HL, HH) = coeffs2
        LLs.append(LL)
        LHs.append(LH)
        HLs.append(HL)
        HHs.append(HH)
        img = LL
    
    low_img = cv2.resize(LL, (s, s), interpolation = cv2.INTER_AREA)
    
    return low_img

def reduce_resolution(img):
    
    level = np.random.randint(2)
    return wavelets_transform(img, level = level + 1)

### Motion bluring

In [None]:
def Motion_model(w, h, a=0.1, b=0.1, T=1):
    u, v = np.meshgrid(np.arange(-h//2, h//2), np.arange(-w//2, w//2))
    phi = np.pi * (u * a + v * b) + 10e-6
    ft = (T / phi) * np.sin(phi) * np.exp(- phi * 1j)
    
    return ft

def Motion_bluring(img):
    
    
    a = np.random.normal(0, 0.05)
    b = np.random.normal(0, 0.05)
    
    w, h = img.shape[:]
    
    H = Motion_model(w, h, a, b)
    
    f = np.fft.fft2(img)
    fshift = np.fft.fftshift(f)
    fshift[h//2-1:h//2+1, w//2-1:w//2+1] = 0  
    
    gshift = fshift * H
    g = np.fft.ifftshift(gshift)
    
    out = np.fft.ifft2(g)
    
#     magnitude_spectrum = 20*np.log(np.abs(gshift) + 10e-6)

    out = out.real
    
    return out


# Load dataset

### SmallNorb dataset

In [None]:
class SmallNorbread(Dataset):
    def __init__(self, name, degrade_fn = None, transform=None):
        hf = h5py.File(name, 'r')
        input_images = np.array(hf.get('data')).astype(np.uint8)
        self.input_images = input_images / 255.
        self.target_labels = np.array(hf.get('labels')).astype(np.long)

        self.transform = transform
        self.degrade_fn = degrade_fn
        hf.close()

    def __len__(self):
        return (self.input_images.shape[0])

    def __getitem__(self, idx):
        images = self.input_images[idx]
        o_images = images.copy()
        o_images = (o_images - np.min(o_images)) / (np.max(o_images) - np.min(o_images))
        classes = self.target_labels[idx]
        
        if self.degrade_fn is not None:
            images = self.degrade_fn(images)
            
        images = (images  - np.min(images)) / (np.max(images) - np.min(images))
        
        if self.transform is not None:
            images = self.transform(images.astype(np.float32))
            o_images = self.transform(o_images.astype(np.float32))
        
        return o_images, images, classes

In [None]:
data_path = os.path.dirname(os.getcwd()) + "/data/"

Train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees = 5, translate = (0.05, 0.05)),
        transforms.ColorJitter(brightness=0.01, contrast = 0.2),
        transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean = (0.5,), std = (0.5,))
        ])
        
Test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean = (0.5,), std = (0.5,))
    ])

Train_data = SmallNorbread(name=data_path + "smallNorb/smallNorb_train96.h5", transform=Test_transform)
Val_data = SmallNorbread(name=data_path + "smallNorb/smallNorb_test96.h5", transform=Test_transform)
Test_data = SmallNorbread(name=data_path + "smallNorb/smallNorb_test96.h5", transform=Test_transform)

Train_dataloader = DataLoader(dataset=Train_data, batch_size = params['n_batch'], shuffle=True)
Val_dataloader = DataLoader(dataset=Val_data, batch_size = params['n_batch'], shuffle=False)
Test_dataloader = DataLoader(dataset=Test_data, batch_size = params['n_batch'], shuffle=False)

### show sample images

In [None]:
real_batch = next(iter(Val_dataloader))
plt.figure(figsize=(10,10))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[1].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

### model

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 32, kernel_size=5, padding=5 // 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, padding=5 // 2)
        self.conv4 = nn.Conv2d(128, 32, kernel_size=3, padding=3 // 2)
        self.conv5 = nn.Conv2d(32, num_channels, kernel_size=3, padding=3 // 2)
        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.tanh(self.conv5(x))
        return x
    
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            
class Classifier(nn.Module):
    def __init__(self, n_class):
        super(Classifier, self).__init__()
        
        self.to_3channel = nn.Conv2d(1, 3, 1)
        
        model = models.resnet34(pretrained = True)
        self.model_ft = torch.nn.Sequential(*(list(model.children())[:-1]))
        
        set_parameter_requires_grad(self.model_ft, False)
        
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_class),
        )
        
#         self.weights_init()
    
    def forward(self, x):
        x = self.to_3channel(x)
        x = self.model_ft(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=1e-3)
                # nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

# Training classifier

### Logger

In [None]:
class Customized_Logger():
    """
    logging the training process
    write to file and draw learning curves
    """
    def __init__(self, model, params="", summary_string=""):
        
        self.file_name = params['log_file']
        self.model = model
        num_classes = params['n_class']

        #mesurements
        self.train_meter_loss = tnt.meter.AverageValueMeter()
        self.train_classerr = tnt.meter.ClassErrorMeter(accuracy=True)
        self.val_meter_loss = tnt.meter.AverageValueMeter()
        self.val_classerr = tnt.meter.ClassErrorMeter(accuracy=True)
        self.confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)
        

        if not os.path.isdir("runs"):
            os.mkdir("runs")
        # shutil.rmtree("runs/ims_%s" % self.file_name, ignore_errors=True)
        # shutil.rmtree("nets", ignore_errors=True)
        if not os.path.isdir("runs/ims_%s" % self.file_name):
            os.mkdir("runs/ims_%s" % self.file_name)
        if not os.path.isdir("runs/nets_%s" % self.file_name):
            os.mkdir("runs/nets_%s" % self.file_name)

        #Plot Logger
        port = 8097
        self.loss_logger = VisdomPlotLogger('line', port=port, win = "Loss" + self.file_name, opts={'title': 'Loss Logger'})
        self.acc_logger = VisdomPlotLogger('line', port=port, win = "Acc" + self.file_name, opts={'title': 'Accuracy Logger'})
        self.confusion_logger = VisdomLogger('heatmap', port=port, win="confusion" + self.file_name, opts={'title': 'Confusion matrix',
                                                                'columnnames': list(range(num_classes)),
                                                                'rownames': list(range(num_classes))})
        self.reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

        #Logger   
        self.best_acc = 0
        self.best_epoch = -1
        with open("runs/nets_%s/log.txt" % self.file_name, 'w+') as log_file:
            log_file.write("--SETTINGS--\n")
            if(params != ""):
                for k in params:
                    log_file.write('%s: %s\n'%(k, params[k]))
            log_file.writelines(summary_string)
            log_file.write("--WRITE LOG--\n")
            log_file.write("train_acc\tval_acc\ttrain_loss\tval_loss\n")

                    
    def plot(self, epoch, recons=None):
        self.loss_logger.log(epoch, self.train_meter_loss.value()[0], name="train")
        self.acc_logger.log(epoch, self.train_classerr.value()[0], name="train")
        self.loss_logger.log(epoch, self.val_meter_loss.value()[0], name="val")
        self.acc_logger.log(epoch, self.val_classerr.value()[0], name="val")
        self.confusion_logger.log(self.confusion_meter.value())

        if(recons is not None):
            self.reconstruction_logger.log(
            make_grid(recons, nrow=int(recons.size(0) ** 0.5), padding=2, normalize=True).numpy())
            vutils.save_image(recons.data,
                          'runs/ims_%s/reconstructions_epoch_%03d.png' % (self.file_name, epoch),
                          normalize=False)

    def batch_update(self, outputs, targets, loss, train=True):
        self.train_classerr.add(outputs.data, targets)
        self.train_meter_loss.add(loss.item())
        if(train == False):
            self.val_classerr.add(outputs.data, targets)
            self.val_meter_loss.add(loss.item())
            self.confusion_meter.add(outputs.data, targets)

    def epoch_update(self, epoch, recons=None, save_best = True):
        train_acc = self.train_classerr.value()[0]
        val_acc = self.val_classerr.value()[0]
        train_err = self.train_meter_loss.value()[0]
        val_err = self.val_meter_loss.value()[0]
        if(save_best == True):
            if(val_acc > self.best_acc):
                self.best_acc = val_acc
                self.best_epoch = epoch
                torch.save(self.model.state_dict(), 'runs/nets_%s/net.pth' % self.file_name)
        else:
            torch.save(self.model.state_dict(), 'runs/nets_%s/net.pth' % self.file_name)

        with open("runs/nets_%s/log.txt" % self.file_name, 'a') as log_file:
            log_file.write('%.4f\t%.4f\t%.4f\t%.4f\n'%(train_acc, val_acc, train_err, val_err))
        
        print("training accuracy : %.4f, validation accuracy %.4f, training loss : %.4f, validation loss %.4f"%(train_acc, val_acc, train_err, val_err))
        self.plot(epoch, recons=recons)
        self.reset_meters()
    
    def final_update(self, training_time=0):
        test_acc = self.val_classerr.value()[0]
        test_err = self.val_meter_loss.value()[0]
        with open("runs/nets_%s/log.txt" % self.file_name, 'a') as log_file:
            log_file.write('\nBest accuracy %.4f at epoch %d\n'%(self.best_acc, self.best_epoch))
            log_file.write('%.4f\t%.4f\n'%(test_acc, test_err))
            log_file.write('Training time: %.4f seconds'%(training_time))
        print("test accuracy %.4f"%(test_acc))

    def reset_meters(self):
        self.train_classerr.reset()
        self.train_meter_loss.reset()
        self.val_classerr.reset()
        self.val_meter_loss.reset()

        self.confusion_meter.reset()

### Loss

In [None]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    
class CrossEntropyLoss(nn.Module):
    def __init__(self, num_classes):
        super(CrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, output, target):
        output = F.softmax(output, dim=1)
        return self.loss(output, target)

# Trainer

In [None]:
def predict_l2(features, cl, gallery, cls, topk=5):
    
    
    features = features
    gallery = gallery
    d = torch.norm(gallery[None, :, :] - features[:, None, :], dim=-1)
    
    _, topk_index = torch.topk(d, topk, dim=-1, largest=False)

    predicted_cls = cls[topk_index]
    truth = [cl[i] in predicted_cls[i] for i in range(0, predicted_cls.size(0))]
    
    return np.sum(truth)

def predict_cosine(features, cl, gallery, cls, topk=5):
    
    features = features/torch.norm(features, dim=-1, keepdim=True)
    gallery = gallery/torch.norm(gallery, dim=-1, keepdim=True)
    d = torch.sum(gallery[None, :, :] * features[:, None, :], dim=-1)
    
    _, topk_index = torch.topk(d, topk, dim=-1, largest=True)

    predicted_cls = cls[topk_index]
    truth = [cl[i] in predicted_cls[i] for i in range(0, predicted_cls.size(0))]
    return np.sum(truth)


def retrieval(features, gallery, topk=5):
    
    features = features
    gallery = gallery
    d = np.linalg.norm(gallery - features, axis=-1)
    
    topk_index = d.argsort()[:topk]
    print(d[topk_index])
    return topk_index


def show(query_img, result_index, topk=5):
    
    real_batch = Train_data.input_images[result_index]
    
    
    # show real image
    plt.title("query Image")
    plt.imshow(query_img.cpu().numpy()[0], cmap='gray')
    plt.show()
    
    # show retrival
    fig=plt.figure(figsize=(8, 2), dpi=200)
    plt.axis("off")
    plt.title("Retrieval Images")
    for i in range(1, topk + 1):
        img = real_batch[i - 1] 
        fig.add_subplot(topk // 5, 5, i)
        plt.axis("off")
        plt.imshow(img, cmap='gray')
    plt.show()


class Trainer():
    def __init__(self, Train_dataloader, Val_dataloader, Test_dataloader, params):
        self.train_data = Train_dataloader
        self.val_data = Val_dataloader
        self.test_data = Test_dataloader
        self.params = params
        
        images, _, _ = next(iter(self.test_data))
        self._model = Classifier(n_class = params['n_class']).to(device)
        
        summary(self._model, input_size= images[0].size(), device="cuda")
        
        self._loss = LabelSmoothingLoss(params['n_class'], smoothing = 0.1)
            
        ###logger
        self.logger = Customized_Logger(model=self._model, params=params, summary_string="")
    
    def run_batch(self, data, targets):
        
        labels = targets.to(device)
        xb = data.to(device).float()

        f_class = self._model(xb)
        p_class = F.softmax(f_class, dim=1)
          
        loss = self._loss(pred=f_class, target=labels)
            
        return loss, p_class, labels
                   
              
    def train(self):
        
        # set up the optimizer
        optimizer = torch.optim.Adam(self._model.parameters(), lr= self.params['lr'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.params['lr_step'], self.params['lr_decay'])
        
        
        for epoch in range(1, self.params['n_epoch'] + 1):
            print("epoch: [%d/%d]"%(epoch, self.params['n_epoch']))
            
            self._model.train()
        
            for data, _, targets in tqdm(self.train_data):
                
                loss, p_class, labels = self.run_batch(data, targets)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                self.logger.batch_update(outputs=p_class, targets=labels, loss=loss)
                         

            self._model.eval()
            for data, _, targets, in self.val_data:
          
                loss, p_class, labels = self.run_batch(data, targets)
            
                self.logger.batch_update(outputs=p_class, targets=labels, loss=loss, train=False)
            
            scheduler.step()
            self.logger.epoch_update(epoch=epoch)
    
                
    def test(self):
            
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()
            
        with torch.no_grad():
            for data, _, targets, in tqdm(self.test_data):

                loss, p_class, labels = self.run_batch(data, targets)
                self.logger.batch_update(outputs=p_class, targets=labels, loss=loss, train=False)
            
        self.logger.epoch_update(epoch=1)
            
    def extract_feature(self, data_loader):
        
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()
        
        feas_vector = np.empty((0, 512), np.float)
        cls = np.empty((0), np.int)
        
        with torch.no_grad():
            for data, _, targets, in tqdm(data_loader):


                labels = targets.to(device)
                xb = data.to(device).float()

                x = self._model.to_3channel(xb)
                x = self._model.model_ft(x).view(x.size(0), -1)


                feas = x.cpu().numpy()


                feas_vector = np.append(feas_vector, feas, axis=0)
                cls = np.append(cls, targets.cpu().numpy(), axis=0)
        
        hf = h5py.File('runs/nets_%s/feas512_gallery.h5' % self.params['log_file'], 'w')
        hf.create_dataset('features', data=feas_vector)
        hf.create_dataset('targets', data=np.array(cls))
        hf.close()
    
    def retrieval_result(self, data_loader, predict_fn, topk=5):
        
        hf = h5py.File('runs/nets_%s/feas512_gallery.h5' % self.params['log_file'], 'r')
        gallery = np.array(hf.get('features'))
        gallery_cls = np.array(hf.get('targets'))
        
        enhance_model2 = SRCNN()
        enhance_model2.load_state_dict(torch.load('runs/nets_upscale_smallNorb/net.pth'))
        enhance_model2.eval()

        enhance_model = autoencoder()
        enhance_model.load_state_dict(torch.load('runs/nets_deblur_smallNorb/net.pth'))
        enhance_model.eval()
        
        
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()


        gallery = torch.from_numpy(gallery).float().to(device)
        gallery_cls = torch.from_numpy(gallery_cls).long().to(device)

        sum_predicted = 0
        with torch.no_grad():
            for _, data, targets in tqdm(data_loader):

                labels = targets.cuda()
                xb = data.to(device).float()
                
#                 xb = enhance_model(xb)
                xb = enhance_model2(xb)
                
                
                #transform
                mi = torch.min(xb.view(-1, 96 * 96), dim=-1)[0].view(-1, 1, 1, 1)
                ma = torch.max(xb.view(-1, 96 * 96), dim=-1)[0].view(-1, 1, 1, 1)
                xb = (xb - mi) / (ma - mi)
                xb = 2*xb - 1
                
                x = self._model.to_3channel(xb)
                x = self._model.model_ft(x).view(x.size(0), -1)
               
                sum_predicted += predict_fn(x, labels, gallery, gallery_cls, topk=topk)
        n = Test_data.__len__()
        print(sum_predicted / n)
        
        
    def show_retrieval(self, data_loader, topk=5):
        
        
        hf = h5py.File('runs/nets_%s/feas512_gallery.h5' % self.params['log_file'], 'r')
        gallery = np.array(hf.get('features'))
        gallery_cls = np.array(hf.get('targets'))
        
        enhance_model2 = SRCNN()
        enhance_model2.load_state_dict(torch.load('runs/nets_upscale_smallNorb/net.pth'))
        enhance_model2.eval()

        enhance_model = autoencoder()
        enhance_model.load_state_dict(torch.load('runs/nets_deblur_smallNorb/net.pth'))
        enhance_model.eval()
        
        
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()


        with torch.no_grad():
            for _, data, classes in tqdm(data_loader):
                
                i = np.random.randint(0, params['n_batch'])
                xb = data.to(device).float()
                
#                 xb = enhance_model(xb)
#                 xb = enhance_model2(xb)
                  
                #transform
#                 mi = torch.min(xb.view(-1, 96 * 96), dim=-1)[0].view(-1, 1, 1, 1)
#                 ma = torch.max(xb.view(-1, 96 * 96), dim=-1)[0].view(-1, 1, 1, 1)
#                 xb = (xb - mi) / (ma - mi)
#                 xb = 2*xb - 1
                
                
                x = self._model.to_3channel(xb)
                x = self._model.model_ft(x).view(x.size(0), -1)
               
                feas = x[i].cpu().numpy()

                _index = retrieval(feas, gallery, topk=topk)
                show(data[i], _index, topk=topk)
        
        
 
        
        
        

### train

# Denoising Autoencoder

In [None]:
class Denoising():
    
    def __init__(self, Train_dataloader, Val_dataloader, Test_dataloader, params):
        self.train_data = Train_dataloader
        self.val_data = Val_dataloader
        self.test_data = Test_dataloader
        self.params = params
        
        images, _, _ = next(iter(self.test_data))
        self._model = autoencoder().to(device)
        
        summary(self._model, input_size= images[0].size(), device="cuda")
        
        self._loss = nn.MSELoss()
            
        ###logger
        self.logger = Customized_Logger(model=self._model, params=params, summary_string="")
        
        #dummy
        self.dummy = torch.rand((params['n_batch'], params['n_class'])).to(device)
        self.dummy1 = torch.randint(high=params['n_class'], size=(1, params['n_batch']))[0].to(device)
    
    def run_batch(self, data, targets):
        
        targets = targets.to(device).float()
        xb = data.to(device).float()

        x = self._model(xb)
        
          
        loss = self._loss(x, targets)
            
        return loss, x
                   
              
    def train(self):
        
        # set up the optimizer
        optimizer = torch.optim.Adam(self._model.parameters(), lr= self.params['lr'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.params['lr_step'], self.params['lr_decay'])
        
        
        for epoch in range(1, self.params['n_epoch'] + 1):
            print("epoch: [%d/%d]"%(epoch, self.params['n_epoch']))
            
            self._model.train()
        
            for targets, data, _ in tqdm(self.train_data):
                
                loss, x = self.run_batch(data, targets)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss)
                         

            self._model.eval()
            for targets, data, _ in self.val_data:
          
                loss, x = self.run_batch(data, targets)
            
            
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss, train=False)
            
            scheduler.step()
            self.logger.epoch_update(epoch=epoch, recons = x.detach().cpu(), save_best=False)
    
                
    def test(self):
            
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()
            
        with torch.no_grad():
            for targets, data, _ in tqdm(self.test_data):

                loss, x = self.run_batch(data, targets)
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss, train=False)
                
                plt.figure(figsize=(10,10))
                plt.axis("off")
                plt.imshow(np.transpose(vutils.make_grid(x.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
                break
            
        self.logger.final_update()

### train denoising model

# Super resolution

In [None]:
class Upscale():
    
    def __init__(self, Train_dataloader, Val_dataloader, Test_dataloader, params):
        self.train_data = Train_dataloader
        self.val_data = Val_dataloader
        self.test_data = Test_dataloader
        self.params = params
        
        images, _, _ = next(iter(self.test_data))
        self._model = SRCNN(num_channels=1).to(device)
        
        summary(self._model, input_size= images[0].size(), device="cuda")
        
        self._loss = nn.MSELoss()
            
        ###logger
        self.logger = Customized_Logger(model=self._model, params=params, summary_string="")
        
        #dummy
        self.dummy = torch.rand((params['n_batch'], params['n_class'])).to(device)
        self.dummy1 = torch.randint(high=params['n_class'], size=(1, params['n_batch']))[0].to(device)
    
    def run_batch(self, data, targets):
        
        targets = targets.to(device).float()
        xb = data.to(device).float()

        x = self._model(xb)
        
          
        loss = self._loss(x, targets)
            
        return loss, x
                   
              
    def train(self):
        
        # set up the optimizer
        optimizer = torch.optim.Adam(self._model.parameters(), lr= self.params['lr'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.params['lr_step'], self.params['lr_decay'])
        
        
        for epoch in range(1, self.params['n_epoch'] + 1):
            print("epoch: [%d/%d]"%(epoch, self.params['n_epoch']))
            
            self._model.train()
        
            for  targets, data, _ in tqdm(self.train_data):
                
                loss, x = self.run_batch(data, targets)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss)
                         

            self._model.eval()
            for targets, data, _ in self.val_data:
          
                loss, x = self.run_batch(data, targets)
            
            
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss, train=False)
            
            scheduler.step()
            self.logger.epoch_update(epoch=epoch, recons = x.detach().cpu(), save_best=False)
    
                
    def test(self):
            
        self._model.load_state_dict(torch.load('runs/nets_%s/net.pth' % self.params['log_file']))
        self._model.eval()
            
        with torch.no_grad():
            for targets, data, _ in tqdm(self.test_data):

                loss, x = self.run_batch(data, targets)
                self.logger.batch_update(outputs=self.dummy, targets=self.dummy1, loss=loss, train=False)
                
                plt.figure(figsize=(10,10))
                plt.axis("off")
                plt.imshow(np.transpose(vutils.make_grid(x.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
                break
            
        self.logger.final_update()

# Model evaluation

### SmallNorb validation retrieval
k = 1: Mean match = 0.894753\
k = 5: Mean match = 0.91823 \
k = 10: Mean match = 0.923004 \
validation accuracy 88.73

## Random retrieval

In [None]:
trainer = Trainer(Train_dataloader, Val_dataloader, Test_dataloader, params)
Test_data = SmallNorbread(name=data_path + "smallNorb/smallNorb_test96.h5", degrade_fn=Motion_bluring, transform=Test_transform)
Test_dataloader = DataLoader(dataset=Test_data, batch_size = 64, shuffle=False)

# Gallery_data = SmallNorbread(name=data_path + "smallNorb/smallNorb_train96.h5", transform=Test_transform)
# Gallery_dataloader = DataLoader(dataset=Gallery_data, batch_size = 64, shuffle=False)

# trainer.extract_feature(Gallery_dataloader)

trainer.show_retrieval(Test_dataloader, topk=10)