In [1]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
import torchaudio
from copy import deepcopy
import torchaudio
from model.encoder import ResNetExtractor

In [2]:
model2 = ResNetExtractor(
            checkpoint=None,
            scenario="finetune",  # frozen
            transform=True,
        )

model2 = model2.encoder

In [10]:
from resnet import ResNet

In [62]:
def convert_2d_to_1d(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            # Get the parameters of the 2D convolutional layer
            in_channels = module.in_channels
            out_channels = module.out_channels
            kernel_size = module.kernel_size
            stride = module.stride
            padding = module.padding
            dilation = module.dilation
            groups = module.groups
            bias = module.bias is not None

            # Create a new 1D convolutional layer with equivalent parameters
            conv1d = nn.Conv1d(in_channels, out_channels, kernel_size[0], stride[0]**2, padding[0], dilation[0], groups, bias)
            conv1d.weight.data.copy_(torch.mean(module.weight, dim=2))
            if module.bias is not None:
                conv1d.bias.data.copy_(module.bias)
            
            # Replace the 2D convolutional layer with the new 1D convolutional layer
            setattr(model, name, conv1d)
            
        if isinstance(module, nn.BatchNorm2d):
            # Get the parameters of the 2D BatchNorm layer
            num_features = module.num_features
            eps = module.eps
            momentum = module.momentum
            affine = module.affine
            track_running_stats = module.track_running_stats

            # Create a new 1D BatchNorm layer with equivalent parameters
            bn1d = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)

            bn1d.weight.data.copy_(module.weight)
            bn1d.bias.data.copy_(module.bias)
            bn1d.running_mean.data.copy_(module.running_mean)
            bn1d.running_var.data.copy_(module.running_var)

            
            # Replace the 2D BatchNorm layer with the new 1D BatchNorm layer
            setattr(model, name, bn1d)

        if isinstance(module, nn.MaxPool2d):
            # Get the parameters of the 2D MaxPool layer
            kernel_size = module.kernel_size
            stride = module.stride
            padding = module.padding
            dilation = module.dilation
            return_indices = module.return_indices
            ceil_mode = module.ceil_mode

            # Create a new 1D MaxPool layer with equivalent parameters
            pool1d = nn.MaxPool1d(kernel_size**2, stride**2, padding, dilation, return_indices, ceil_mode)

            # Replace the 2D MaxPool layer with the new 1D MaxPool layer
            setattr(model, name, pool1d)

        if isinstance(module, nn.AdaptiveAvgPool2d):
            pool1d = nn.AdaptiveAvgPool1d(1)
            setattr(model, name, pool1d)
            

        # Recursively call the function for nested modules
        elif isinstance(module, nn.Module):
            convert_2d_to_1d(module)

In [63]:
model = ResNet()
convert_2d_to_1d(model)

In [64]:
x = torch.randn(2, 1, 48000)
x = model.compute_stage1(x, preprocess=0)
print(x.shape)
x = model.compute_stage2(x)
print(x.shape)
x = model.compute_stage3(x)
print(x.shape)
x = model.compute_stage4(x)
print(x.shape)

torch.Size([2, 64, 2999])
torch.Size([2, 128, 750])
torch.Size([2, 256, 188])
torch.Size([2, 512, 47])


In [49]:
sd = model.state_dict()
sd2 = model2.state_dict()

for k in sd.keys():
    print(k, sd[k].shape, sd2[k].shape)

model.spectrogram.window torch.Size([512]) torch.Size([512])
model.conv1.weight torch.Size([64, 1, 7]) torch.Size([64, 1, 7, 7])
model.bn1.weight torch.Size([64]) torch.Size([64])
model.bn1.bias torch.Size([64]) torch.Size([64])
model.bn1.running_mean torch.Size([64]) torch.Size([64])
model.bn1.running_var torch.Size([64]) torch.Size([64])
model.bn1.num_batches_tracked torch.Size([]) torch.Size([])
model.layer1.0.conv1.weight torch.Size([64, 64, 3]) torch.Size([64, 64, 3, 3])
model.layer1.0.bn1.weight torch.Size([64]) torch.Size([64])
model.layer1.0.bn1.bias torch.Size([64]) torch.Size([64])
model.layer1.0.bn1.running_mean torch.Size([64]) torch.Size([64])
model.layer1.0.bn1.running_var torch.Size([64]) torch.Size([64])
model.layer1.0.bn1.num_batches_tracked torch.Size([]) torch.Size([])
model.layer1.0.conv2.weight torch.Size([64, 64, 3]) torch.Size([64, 64, 3, 3])
model.layer1.0.bn2.weight torch.Size([64]) torch.Size([64])
model.layer1.0.bn2.bias torch.Size([64]) torch.Size([64])
mode

In [46]:
model2 = ResNet()
model2.state_dict()

OrderedDict([('model.spectrogram.window',
              tensor([0.0000e+00, 3.7640e-05, 1.5059e-04, 3.3882e-04, 6.0228e-04, 9.4095e-04,
                      1.3548e-03, 1.8437e-03, 2.4076e-03, 3.0465e-03, 3.7602e-03, 4.5487e-03,
                      5.4117e-03, 6.3493e-03, 7.3612e-03, 8.4473e-03, 9.6073e-03, 1.0841e-02,
                      1.2149e-02, 1.3530e-02, 1.4984e-02, 1.6512e-02, 1.8112e-02, 1.9785e-02,
                      2.1530e-02, 2.3347e-02, 2.5236e-02, 2.7196e-02, 2.9228e-02, 3.1330e-02,
                      3.3504e-02, 3.5747e-02, 3.8060e-02, 4.0443e-02, 4.2895e-02, 4.5416e-02,
                      4.8005e-02, 5.0663e-02, 5.3388e-02, 5.6180e-02, 5.9039e-02, 6.1965e-02,
                      6.4957e-02, 6.8014e-02, 7.1136e-02, 7.4322e-02, 7.7573e-02, 8.0888e-02,
                      8.4265e-02, 8.7705e-02, 9.1208e-02, 9.4771e-02, 9.8396e-02, 1.0208e-01,
                      1.0583e-01, 1.0963e-01, 1.1349e-01, 1.1742e-01, 1.2140e-01, 1.2543e-01,
                  

In [13]:
for k in model2.state_dict().keys():
    if not k in model.state_dict().keys():
        print(k)

spectrogram.window


In [None]:
from torchvision.models import resnet18
model = resnet18(weights='DEFAULT')
sd = model.state_dict()
sd['conv1.weight'] = torch.mean(sd['conv1.weight'], dim=1, keepdims=True)
_ = model2.load_state_dict(sd, strict=False)