### Imports

In [1]:
import os 
import random 
random.seed(97)
import scipy.io as io 
import numpy as np 
import cv2 
import datetime
import itertools
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import Parameter
from torchsummary import summary
from torch.autograd import Variable
import torch.autograd as autograd
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter

### Configs

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Utils

#### Visualisation Utils

In [4]:
def check_dir(path): #why do we need this
    ''' 
    check if the given directory exists, else create one
    '''
    if not os.path.exists(path):
        os.makedirs(path)

def visualise_hyperspectral_cube(sample_file,channels=[],save_imgs=True,show_imgs=False):
    '''
    To visualise & save reflectances corresponding to each wavelength as a grayscale image
    '''
    sample_dump = io.loadmat(sample_file)
    img_hsi = sample_dump['cube'] 
    img_bands = sample_dump['bands'][0] 
    img_256 = (img_hsi*255).astype(np.uint8)
    for chnl in channels:
        img_chnl = img_256[:,:,chnl]
        chnl_band = img_bands[chnl-1]
        img_save_path = sample_file.replace(".mat",f"_{chnl}.png")
        if save_imgs:
            cv2.imwrite(img_save_path, img_chnl)
        if show_imgs:
            plt.imshow(img_chnl,cmap='gray')

In [5]:
# visualise_hyperspectral_cube("sample/ARAD_HS_0464.mat",[1,2,3,14],save_imgs=False,show_imgs=True)

#### Data Utils

#### Network Utils

In [6]:
def weights_init(net, init_type = 'xavier', init_gain = 0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    # print('initialising network with %s type' % init_type)
    net.apply(init_func)



In [7]:
def save_checkpoint(state_dicts):
    torch.save(state_dicts,CHECKPOINT_FOLDER+CHECKPOINT_FILE)

def load_checkpoint(checkpoint_file,model,optimiser,learning_rate):
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint["state_dict"])
    optimiser.load_state_dict(checkpoint["optimiser"])

    for params in optimiser.param_groups:
        params["lr"] = learning_rate

In [8]:
# x = torch.range(start = 0, end = 31).reshape([1, 8, 2, 2])
# y = F.pixel_shuffle(x, 2)
# x_ = pixel_unshuffle(y, 2)
# print(f"-{x.shape},y-{y.shape},x_{x_.shape}")

### Dataset

#### Dataset Class

In [9]:
class HSI_Dataset_Train(Dataset):
    def __init__(self,train_data_dir,input_image_shape,data_transforms=None):
        self.train_data_dir = train_data_dir
        self.data_transforms = data_transforms
        self.rgb_path = self.train_data_dir+"rgb/"
        self.hsi_path = self.train_data_dir + "hsi/"
        self.input_image_shape = input_image_shape
        self.input_image_size = self.input_image_shape[0]
        
        #Get file names
        self.img_root_names = []
        for img_name in os.listdir(self.hsi_path):
            self.img_root_names.append(img_name.split(".mat")[0])

        self.length = len(self.img_root_names)
        
        #generate file names for spectral cubes and jpg images
        self.rgb_image_files = []
        self.hsi_image_files = []
        for img_root_name in self.img_root_names:
            self.rgb_image_files.append(img_root_name+"_RealWorld.jpg")
            self.hsi_image_files.append(img_root_name+".mat")
    
    def __getitem__(self, index):
        rgb_img_path = self.rgb_path + self.rgb_image_files[index]
        hsi_img_path = self.hsi_path + self.hsi_image_files[index]
        img_root_name = self.img_root_names[index]

        hsi_img = io.loadmat(hsi_img_path)["cube"]
        rgb_img = cv2.imread(rgb_img_path,-1)

        rgb_img = rgb_img.astype(np.float64)/255.0 #normalisation to [0,1]

        #crop a patch of the scene
        h,w = rgb_img.shape[:2]
        if (h>self.input_image_size) or (w>self.input_image_size):
            rand_h = random.randint(0,h-self.input_image_size)
            rand_w = random.randint(0,w-self.input_image_size)

            rgb_img = rgb_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            hsi_img = hsi_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            
        rgb_img = torch.from_numpy(rgb_img.astype(np.float32).transpose(2, 0, 1)).contiguous() #to tensor
        hsi_img = torch.from_numpy(hsi_img.astype(np.float32).transpose(2, 0, 1)).contiguous()
            
        return rgb_img, hsi_img, img_root_name
    
    def __len__(self):
        return self.length

