# Model Prototyping

Building the basis for our model experimentation

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import numpy as np
import pandas as pd
import torch
import os
import json

from torch.utils import data
from torch.nn import Conv2d, AvgPool2d, ReLU, Dropout, Flatten, Linear, Sequential, Module
from torch.optim import Adam
from time import time

from tqdm import tqdm

device = torch.device("cuda:0")
torch.set_default_dtype(torch.float64)

MODELS_DIR  = '/home/cxw/sonos_rirs/models/'

In [2]:
model_dict = {}
model_dict['name'] = "testrun2_regularization"
model_dict['notes'] = "same as test run but with regularization"
model_dict['data_path'] = '/home/cxw/sonos_rirs/features/080122_5k_phase/feature_df.csv'
model_dict['model_path'] = os.path.join(MODELS_DIR, model_dict['name'])

In [3]:
try:
    # 尝试导入IPython
    from IPython import get_ipython
    # 检查是否在IPython环境下
    if get_ipython() is not None:
        # 加载autoreload扩展
        %load_ext autoreload
        # 设置autoreload为2
        %autoreload 2
except ImportError:
    # 如果IPython没有被安装，则不作任何操作
    pass

In [4]:
# %autoreload 2
# # import volume_estimation.modeling as model_funcs
# model_funcs.train_model(model_funcs.Baseline_Model, model_dict,\
#                         overwrite=True, epochs=1,log=False) #######################################################

In [5]:
feat_df = pd.read_csv(model_dict['data_path'])
model_path = os.path.join(MODELS_DIR, model_dict['name'])

dataset = []

    
def create_dataloader(feature_df, batch_size=1, log=True):
    dataset = []
    for row in tqdm(feature_df.iterrows()):
        feat_file = row[1]['file_feature']
        loaded = np.load(feat_file)

        feature = loaded['feat']
        feature = feature.reshape((feature.shape[0], feature.shape[1]))
        feature = np.real(feature)

        vol = loaded['vol']
        if log:
            vol = np.log10(vol)
        dataset.append((feature, vol))
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    return dataloader

dataloader = create_dataloader(feat_df, log=False)

32000it [00:19, 1620.30it/s]


In [6]:
savename = './testmodeldict.json'
with open(savename, 'w') as f:
    json.dump(model_dict, f)
    
with open(savename) as f:
    load_dict = json.load(f)
    
for key in load_dict.keys():
    print(key, load_dict[key])

name testrun2_regularization
notes same as test run but with regularization
data_path /home/cxw/sonos_rirs/features/080122_5k_phase/feature_df.csv
model_path /home/cxw/sonos_rirs/models/testrun2_regularization


In [7]:
train_df = feat_df[feat_df['split']=='train']
val_df = feat_df[feat_df['split']=='val']
test_df = feat_df[feat_df['split']=='test']

print("Creating training dataloader")
train_dataloader = create_dataloader(train_df, batch_size=16)        ##################################batch_size

print("Creating validation dataloader")
val_dataloader = create_dataloader(val_df)

print("Creating test dataloader")
test_dataloader = create_dataloader(test_df)

Creating training dataloader


19200it [00:11, 1612.44it/s]


Creating validation dataloader


6420it [00:03, 1664.66it/s]


Creating test dataloader


6380it [00:03, 1663.27it/s]


In [8]:
features, labels = next(iter(train_dataloader))
# features = features.squeeze(1)
# features = features.transpose(1,2)
print(f"Feature batch shape: {features.size()}")
print(f"Labels batch shape: {labels.size()}")


Feature batch shape: torch.Size([16, 30, 1997])
Labels batch shape: torch.Size([16])


In [9]:
# class Baseline_Model(Module):
#     def __init__(self, input_shape):
#         #accepts a tuple with the height/width of the feature
#         #matrix to set the FC layer dimensions
#         super(Baseline_Model, self).__init__()
        
#         #block1
#         Conv1 = Conv2d(1, 30, kernel_size=(1,10), stride=(1,1))
#         Avgpool1 = AvgPool2d((1,2), stride=(1,2))

#         #block2
#         Conv2 = Conv2d(30, 20, kernel_size=(1,10), stride=(1,1))
#         Avgpool2 = AvgPool2d((1,2), stride=(1,2))

#         #block3
#         Conv3 = Conv2d(20, 10, kernel_size=(1,10), stride=(1,1))
#         Avgpool3 = AvgPool2d((1,2), stride=(1,2))

#         #block4
#         Conv4 = Conv2d(10, 10, kernel_size=(1,10), stride=(1,1))
#         Avgpool4 = AvgPool2d((1,2), stride=(1,2))

#         #block5
#         Conv5 = Conv2d(10, 5, kernel_size=(3,9), stride=(1,1))
#         Avgpool5 = AvgPool2d((1,2), stride=(1,2))

#         #block6
#         Conv6 = Conv2d(5, 5, kernel_size=(3,9), stride=(1,1))
#         Avgpool6 = AvgPool2d((2,2), stride=(2,2))

