In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import model.krccsnet as krccsnet
from loss import *
from data_processor import *
from trainer import *
torch.set_printoptions(precision=8)
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
trn_loader, bsds, set5, set14 = data_loader()
def valid(model):
    criterion = loss_fn
    model.eval()
    psnr1, ssim1 = valid_bsds(bsds, model, criterion)
    print("----------BSDS----------PSNR: %.2f----------SSIM: %.4f" % (psnr1, ssim1))
    psnr2, ssim2 = valid_set(set5, model, criterion)
    print("----------Set5----------PSNR: %.2f----------SSIM: %.4f" % (psnr2, ssim2))
    psnr3, ssim3 = valid_set(set14, model, criterion)
    print("----------Set14----------PSNR: %.2f----------SSIM: %.4f" % (psnr3, ssim3))

In [None]:
def kronecker_product(w1,w2,stride):
    c2,c1,k1,h=w1.shape
    c3,_,k2,h=w2.shape
    p=(k2-1)*stride
    l=[]
    for c in range(c3):
        ww=F.pad(w1,(0,p,0,p),value=0)
        wc=torch.zeros((1,c1,p+k1,p+k1)).to(w1.device)
        for i in range(k2):
            for j in range(k2):
                x=ww.roll(shifts=(i*stride,j*stride),dims=(2,3))
                v=w2[c,:,i,j]
                x=x.reshape(c2,-1)
                wc+=(v@x).reshape(1,c1,p+k1,p+k1)
        l.append(wc)
    w=torch.cat(l,dim=0)
    return w

def reparm(csm:krccsnet.LKSN):
    w1=csm.conv.weight.data
    s=1 
    for i in range(csm.depth):
        pool_weight=torch.zeros_like(csm.down[i].conv.weight.data)
        for j in range(32):
            pool_weight[j,j,1,1]=0.25
            pool_weight[j,j,1,2]=0.25
            pool_weight[j,j,2,1]=0.25
            pool_weight[j,j,2,2]=0.25
        
        w2=csm.down[i].conv.weight.data+pool_weight
        w1=kronecker_product(w1,w2,s)
        s*=2

    w2=csm.linear.weight.data
    w=kronecker_product(w1,w2,s)
    
    cout,cin,k,_=w.shape
    pad=(k-1)//2
    ret=nn.Conv2d(cin, cout, kernel_size=k, padding=pad, stride=s, bias=False)
    ret.weight.data=w
    return ret

In [None]:
sensing_rate=0.25

In [None]:
model_kr=krccsnet.build_LKSN_ARM(sensing_rate)
path='./saved_model/krccsnet_train_'+str(sensing_rate)+'.pth'
state_dict = torch.load(path)
model_kr.load_state_dict(state_dict)
model_kr.cuda()
p=0

In [None]:
KR_LKSN=reparm(model_kr.encoder)

In [None]:
model_kr.encoder=KR_LKSN
dic=model_kr.state_dict()
path='./saved_model/krccsnet_rep_nofinetune'+str(sensing_rate)+'.pth'
torch.save(dic,path)

In [None]:
valid(model_kr)