In [10]:
TrainDataset = HSI_Dataset_Train("Data/train/",[256,256,3],data_transforms=None)
print(f"Number of images in training set {len(TrainDataset)}")

Number of images in training set 450


#### DataLoader

In [11]:
train_loader = DataLoader(TrainDataset, batch_size = 8, shuffle = True, num_workers = 0, pin_memory = True)

In [12]:
train_iter = iter(train_loader)
first_batch = train_iter.next()
rgb_images,hsi_images,  img_names = first_batch

In [13]:
print(len(hsi_images))
print(hsi_images[0].shape)

print(len(rgb_images))
print(rgb_images[0].shape)

### Model


#### Customised Layers for the network

In [14]:
# Pixel Unshuffle layer - does opposite to what torch.nn.Functional.pixelshuffle does

def pixel_unshuffle(input, downscale_factor):
    c = input.shape[1]
    kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
                        device = input.device)
    for y in range(downscale_factor):
        for x in range(downscale_factor):
            kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
    return F.conv2d(input, kernel, stride = downscale_factor, groups = c)


class PixelUnShuffle(nn.Module):
    def __init__(self, downscale_factor):
        super(PixelUnShuffle, self).__init__()
        self.downscale_factor = downscale_factor

    def forward(self, input):
        return pixel_unshuffle(input, self.downscale_factor)
# Conv2d Block
class Conv2dLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
        super(Conv2dLayer, self).__init__()
        # Initialize the padding scheme
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)
        
        # Initialize the normalization type
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(out_channels)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(out_channels)
        elif norm == 'ln':
            self.norm = LayerNorm(out_channels)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)
        
        # Initialize the activation funtion
        if activation == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace = True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace = True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # Initialize the convolution layers
        if sn:
            self.conv2d = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation))
        else:
            self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
    
    def forward(self, x):
        x = self.pad(x)
        x = self.conv2d(x)
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

class TransposeConv2dLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False, scale_factor = 2):
        super(TransposeConv2dLayer, self).__init__()
        # Initialize the conv scheme
        self.scale_factor = scale_factor
        self.conv2d = Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')
        x = self.conv2d(x)
        return x

class ResConv2dLayer(nn.Module):
    def __init__(self, in_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False, scale_factor = 2):
        super(ResConv2dLayer, self).__init__()
        # Initialize the conv scheme
        self.conv2d = nn.Sequential(
            Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn),
            Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation = 'none', norm = norm, sn = sn)
        )
    
    def forward(self, x):
        residual = x
        out = self.conv2d(x)
        out = 0.1 * out + residual
        return out

class DenseConv2dLayer_5C(nn.Module):
    def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
        super(DenseConv2dLayer_5C, self).__init__()
        # dense convolutions
        self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5
        
class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
        super(ResidualDenseBlock_5C, self).__init__()
        # dense convolutions
        self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)

    def forward(self, x):
        residual = x
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        x5 = 0.1 * x5 + residual
        return x5

# Layer Norm

class LayerNorm(nn.Module):
    def __init__(self, num_features, eps = 1e-8, affine = True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = Parameter(torch.Tensor(num_features).uniform_())
            self.beta = Parameter(torch.zeros(num_features))

    def forward(self, x):
        # layer norm
        shape = [-1] + [1] * (x.dim() - 1)                                  # for 4d input: [-1, 1, 1, 1]
        if x.size(0) == 1:
            # These two lines run much faster in pytorch 0.4 than the two lines listed below.
            mean = x.view(-1).mean().view(*shape)
            std = x.view(-1).std().view(*shape)
        else:
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)
        x = (x - mean) / (std + self.eps)
        # if it is learnable
        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)                          # for 4d input: [1, -1, 1, 1]
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x