#         #dropout
#         dropout_layer = Dropout(p=0.5)
#         height5 = input_shape[0] - 2
#         height6 = (height5 - 2) // 2

#         time1 = (input_shape[1] - 9) // 2
#         time2 = (time1 - 9) // 2
#         time3 = (time2 - 9) // 2
#         time4 = (time3 - 9) // 2
#         time5 = (time4 - 7) // 2
#         time6 = (time5 - 7) // 2

#         flat_dims = 5 * height6 * time6
#         fc_layer = Linear(flat_dims, 1)
        
#         self.net = Sequential(
#                     Conv1, ReLU(), Avgpool1,
#                     Conv2, ReLU(), Avgpool2,
#                     Conv3, ReLU(), Avgpool3,
#                     Conv4, ReLU(), Avgpool4,
#                     Conv5, ReLU(), Avgpool5,
#                     Conv6, ReLU(), Avgpool6,
#                     dropout_layer, Flatten(),
#                     fc_layer, Flatten()
#                 )
#     def forward(self, x):
#         for layer in self.net:
#             x = layer(x)
#         return x

In [10]:
# UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition
# Github source: https://github.com/AILab-CVC/UniRepLKNet
# Licensed under The Apache License 2.0 License [see LICENSE for details]
# Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases
# https://github.com/DingXiaoH/RepLKNet-pytorch
# https://github.com/facebookresearch/ConvNeXt
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
from timm.models.registry import register_model
from functools import partial
import torch.utils.checkpoint as checkpoint
try:
    from huggingface_hub import hf_hub_download
except:
    hf_hub_download = None      # install huggingface_hub if you would like to download models conveniently from huggingface

has_mmdet = False
has_mmseg = False
#   =============== for the ease of directly using this file in MMSegmentation and MMDetection.
#   =============== ignore the following two segments of code if you do not plan to do so
#   =============== delete one of the following two segments if you get a confliction
try:
    from mmseg.models.builder import BACKBONES as seg_BACKBONES
    from mmseg.utils import get_root_logger
    from mmcv.runner import _load_checkpoint
    has_mmseg = True
except ImportError:
    get_root_logger = None
    _load_checkpoint = None

# try:
#     from mmdet.models.builder import BACKBONES as det_BACKBONES
#     from mmdet.utils import get_root_logger
#     from mmcv.runner import _load_checkpoint
#     has_mmdet = True
# except ImportError:
#     get_root_logger = None
#     _load_checkpoint = None
#   ===========================================================================================


class GRNwithNHWC(nn.Module):
    """ GRN (Global Response Normalization) layer
    Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
    This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
    We assume the inputs to this layer are (N, H, W, C)
    """
    def __init__(self, dim, use_bias=True):
        super().__init__()
        self.use_bias = use_bias
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        if self.use_bias:
            self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        if self.use_bias:
            return (self.gamma * Nx + 1) * x + self.beta
        else:
            return (self.gamma * Nx + 1) * x


class NCHWtoNHWC(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 2, 3, 1)


class NHWCtoNCHW(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 3, 1, 2)

#================== This function decides which conv implementation (the native or iGEMM) to use
#   Note that iGEMM large-kernel conv impl will be used if
#       -   you attempt to do so (attempt_to_use_large_impl=True), and
#       -   it has been installed (follow https://github.com/AILab-CVC/UniRepLKNet), and
#       -   the conv layer is depth-wise, stride = 1, non-dilated, kernel_size > 5, and padding == kernel_size // 2
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
               attempt_use_lk_impl=True):
    kernel_size = to_2tuple(kernel_size)
    if padding is None:
        padding = (kernel_size[0] // 2, kernel_size[1] // 2)
    else:
        padding = to_2tuple(padding)
    need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)

    if attempt_use_lk_impl and need_large_impl:
        print('---------------- trying to import iGEMM implementation for large-kernel conv')
        try:
            from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
            print('---------------- found iGEMM implementation ')
        except:
            DepthWiseConv2dImplicitGEMM = None
            print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
        if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
                and out_channels == groups and stride == 1 and dilation == 1:
            print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
            return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=bias)


def get_bn(dim, use_sync_bn=False):
    if use_sync_bn:
        return nn.SyncBatchNorm(dim)
    else:
        return nn.BatchNorm2d(dim)