# spectral norm
def l2normalize(v, eps = 1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    def __init__(self, module, name = 'weight', power_iterations = 1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

#non local block
class Self_Attn(nn.Module):
    """ Self attention Layer for Feature Map dimension"""
    def __init__(self, in_dim, latent_dim = 8):
        super(Self_Attn, self).__init__()
        self.channel_in = in_dim
        self.channel_latent = in_dim // latent_dim
        self.query_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim // latent_dim, kernel_size = 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim // latent_dim, kernel_size = 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim = -1)

    def forward(self, x):
        
        batchsize, C, height, width = x.size()
        # proj_query: reshape to B x N x c, N = H x W
        proj_query  = self.query_conv(x).view(batchsize, -1, height * width).permute(0, 2, 1)
        # proj_query: reshape to B x c x N, N = H x W
        proj_key =  self.key_conv(x).view(batchsize, -1, height * width)
        # transpose check, energy: B x N x N, N = H x W
        energy =  torch.bmm(proj_query, proj_key)
        # attention: B x N x N, N = H x W
        attention = self.softmax(energy)
        # proj_value is normal convolution, B x C x N
        proj_value = self.value_conv(x).view(batchsize, -1, height * width)
        # out: B x C x N
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batchsize, C, height, width)
        
        out = self.gamma * out + x
        return out
#Global Block
class SELayer(nn.Module):
    def __init__(self, channel, reduction = 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear(channel // reduction, channel // reduction, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear(channel // reduction, channel, bias = False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class GlobalBlock(nn.Module):
    def __init__(self, in_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False, reduction = 8):
        super(GlobalBlock, self).__init__()
        self.conv1 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.conv2 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear(in_channels // reduction, in_channels // reduction, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear(in_channels // reduction, in_channels, bias = False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # residual
        residual = x
        # Sequeeze-and-Excitation(SE)
        b, c, _, _ = x.size()
        x = self.conv1(x)
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        y = x * y.expand_as(x)
        y = self.conv2(y)
        # addition
        out = 0.1 * y + residual
        return out



#### Designing the network

In [15]:
class RGB2HS(nn.Module):
    def __init__(self, in_channels,out_channels,start_channels,activ,norm,pad):
        super(RGB2HS, self).__init__()
        # Top subnetwork, K = 3
        self.top1 = Conv2dLayer(in_channels * (4 ** 3), start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.top21 = ResidualDenseBlock_5C(start_channels * (2 ** 3), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.top22 = GlobalBlock(start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
        self.top3 = Conv2dLayer(start_channels * (2 ** 3), start_channels * (2 ** 3), 1, 1, 0, pad_type = pad, activation = activ, norm = norm)
        # Middle subnetwork, K = 2
        self.mid1 = Conv2dLayer(in_channels * (4 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.mid2 = Conv2dLayer(int(start_channels * (2 ** 2 + 2 ** 3 / 4)), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.mid31 = ResidualDenseBlock_5C(start_channels * (2 ** 2), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.mid32 = GlobalBlock(start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
        self.mid4 = Conv2dLayer(start_channels * (2 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        # Bottom subnetwork, K = 1
        self.bot1 = Conv2dLayer(in_channels * (4 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.bot2 = Conv2dLayer(int(start_channels * (2 ** 1 + 2 ** 2 / 4)), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.bot31 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.bot32 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.bot33 = GlobalBlock(start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
        self.bot4 = Conv2dLayer(start_channels * (2 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        # Mainstream
        self.main1 = Conv2dLayer(in_channels, start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main2 = Conv2dLayer(int(start_channels * (2 ** 0 + 2 ** 1 / 4)), start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main31 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main32 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main33 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main34 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)
        self.main35 = GlobalBlock(start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4)
        self.main4 = Conv2dLayer(start_channels, out_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm)

    def forward(self, x):
        # PixelUnShuffle                                        input: batch * 3 * 256 * 256
        x1 = pixel_unshuffle(x, 2)               # out: batch * 12 * 128 * 128
        x2 = pixel_unshuffle(x, 4)               # out: batch * 48 * 64 * 64
        x3 = pixel_unshuffle(x, 8)               # out: batch * 192 * 32 * 32
        # Top subnetwork                                        suppose the start_channels = 32
        x3 = self.top1(x3)                                      # out: batch * 256 * 32 * 32
        x3 = self.top21(x3)                                     # out: batch * 256 * 32 * 32
        x3 = self.top22(x3)                                     # out: batch * 256 * 32 * 32
        x3 = self.top3(x3)                                      # out: batch * 256 * 32 * 32
        x3 = F.pixel_shuffle(x3, 2)                             # out: batch * 64 * 64 * 64, ready to be concatenated
        # Middle subnetwork
        x2 = self.mid1(x2)                                      # out: batch * 128 * 64 * 64
        x2 = torch.cat((x2, x3), 1)                             # out: batch * (128 + 64) * 64 * 64
        x2 = self.mid2(x2)                                      # out: batch * 128 * 64 * 64
        x2 = self.mid31(x2)                                     # out: batch * 128 * 64 * 64
        x2 = self.mid32(x2)                                     # out: batch * 128 * 64 * 64
        x2 = self.mid4(x2)                                      # out: batch * 128 * 64 * 64
        x2 = F.pixel_shuffle(x2, 2)                             # out: batch * 32 * 128 * 128, ready to be concatenated
        # Bottom subnetwork
        x1 = self.bot1(x1)                                      # out: batch * 64 * 128 * 128
        x1 = torch.cat((x1, x2), 1)                             # out: batch * (64 + 32) * 128 * 128
        x1 = self.bot2(x1)                                      # out: batch * 64 * 128 * 128
        x1 = self.bot31(x1)                                     # out: batch * 64 * 128 * 128
        x1 = self.bot32(x1)                                     # out: batch * 64 * 128 * 128
        x1 = self.bot33(x1)                                     # out: batch * 64 * 128 * 128
        x1 = self.bot4(x1)                                      # out: batch * 64 * 128 * 128
        x1 = F.pixel_shuffle(x1, 2)                             # out: batch * 16 * 256 * 256, ready to be concatenated
        # U-Net generator with skip connections from encoder to decoder
        x = self.main1(x)                                       # out: batch * 32 * 256 * 256
        x = torch.cat((x, x1), 1)                               # out: batch * (32 + 16) * 256 * 256
        x = self.main2(x)                                       # out: batch * 32 * 256 * 256
        x = self.main31(x)                                      # out: batch * 32 * 256 * 256
        x = self.main32(x)                                      # out: batch * 32 * 256 * 256
        x = self.main33(x)                                      # out: batch * 32 * 256 * 256
        x = self.main34(x)                                      # out: batch * 32 * 256 * 256
        x = self.main35(x)                                      # out: batch * 32 * 256 * 256
        x = self.main4(x)                                       # out: batch * 3 * 256 * 256
        return x

In [16]:
# model = RGB2HS(3,31,64,'lrelu','none','reflect')
# model = model.to(DEVICE)

In [17]:
# summary(model,input_size=(3,256,256)) 

In [18]:
# summary(model,input_size=(3,32,32))

### Training


#### Model Configs

In [19]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 2
NUM_WORKERS = 0 
NUM_EPOCHS = 100

# CHECKPOINT_FOLDER = "Checkpoints/"
# CHECKPOINT_FILE = "EfNetB3_v1.pth.tar"

PIN_MEMORY = True
# SAVE_MODEL = True
# LOAD_MODEL = True

INPUT_IMAGE_WIDTH = [256]
INPUT_IMAGE_HEIGHT = [256]
INPUT_IMAGE_SIZE = INPUT_IMAGE_WIDTH + INPUT_IMAGE_HEIGHT
INPUT_IMAGE_CHANNELS = [3]
INPUT_IMAGE_SHAPE = INPUT_IMAGE_SIZE + INPUT_IMAGE_CHANNELS

TRAIN_DATA_DIR = "Data/train/"

LEARNING_RATE = 0.0001
WEIGHT_B1 = 0.5
WEIGHT_B2 = 0.999
WEIGHT_DECAY = 0.0

CHECKPOINT_FOLDER = "Checkpoints/"
CHECKPOINT_FILE = "SpecReconV1OnlineData_full.pth" #.pth

writer = SummaryWriter("runs/SpecRecon_v1_online_data_full dataset")


In [20]:
model = RGB2HS(3,31,64,'lrelu','none','reflect')
weights_init(model, init_type = 'xavier', init_gain = 0.02)
model = model.to(DEVICE)

TrainDataset = HSI_Dataset_Train(TRAIN_DATA_DIR,INPUT_IMAGE_SHAPE,data_transforms=None)
print(f"Number of images in training set {len(TrainDataset)}")

train_loader = DataLoader(TrainDataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS, pin_memory = PIN_MEMORY)

Number of images in training set 450


##### Loss FUnction, Optimiser

In [21]:
loss_func = torch.nn.L1Loss().cuda()
optimiser = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, betas = (WEIGHT_B1, WEIGHT_B2), weight_decay = WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler()
lr_scheduler1 = lr_scheduler.StepLR(optimiser,step_size=250,gamma=0.1)

#### Training Loop

In [22]:
def train_epoch(model,dataloader,optimiser,scheduler,loss_func,scaler,epoch_num):
    total_loss = 0
    dataiter = tqdm(dataloader)
    for i,(rgb_img,hsi_img,img_name) in enumerate(dataiter):
        rgb_img = rgb_img.to(DEVICE)
        hsi_img = hsi_img.to(DEVICE)

        #forward pass
        with torch.cuda.amp.autocast():
            recons_img = model(rgb_img)
            batch_loss = loss_func(recons_img,hsi_img)
            total_loss += batch_loss.item()
            
        #backward step
        optimiser.zero_grad()
        scaler.scale(batch_loss).backward()
        scaler.step(optimiser)
        scaler.update()
        lr_scheduler1.step()
        dataiter.set_postfix(loss=batch_loss.item())
        
    epoch_loss = total_loss/len(dataloader)
    writer.add_scalar('MAE_loss',epoch_loss,global_step=epoch_num)
    writer.close()
#     if epoch_num%100==0:
    print(f"Epoch {epoch_num}, Loss {epoch_loss}")
    return epoch_loss

    

In [23]:
# trainer = tqdm(range(NUM_EPOCHS))
best_loss = 10000
best_epoch = 0
for epoch in range(NUM_EPOCHS):
    epochloss = train_epoch(model,train_loader,optimiser,lr_scheduler1,loss_func,scaler,epoch)
    if epochloss<best_loss :
        best_loss = epochloss
        best_epoch = epoch
        checkpoint_stuff = {"state_dict":model.state_dict() ,"optimiser":optimiser.state_dict()}
        best_epoch = epoch
#         print(f"Checkpoint Saved at {best_epoch}")
        save_checkpoint(checkpoint_stuff)
    #trainer.set_postfix(loss=epochloss)
print(f"Best Epoch : {best_epoch}, Best Epoch Loss : {best_loss}")


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0815]


Epoch 0, Loss 0.06556085799717241


100%|██████████| 225/225 [01:36<00:00,  2.32it/s, loss=0.0425]


Epoch 1, Loss 0.041837270822789936


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0349] 


Epoch 2, Loss 0.04100427851287855


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0247] 


Epoch 3, Loss 0.041213166308071876


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0244] 


Epoch 4, Loss 0.04084135151778658


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0456] 


Epoch 5, Loss 0.04076602334777514


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0204] 


Epoch 6, Loss 0.04081608824431896


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0671] 


Epoch 7, Loss 0.04092896956536505


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0359]


Epoch 8, Loss 0.04034722871664498


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.033]  


Epoch 9, Loss 0.04055109910459982


100%|██████████| 225/225 [01:35<00:00,  2.37it/s, loss=0.0796] 


Epoch 10, Loss 0.04042940511471695


100%|██████████| 225/225 [01:34<00:00,  2.38it/s, loss=0.0473] 


Epoch 11, Loss 0.04074585429496235


100%|██████████| 225/225 [01:34<00:00,  2.38it/s, loss=0.0234] 


Epoch 12, Loss 0.04108961870893836


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0786] 


Epoch 13, Loss 0.040624146991305884


100%|██████████| 225/225 [01:34<00:00,  2.38it/s, loss=0.0383] 


Epoch 14, Loss 0.041570554218358466


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0249] 


Epoch 15, Loss 0.04059853114601639


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0176] 


Epoch 16, Loss 0.04060666109952662


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0345]


Epoch 17, Loss 0.04095387700945139


100%|██████████| 225/225 [01:38<00:00,  2.27it/s, loss=0.0433] 


Epoch 18, Loss 0.04098437938839197


100%|██████████| 225/225 [01:37<00:00,  2.30it/s, loss=0.0404] 


Epoch 19, Loss 0.040452718796829386


100%|██████████| 225/225 [01:38<00:00,  2.28it/s, loss=0.0115] 


Epoch 20, Loss 0.040803615454998284


100%|██████████| 225/225 [01:37<00:00,  2.30it/s, loss=0.0186]


Epoch 21, Loss 0.04079579018470314


100%|██████████| 225/225 [01:37<00:00,  2.32it/s, loss=0.0348] 


Epoch 22, Loss 0.040770741740448604


100%|██████████| 225/225 [01:37<00:00,  2.31it/s, loss=0.0419] 


Epoch 23, Loss 0.040911201474567255


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0533]


Epoch 24, Loss 0.04020418982124991


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0294] 


Epoch 25, Loss 0.04170353556258811


100%|██████████| 225/225 [01:35<00:00,  2.37it/s, loss=0.032]  


Epoch 26, Loss 0.040800969373020864


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0182] 


Epoch 27, Loss 0.040504590498490464


100%|██████████| 225/225 [01:35<00:00,  2.37it/s, loss=0.0361] 


Epoch 28, Loss 0.040436577366458046


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0314] 


Epoch 29, Loss 0.04035615060064528


100%|██████████| 225/225 [01:37<00:00,  2.30it/s, loss=0.0376] 


Epoch 30, Loss 0.04030147278267476


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0522] 


Epoch 31, Loss 0.04072173051329123


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0335] 


Epoch 32, Loss 0.041158003146863645


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0275] 


Epoch 33, Loss 0.0410289044967956


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.048]  


Epoch 34, Loss 0.04078147942200303


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0559]


Epoch 35, Loss 0.04113164725816912


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.00873]


Epoch 36, Loss 0.041369544424944456


100%|██████████| 225/225 [01:34<00:00,  2.38it/s, loss=0.0398] 


Epoch 37, Loss 0.04128971901618772


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0808] 


Epoch 38, Loss 0.04125530114190446


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0555]


Epoch 39, Loss 0.04069312654021714


100%|██████████| 225/225 [01:37<00:00,  2.32it/s, loss=0.0872] 


Epoch 40, Loss 0.04061228990347849


100%|██████████| 225/225 [01:39<00:00,  2.27it/s, loss=0.041] 


Epoch 41, Loss 0.03992189841551913


100%|██████████| 225/225 [01:39<00:00,  2.27it/s, loss=0.0219] 


Epoch 42, Loss 0.04087017623707652


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0879] 


Epoch 43, Loss 0.04098172065284517


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0302] 


Epoch 44, Loss 0.04104617925567759


100%|██████████| 225/225 [01:36<00:00,  2.32it/s, loss=0.0246] 


Epoch 45, Loss 0.04063085508843263


100%|██████████| 225/225 [01:39<00:00,  2.26it/s, loss=0.0436] 


Epoch 46, Loss 0.040229817651626136


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0775] 