class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507)
    We assume the inputs to this layer are (N, C, H, W)
    """
    def __init__(self, input_channels, internal_neurons):
        super(SEBlock, self).__init__()
        self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons,
                              kernel_size=1, stride=1, bias=True)
        self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels,
                            kernel_size=1, stride=1, bias=True)
        self.input_channels = input_channels
        self.nonlinear = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
        x = self.down(x)
        x = self.nonlinear(x)
        x = self.up(x)
        x = F.sigmoid(x)
        return inputs * x.view(-1, self.input_channels, 1, 1)

def fuse_bn(conv, bn):
    conv_bias = 0 if conv.bias is None else conv.bias
    std = (bn.running_var + bn.eps).sqrt()
    return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std

def convert_dilated_to_nondilated(kernel, dilate_rate):
    identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
    if kernel.size(1) == 1:
        #   This is a DW kernel
        dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
        return dilated
    else:
        #   This is a dense or group-wise (but not DW) kernel
        slices = []
        for i in range(kernel.size(1)):
            dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
            slices.append(dilated)
        return torch.cat(slices, dim=1)

def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
    large_k = large_kernel.size(2)
    dilated_k = dilated_kernel.size(2)
    equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
    equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
    rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
    merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
    return merged_kernel


class DilatedReparamBlock(nn.Module):
    """
    Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
    We assume the inputs to this block are (N, C, H, W)
    """
    def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
        super().__init__()
        self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
                                    padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
                                    attempt_use_lk_impl=attempt_use_lk_impl)
        self.attempt_use_lk_impl = attempt_use_lk_impl

        #   Default settings. We did not tune them carefully. Different settings may work better.
        if kernel_size == 17:
            self.kernel_sizes = [5, 9, 3, 3, 3]
            self.dilates = [1, 2, 4, 5, 7]
        elif kernel_size == 15:
            self.kernel_sizes = [5, 7, 3, 3, 3]
            self.dilates = [1, 2, 3, 5, 7]
        elif kernel_size == 13:
            self.kernel_sizes = [5, 7, 3, 3, 3]
            self.dilates = [1, 2, 3, 4, 5]
        elif kernel_size == 11:
            self.kernel_sizes = [5, 5, 3, 3, 3]
            self.dilates = [1, 2, 3, 4, 5]
        elif kernel_size == 9:
            self.kernel_sizes = [5, 5, 3, 3]
            self.dilates = [1, 2, 3, 4]
        elif kernel_size == 7:
            self.kernel_sizes = [5, 3, 3]
            self.dilates = [1, 2, 3]
        elif kernel_size == 5:
            self.kernel_sizes = [3, 3]
            self.dilates = [1, 2]
        else:
            raise ValueError('Dilated Reparam Block requires kernel_size >= 5')

        if not deploy:
            self.origin_bn = get_bn(channels, use_sync_bn)
            for k, r in zip(self.kernel_sizes, self.dilates):
                self.__setattr__('dil_conv_k{}_{}'.format(k, r),
                                 nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
                                           padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
                                           bias=False))
                self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))

    def forward(self, x):
        if not hasattr(self, 'origin_bn'):      # deploy mode
            return self.lk_origin(x)
        out = self.origin_bn(self.lk_origin(x))
        for k, r in zip(self.kernel_sizes, self.dilates):
            conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
            bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
            out = out + bn(conv(x))
        return out

    def merge_dilated_branches(self):
        if hasattr(self, 'origin_bn'):
            origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
            for k, r in zip(self.kernel_sizes, self.dilates):
                conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
                bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
                branch_k, branch_b = fuse_bn(conv, bn)
                origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
                origin_b += branch_b
            merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
                                    padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
                                    attempt_use_lk_impl=self.attempt_use_lk_impl)
            merged_conv.weight.data = origin_k
            merged_conv.bias.data = origin_b
            self.lk_origin = merged_conv
            self.__delattr__('origin_bn')
            for k, r in zip(self.kernel_sizes, self.dilates):
                self.__delattr__('dil_conv_k{}_{}'.format(k, r))
                self.__delattr__('dil_bn_k{}_{}'.format(k, r))


class UniRepLKNetBlock(nn.Module):

    def __init__(self,
                 dim,
                 kernel_size,
                 drop_path=0.,
                 layer_scale_init_value=1e-6,
                 deploy=False,
                 attempt_use_lk_impl=True,
                 with_cp=False,
                 use_sync_bn=False,
                 ffn_factor=4):
        super().__init__()
        self.with_cp = with_cp
        if deploy:
            print('------------------------------- Note: deploy mode')
        if self.with_cp:
            print('****** note with_cp = True, reduce memory consumption but may slow down training ******')

        if kernel_size == 0:
            self.dwconv = nn.Identity()
        elif kernel_size >= 7:
            self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy,
                                              use_sync_bn=use_sync_bn,
                                              attempt_use_lk_impl=attempt_use_lk_impl)

        else:
            assert kernel_size in [3, 5]
            self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
                                     dilation=1, groups=dim, bias=deploy,
                                     attempt_use_lk_impl=attempt_use_lk_impl)

        if deploy or kernel_size == 0:
            self.norm = nn.Identity()
        else:
            self.norm = get_bn(dim, use_sync_bn=use_sync_bn)

        self.se = SEBlock(dim, dim // 4)

        ffn_dim = int(ffn_factor * dim)
        self.pwconv1 = nn.Sequential(
            NCHWtoNHWC(),
            nn.Linear(dim, ffn_dim))
        self.act = nn.Sequential(
            nn.GELU(),
            GRNwithNHWC(ffn_dim, use_bias=not deploy))
        if deploy:
            self.pwconv2 = nn.Sequential(
                nn.Linear(ffn_dim, dim),
                NHWCtoNCHW())
        else:
            self.pwconv2 = nn.Sequential(
                nn.Linear(ffn_dim, dim, bias=False),
                NHWCtoNCHW(),
                get_bn(dim, use_sync_bn=use_sync_bn))

        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
                                  requires_grad=True) if (not deploy) and layer_scale_init_value is not None \
                                                         and layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def compute_residual(self, x):
        y = self.se(self.norm(self.dwconv(x)))
        y = self.pwconv2(self.act(self.pwconv1(y)))
        if self.gamma is not None:
            y = self.gamma.view(1, -1, 1, 1) * y
        return self.drop_path(y)

    def forward(self, inputs):

        def _f(x):
            return x + self.compute_residual(x)

        if self.with_cp and inputs.requires_grad:
            out = checkpoint.checkpoint(_f, inputs)
        else:
            out = _f(inputs)
        return out

    def reparameterize(self):
        if hasattr(self.dwconv, 'merge_dilated_branches'):
            self.dwconv.merge_dilated_branches()
        if hasattr(self.norm, 'running_var'):
            std = (self.norm.running_var + self.norm.eps).sqrt()
            if hasattr(self.dwconv, 'lk_origin'):
                self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1)
                self.dwconv.lk_origin.bias.data = self.norm.bias + (
                            self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std
            else:
                conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size,
                                 padding=self.dwconv.padding, groups=self.dwconv.groups, bias=True)
                conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1)
                conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std
                self.dwconv = conv
            self.norm = nn.Identity()
        if self.gamma is not None:
            final_scale = self.gamma.data
            self.gamma = None
        else:
            final_scale = 1
        if self.act[1].use_bias and len(self.pwconv2) == 3:
            grn_bias = self.act[1].beta.data
            self.act[1].__delattr__('beta')
            self.act[1].use_bias = False
            linear = self.pwconv2[0]
            grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze()
            bn = self.pwconv2[2]
            std = (bn.running_var + bn.eps).sqrt()
            new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
            new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1)
            linear_bias = 0 if linear.bias is None else linear.bias.data
            linear_bias += grn_bias_projected_bias
            new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale
            self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1])



default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3),
                                      (13, 13),
                                      (13, 13, 13, 13, 13, 13),
                                      (13, 13))
default_UniRepLKNet_N_kernel_sizes = ((3, 3),
                                      (13, 13),
                                      (13, 13, 13, 13, 13, 13, 13, 13),
                                      (13, 13))
default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3),
                                      (13, 13, 13),
                                      (13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3),
                                      (13, 13, 13))
default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3),
                                             (13, 13, 13),
                                             (13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3),
                                             (13, 13, 13))
UniRepLKNet_A_F_P_depths = (2, 2, 6, 2)
UniRepLKNet_N_depths = (2, 2, 8, 2)
UniRepLKNet_T_depths = (3, 3, 18, 3)
UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3)

default_depths_to_kernel_sizes = {
    UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes,
    UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes,
    UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes,
    UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes
}

class UniRepLKNet(nn.Module):
    r""" UniRepLKNet
        A PyTorch impl of UniRepLKNet

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3)
        dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768)
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
        kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None.
        deploy (bool): deploy = True means using the inference structure. Default: False
        with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False
        init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None
        attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True
        use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False
    """
    def __init__(self,
                 in_chans=3,
                 num_classes=768,
                 depths=(3, 3, 27, 3),
                 dims=(96, 192, 384, 768),
                 drop_path_rate=0.,
                 layer_scale_init_value=1e-6,
                 head_init_scale=1.,
                 kernel_sizes=None,
                 deploy=False,
                 with_cp=False,
                 init_cfg=None,
                 attempt_use_lk_impl=True,
                 use_sync_bn=False,
                 **kwargs
                 ):
        super().__init__()

        depths = tuple(depths)
        if kernel_sizes is None:
            if depths in default_depths_to_kernel_sizes:
                print('=========== use default kernel size ')
                kernel_sizes = default_depths_to_kernel_sizes[depths]
            else:
                raise ValueError('no default kernel size settings for the given depths, '
                                 'please specify kernel sizes for each block, e.g., '
                                 '((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))')
        print(kernel_sizes)
        for i in range(4):
            assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths'

        self.with_cp = with_cp

        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        print('=========== drop path rates: ', dp_rates)

        self.downsample_layers = nn.ModuleList()
        self.downsample_layers.append(nn.Sequential(
            nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1),
            LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"),
            nn.GELU(),
            nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")))

        for i in range(3):
            self.downsample_layers.append(nn.Sequential(
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1),
                LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first")))

        self.stages = nn.ModuleList()

        cur = 0
        for i in range(4):
            main_stage = nn.Sequential(
                *[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j],
                                   layer_scale_init_value=layer_scale_init_value, deploy=deploy,
                                   attempt_use_lk_impl=attempt_use_lk_impl,
                                   with_cp=with_cp, use_sync_bn=use_sync_bn) for j in
                  range(depths[i])])
            self.stages.append(main_stage)
            cur += depths[i]

        last_channels = dims[-1]

        self.for_pretrain = init_cfg is None
        self.for_downstream = not self.for_pretrain     # there may be some other scenarios
        if self.for_downstream:
            assert num_classes is None

        if self.for_pretrain:
            self.init_cfg = None
            self.norm = nn.LayerNorm(last_channels, eps=1e-6)  # final norm layer
            self.head = nn.Linear(last_channels, num_classes)
            self.apply(self._init_weights)
            self.head.weight.data.mul_(head_init_scale)
            self.head.bias.data.mul_(head_init_scale)
            self.output_mode = 'logits'
        else:
            self.init_cfg = init_cfg        # OpenMMLab style init
            self.init_weights()
            self.output_mode = 'features'
            norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
            for i_layer in range(4):
                layer = norm_layer(dims[i_layer])
                layer_name = f'norm{i_layer}'
                self.add_module(layer_name, layer)


    #   load pretrained backbone weights in the OpenMMLab style
    def init_weights(self):

        def load_state_dict(module, state_dict, strict=False, logger=None):
            unexpected_keys = []
            own_state = module.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    unexpected_keys.append(name)
                    continue
                if isinstance(param, torch.nn.Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    raise RuntimeError(
                        'While copying the parameter named {}, '
                        'whose dimensions in the model are {} and '
                        'whose dimensions in the checkpoint are {}.'.format(
                            name, own_state[name].size(), param.size()))
            missing_keys = set(own_state.keys()) - set(state_dict.keys())

            err_msg = []
            if unexpected_keys:
                err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys)))
            if missing_keys:
                err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys)))
            err_msg = '\n'.join(err_msg)
            if err_msg:
                if strict:
                    raise RuntimeError(err_msg)
                elif logger is not None:
                    logger.warn(err_msg)
                else:
                    print(err_msg)

        logger = get_root_logger()
        assert self.init_cfg is not None
        ckpt_path = self.init_cfg['checkpoint']
        if ckpt_path is None:
            print('================ Note: init_cfg is provided but I got no init ckpt path, so skip initialization')
        else:
            ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt

            load_state_dict(self, _state_dict, strict=False, logger=logger)


    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.output_mode == 'logits':
            for stage_idx in range(4):
                x = self.downsample_layers[stage_idx](x)
                x = self.stages[stage_idx](x)
            x = self.norm(x.mean([-2, -1]))
            x = self.head(x)
            return x
        elif self.output_mode == 'features':
            outs = []
            for stage_idx in range(4):
                x = self.downsample_layers[stage_idx](x)
                x = self.stages[stage_idx](x)
                outs.append(self.__getattr__(f'norm{stage_idx}')(x))
            return outs
        else:
            raise ValueError('Defined new output mode?')

    def reparameterize_unireplknet(self):
        for m in self.modules():
            if hasattr(m, 'reparameterize'):
                m.reparameterize()



class LayerNorm(nn.Module):
    r""" LayerNorm implementation used in ConvNeXt
    LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)
        self.reshape_last_to_first = reshape_last_to_first

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