Epoch 47, Loss 0.040522443352060185


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0381] 


Epoch 48, Loss 0.04091128734250864


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0145] 


Epoch 49, Loss 0.04157537032539646


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0339] 


Epoch 50, Loss 0.04039242411653201


100%|██████████| 225/225 [01:37<00:00,  2.31it/s, loss=0.0311] 


Epoch 51, Loss 0.040871697035100726


100%|██████████| 225/225 [01:37<00:00,  2.31it/s, loss=0.00819]


Epoch 52, Loss 0.04160178371808595


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0972] 


Epoch 53, Loss 0.041340159032907754


100%|██████████| 225/225 [01:36<00:00,  2.32it/s, loss=0.0252] 


Epoch 54, Loss 0.03990380442183879


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0677] 


Epoch 55, Loss 0.04152699755297767


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0129] 


Epoch 56, Loss 0.040911396640456385


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.00571]


Epoch 57, Loss 0.04043928465288547


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.038] 


Epoch 58, Loss 0.040724327671859


100%|██████████| 225/225 [01:37<00:00,  2.30it/s, loss=0.0572] 


Epoch 59, Loss 0.04090990070874492


100%|██████████| 225/225 [01:39<00:00,  2.27it/s, loss=0.0652] 


Epoch 60, Loss 0.041011621811323694