#   For easy use as backbone in MMDetection framework. Ignore these lines if you do not use MMDetection
if has_mmdet:
    @det_BACKBONES.register_module()
    class UniRepLKNetBackbone(UniRepLKNet):
        def __init__(self,
                     depths=(3, 3, 27, 3),
                     dims=(96, 192, 384, 768),
                     drop_path_rate=0.,
                     layer_scale_init_value=1e-6,
                     kernel_sizes=None,
                     deploy=False,
                     with_cp=False,
                     init_cfg=None,
                     attempt_use_lk_impl=False):
            assert init_cfg is not None
            super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
                             drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
                             kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
                             init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)

#   For easy use as backbone in MMSegmentation framework. Ignore these lines if you do not use MMSegmentation
if has_mmseg:
    @seg_BACKBONES.register_module()
    class UniRepLKNetBackbone(UniRepLKNet):
        def __init__(self,
                     depths=(3, 3, 27, 3),
                     dims=(96, 192, 384, 768),
                     drop_path_rate=0.,
                     layer_scale_init_value=1e-6,
                     kernel_sizes=None,
                     deploy=False,
                     with_cp=False,
                     init_cfg=None,
                     attempt_use_lk_impl=False):
            assert init_cfg is not None
            super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
                             drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
                             kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
                             init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)


model_urls = {
    #TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions?
}

huggingface_file_names = {
    "unireplknet_a_1k": "unireplknet_a_in1k_224_acc77.03.pth",
    "unireplknet_f_1k": "unireplknet_f_in1k_224_acc78.58.pth",
    "unireplknet_p_1k": "unireplknet_p_in1k_224_acc80.23.pth",
    "unireplknet_n_1k": "unireplknet_n_in1k_224_acc81.64.pth",
    "unireplknet_t_1k": "unireplknet_t_in1k_224_acc83.21.pth",
    "unireplknet_s_1k": "unireplknet_s_in1k_224_acc83.91.pth",
    "unireplknet_s_22k": "unireplknet_s_in22k_pretrain.pth",
    "unireplknet_s_22k_to_1k": "unireplknet_s_in22k_to_in1k_384_acc86.44.pth",
    "unireplknet_b_22k": "unireplknet_b_in22k_pretrain.pth",
    "unireplknet_b_22k_to_1k": "unireplknet_b_in22k_to_in1k_384_acc87.40.pth",
    "unireplknet_l_22k": "unireplknet_l_in22k_pretrain.pth",
    "unireplknet_l_22k_to_1k": "unireplknet_l_in22k_to_in1k_384_acc87.88.pth",
    "unireplknet_xl_22k": "unireplknet_xl_in22k_pretrain.pth",
    "unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth"
}