100%|██████████| 225/225 [01:42<00:00,  2.19it/s, loss=0.0446] 


Epoch 61, Loss 0.040935354895061914


100%|██████████| 225/225 [01:39<00:00,  2.26it/s, loss=0.0332] 


Epoch 62, Loss 0.04102139346715477


100%|██████████| 225/225 [01:35<00:00,  2.34it/s, loss=0.0367] 


Epoch 63, Loss 0.040229693858159915


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0514] 


Epoch 64, Loss 0.04137041169322199


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0309] 


Epoch 65, Loss 0.04040788381256991


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0408] 


Epoch 66, Loss 0.04075706483796239


100%|██████████| 225/225 [01:37<00:00,  2.30it/s, loss=0.016]  


Epoch 67, Loss 0.04115014864131808


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0592] 


Epoch 68, Loss 0.041055587397681344


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0136]


Epoch 69, Loss 0.040857503335509034


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.028]  


Epoch 70, Loss 0.040883250395870874


100%|██████████| 225/225 [01:38<00:00,  2.27it/s, loss=0.0266] 


Epoch 71, Loss 0.04053987538028094


100%|██████████| 225/225 [01:39<00:00,  2.26it/s, loss=0.0233] 


Epoch 72, Loss 0.040859780663417444


100%|██████████| 225/225 [01:37<00:00,  2.31it/s, loss=0.0438] 