def load_with_key(model, key):
    # if huggingface hub is found, download from our huggingface repo
    if hf_hub_download is not None:
        repo_id = 'DingXiaoH/UniRepLKNet'
        cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key])
        checkpoint = torch.load(cache_file, map_location='cpu')
    else:
        checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True)
    if 'model' in checkpoint:
        checkpoint = checkpoint['model']
    model.load_state_dict(checkpoint)

def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k):
    if in_1k_pretrained:
        key = model_name + '_1k'
    elif in_22k_pretrained:
        key = model_name + '_22k'
    elif in_22k_to_1k:
        key = model_name + '_22k_to_1k'
    else:
        key = None
    if key:
        load_with_key(model, key)

@register_model
def unireplknet_a(in_1k_pretrained=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_a', in_1k_pretrained, False, False)
    return model

@register_model
def unireplknet_f(in_1k_pretrained=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_f', in_1k_pretrained, False, False)
    return model

@register_model
def unireplknet_p(in_1k_pretrained=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_p', in_1k_pretrained, False, False)
    return model

@register_model
def unireplknet_n(in_1k_pretrained=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_n', in_1k_pretrained, False, False)
    return model

@register_model
def unireplknet_t(in_1k_pretrained=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_t', in_1k_pretrained, False, False)
    return model

@register_model
def unireplknet_s(in_1k_pretrained=False, in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_s', in_1k_pretrained, in_22k_pretrained, in_22k_to_1k)
    return model

@register_model
def unireplknet_b(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_b', False, in_22k_pretrained, in_22k_to_1k)
    return model

@register_model
def unireplknet_l(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_l', False, in_22k_pretrained, in_22k_to_1k)
    return model

@register_model
def unireplknet_xl(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
    model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs)
    initialize_with_pretrained(model, 'unireplknet_xl', False, in_22k_pretrained, in_22k_to_1k)
    return model



if __name__ == '__main__':
    #   Test case showing the equivalency of Structural Re-parameterization
    x = torch.randn(2, 4, 19, 19)
    layer = UniRepLKNetBlock(4, kernel_size=13, attempt_use_lk_impl=False)
    for n, p in layer.named_parameters():
        if 'beta' in n:
            torch.nn.init.ones_(p)
        else:
            torch.nn.init.normal_(p)
    for n, p in layer.named_buffers():
        if 'running_var' in n:
            print('random init var')
            torch.nn.init.uniform_(p)
            p.data += 2
        elif 'running_mean' in n:
            print('random init mean')
            torch.nn.init.uniform_(p)
    layer.gamma.data += 0.5
    layer.eval()
    origin_y = layer(x)
    layer.reparameterize()
    eq_y = layer(x)
    print(layer)
    print(eq_y - origin_y)
    print((eq_y - origin_y).abs().sum() / origin_y.abs().sum())

import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os
import wget
os.environ['TORCH_HOME'] = '../../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_
# from unireplknet import UniRepLKNet

# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: the number of frequency bins of the input spectrogram
    :param input_tdim: the number of time frames of the input spectrogram
    :param imagenet_pretrain: if use ImageNet pretrained model
    :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
    :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
    """
    def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=224, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):

        super(ASTModel, self).__init__()

        if verbose == True:
            print('---------------AST Model Summary---------------')
            print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
        timm.models.vision_transformer.PatchEmbed = PatchEmbed
        self.v =  UniRepLKNet(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], drop_path_rate=0.4,
                              in_chans=1,
                       disable_iGEMM=True)
        self.original_embedding_dim = 768
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

    def get_shape(self, fstride, tstride, input_fdim=224, input_tdim=1024):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

    @autocast()
    def forward(self, x):
        """
        :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        :return: prediction
        """
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)
        x = self.v(x)
        x = self.mlp_head(x)
        return x

# if __name__ == '__main__':
#     input_tdim = 100
#     ast_mdl = ASTModel(input_tdim=input_tdim)
#     # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
#     test_input = torch.rand([10, input_tdim, 128])
#     test_output = ast_mdl(test_input)
#     # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
#     print(test_output.shape)

#     input_tdim = 256
#     ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
#     # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
#     test_input = torch.rand([10, input_tdim, 128])
#     test_output = ast_mdl(test_input)
#     # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
#     print(test_output.shape)

random init mean
random init var
random init mean
random init var
random init mean
random init var
random init mean
random init var
random init mean
random init var
random init mean
random init var
random init mean
random init var
random init mean
random init var
UniRepLKNetBlock(
  (dwconv): DilatedReparamBlock(
    (lk_origin): Conv2d(4, 4, kernel_size=(13, 13), stride=(1, 1), padding=(6, 6), groups=4)
  )
  (norm): Identity()
  (se): SEBlock(
    (down): Conv2d(4, 1, kernel_size=(1, 1), stride=(1, 1))
    (up): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1))
    (nonlinear): ReLU(inplace=True)
  )
  (pwconv1): Sequential(
    (0): NCHWtoNHWC()
    (1): Linear(in_features=4, out_features=16, bias=True)
  )
  (act): Sequential(
    (0): GELU(approximate=none)
    (1): GRNwithNHWC()
  )
  (pwconv2): Sequential(
    (0): Linear(in_features=16, out_features=4, bias=True)
    (1): NHWCtoNCHW()
  )
  (drop_path): Identity()
)
tensor([[[[ 1.3323e-15, -4.4409e-15, -7.1054e-15,  ..., -1.7764e



In [11]:
input_height = features.size()[1]
input_width = features.size()[2]
model = ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=input_height, input_tdim=input_width, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True).to(device)
# model = Baseline_Model((input_height, input_width)).to(device)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: False
((3, 3, 3), (13, 13, 13), (13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3), (13, 13, 13))
---------------- trying to import iGEMM implementation for large-kernel conv
---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.
---------------- trying to import iGEMM implementation for large-kernel conv
---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.
---------------- trying to import iGEMM implementation for large-kernel conv
---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.
---------------- trying to import iGEMM implementation for large-kernel conv
---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install 

In [12]:
def MSE(output, target):
    loss = torch.mean((output - target)**2)
    return loss

def Bias(output, target):
    loss = torch.mean(torch.abs(10**output - 10**target))
    return loss

def CovStep(output, target, output_mean, target_mean):
    loss = torch.mean(((output - output_mean) * (target - target_mean)))
    return loss

def MeanAbsLogStep(output, target, log=True):
    #convert out of log
    if log:
        vol_pred = 10**output
        vol_target = 10**target
    else:
        vol_pred = output
        vol_target = target
    loss = torch.mean(torch.abs(torch.log(vol_pred/vol_target)))
    return loss

def compute_eval_metrics(dataloader, model, log=True):
    target_sum = 0
    pred_sum = 0
    n_steps = 0
    
    for (x,y) in dataloader:        
        (x, y) = (x.to(device), y.to(device))
        pred = model(x)
        target_sum += y.sum()
        pred_sum += pred.sum()
        n_steps += 1
    
    target_mean = target_sum/n_steps
    pred_mean = pred_sum/n_steps
    
    mse = 0
    mean_error = 0
    cov = 0
    abs_log_ratio = 0
    
    var_pred = 0 #technically var * N but gets cancelled out in Pearson calculation
    var_target = 0 
    
    for (x,y) in dataloader:        
        (x, y) = (x.to(device), y.to(device))
        pred = model(x)
        mse += MSE(pred, y)
        mean_error += Bias(pred, y)
        cov += CovStep(pred, y, pred_mean, target_mean)
        abs_log_ratio += MeanAbsLogStep(pred, y, log=log)
        
        var_pred += MSE(pred, pred_mean)
        var_target += MSE(y, target_mean)
        
        pears = CovStep(pred, y, pred_mean, target_mean)/(torch.sqrt(MSE(pred, pred_mean))*torch.sqrt(MSE(y, target_mean)))
    
    out_dict = {}
    out_dict['mse'] = (mse / n_steps).item()
    out_dict['bias'] = (mean_error / n_steps).item()
    out_dict['pearson_cor'] = (cov/(torch.sqrt(var_pred) * torch.sqrt(var_target))).item()
    out_dict['mean_mult'] = (torch.exp(abs_log_ratio/n_steps)).item()
    
    return out_dict
    
# with torch.no_grad():
#     eval_dict = compute_eval_metrics(val_dataloader, model)
#     print(eval_dict)

In [None]:
opt = Adam(model.parameters(),lr=0.000005, weight_decay=1e-2)

hist = {
	"train_loss": [],
	"val_loss": [],
    "val_bias": [],
    "val_pearson_cor": [],
    "val_mean_mult": []
}

for ep in range(150):     #########################################################################                   
    t_start = time()
    model.train()
    
    train_loss = 0
    val_loss = 0
    train_steps = 0
    val_steps = 0
    
    for (x, y) in train_dataloader:
        (x, y) = (x.to(device), y.to(device))
        pred = model(x)
        loss = MSE(pred, y.reshape((y.shape[0], 1)))
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        train_loss += loss
        train_steps += 1
    
    with torch.no_grad():
        model.eval()
        
        val_metrics = compute_eval_metrics(val_dataloader, model)
    
    
    hist['train_loss'].append(train_loss/train_steps)
    hist['val_loss'].append(val_metrics['mse'])
    hist['val_bias'].append(val_metrics['bias'])
    hist['val_pearson_cor'].append(val_metrics['pearson_cor'])
    hist['val_mean_mult'].append(val_metrics['mean_mult'])
    
    t_end = time()
    
    t_elapsed = t_end - t_start
    print("Epoch: {}\tDuration: {:.2f}s\tTrain loss: {:.4f}\tVal loss: {:.4f}\tVal bias:{:.4f}\tVal Pearson correlation: {:.4e}\tVal MeanMult: {:.4f}"\
          .format(ep, t_elapsed, train_loss/train_steps, val_metrics['mse'],\
                  val_metrics['bias'], val_metrics['pearson_cor'],val_metrics['mean_mult']))
    
    
    

Epoch: 0	Duration: 2103.09s	Train loss: 0.6371	Val loss: 0.6571	Val bias:2638.3502	Val Pearson correlation: 4.6044e-01	Val MeanMult: 4.3455
Epoch: 1	Duration: 2125.05s	Train loss: 0.5750	Val loss: 0.6197	Val bias:2618.5975	Val Pearson correlation: 4.9670e-01	Val MeanMult: 4.1576
Epoch: 2	Duration: 2148.23s	Train loss: 0.5545	Val loss: 0.6008	Val bias:2603.1476	Val Pearson correlation: 5.1809e-01	Val MeanMult: 4.0621


In [None]:
torch.cuda.empty_cache()

In [None]:
import csv

# 创建一个空列表来存储pred和y的值
data_to_save = []

test1_df = test_df.sample(5)

mae = 0.0 
total_samples = 0

test_dataloader1 = create_dataloader(test1_df) 
test_random = compute_eval_metrics(test_dataloader1,model) 
print(test_random)

for(x,y) in test_dataloader: 
    (x,y) = (x.to(device),y.to(device)) 
    pred = model(x) 
    for i in range(len(pred)):
        data_to_save.append([pred[i].item(), y[i].item()])

# 指定要保存的CSV文件名
csv_filename = 'predictions.csv'

# 打开CSV文件并将数据写入
with open(csv_filename, 'w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    
    # 写入列名（如果需要）
    csv_writer.writerow(['Prediction', 'Actual'])
    
    # 写入数据
    csv_writer.writerows(data_to_save)

print(f'Data saved to {csv_filename}')
#     print(pred,'///',y)

#     # 计算绝对误差
#     absolute_error = torch.abs(pred-y)

#     # 累加绝对误差和样本数
#     mae += absolute_error.sum().item()
    

# #     计算平均绝对误差


# mae /= 5 
# print("MAE:", mae)


In [None]:
with torch.no_grad():
    eval_test = compute_eval_metrics(test_dataloader, model)
    print(eval_test)

In [None]:
hist.keys()
np.std(hist['val_loss'][15:])
np.arange(100)[-10:]

In [None]:
random_df = feat_df.sample(20)

with torch.no_grad():
    random_dataloader = create_dataloader(random_df)
    eval_random = compute_eval_metrics(random_dataloader, model)
    print(eval_random)

    for (x, y) in random_dataloader:
            (x, y) = (x.to(device), y.to(device))
            pred = model(x)
            print(pred.item())

In [None]:
class print_Model(Module):
    def __init__(self, seq):
        super(print_Model, self).__init__()
        self.net = seq

    def forward(self, x):
        print("Start\n{}".format(x.size()))
        for layer in self.net:
            x = layer(x)
            print(layer)
            print(x.size())
        return x

In [None]:
%load_ext autoreload


In [None]:
random_df = feat_df.sample(20)

random_dataloader = create_dataloader(random_df, log=False)


load_model = Baseline_Model((input_height, input_width)).to(device)
model_name = 'testrun2_regularization'
load_model.load_state_dict(torch.load(os.path.join(MODELS_DIR,model_name,'model_state.pt'), map_location=torch.device('cpu')))

load_metrics = compute_eval_metrics(test_dataloader, load_model, log=False)
for key in load_metrics.keys():
    print(key, load_metrics[key])

for (x, y) in random_dataloader:
    (x, y) = (x.to(device), y.to(device))
    pred = load_model(x)
    print(pred.item())


In [None]:
%load_ext autoreload


In [None]:
feat_df['vol'].hist()

### 