Epoch 73, Loss 0.040186455465025375


100%|██████████| 225/225 [01:38<00:00,  2.28it/s, loss=0.0559]


Epoch 74, Loss 0.04041419621970919


100%|██████████| 225/225 [01:37<00:00,  2.32it/s, loss=0.0502]


Epoch 75, Loss 0.040553990993648766


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.027]  


Epoch 76, Loss 0.04087033724205361


100%|██████████| 225/225 [01:36<00:00,  2.33it/s, loss=0.0574] 


Epoch 77, Loss 0.040095378371576465


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0266] 


Epoch 78, Loss 0.03994101274137696


100%|██████████| 225/225 [01:36<00:00,  2.32it/s, loss=0.0297] 


Epoch 79, Loss 0.04143315695019232


100%|██████████| 225/225 [01:37<00:00,  2.31it/s, loss=0.0472] 


Epoch 80, Loss 0.0405866617233389


100%|██████████| 225/225 [01:38<00:00,  2.29it/s, loss=0.0434] 


Epoch 81, Loss 0.04076615801701943


100%|██████████| 225/225 [01:36<00:00,  2.34it/s, loss=0.0459]


Epoch 82, Loss 0.04096736317293512


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0286] 


Epoch 83, Loss 0.04142142458508412


100%|██████████| 225/225 [01:36<00:00,  2.32it/s, loss=0.0844] 


Epoch 84, Loss 0.041288666762411594


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0155] 


Epoch 85, Loss 0.04059868760406971


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0343]


Epoch 86, Loss 0.041113055418762895


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0178] 


Epoch 87, Loss 0.04037062184471223


100%|██████████| 225/225 [01:35<00:00,  2.37it/s, loss=0.0157]


Epoch 88, Loss 0.041013087814466824


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.057] 


Epoch 89, Loss 0.04048398819234636


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0122] 


Epoch 90, Loss 0.040956406945155725


100%|██████████| 225/225 [01:35<00:00,  2.35it/s, loss=0.0354] 


Epoch 91, Loss 0.04040572730617391


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0111] 


Epoch 92, Loss 0.04073727660915918


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0557] 


Epoch 93, Loss 0.04132699180808332


100%|██████████| 225/225 [01:35<00:00,  2.37it/s, loss=0.0631]


Epoch 94, Loss 0.04061649943391482


100%|██████████| 225/225 [01:35<00:00,  2.36it/s, loss=0.0317] 


Epoch 95, Loss 0.041056230761524704


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0506]


Epoch 96, Loss 0.04099229085155659


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0495] 


Epoch 97, Loss 0.041034724277754626


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0273]


Epoch 98, Loss 0.04173011793237594


100%|██████████| 225/225 [01:34<00:00,  2.37it/s, loss=0.0244] 

Epoch 99, Loss 0.04068570215668943
Best Epoch : 54, Best Epoch Loss : 0.03990380442183879





In [24]:
#Load the best model
model = RGB2HS(3,31,64,'lrelu','none','reflect')
weights_init(model, init_type = 'xavier', init_gain = 0.02)
optimiser = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, betas = (WEIGHT_B1, WEIGHT_B2), weight_decay = WEIGHT_DECAY)
load_checkpoint(CHECKPOINT_FOLDER+CHECKPOINT_FILE,model,optimiser,LEARNING_RATE)


In [None]:
TrainDataset = HSI_Dataset_Train("Dataset/validation-trial-norway-2/","_flash_img.jpg","_spectra_2.mat",[32,32,3],data_transforms=None)
print(f"Number of images in training set {len(TrainDataset)}")



