In [1]:
from tf_model import make_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import types
import torch.onnx


C_max = 123
src_vocab = 30
tgt_vocab = 30
N = 2
d_model = 128
torch.backends.cudnn.benchmark = True
model = make_model(src_vocab=src_vocab, tgt_vocab=tgt_vocab, N=N, d_model=d_model)

checkpoint = torch.load("checkpoint.pth.tar", map_location=torch.device("cpu"))
model.load_state_dict(checkpoint['state_dict'])

for embed in (model.src_embed, model.tgt_embed):
    exp = embed[0]            # это ваш ExpandConv
    exp.d_model = d_model     # чтобы было в локальном scope
    exp._n_chan = C_max
    exp.lut = nn.Conv1d(in_channels=C_max,
                        out_channels=d_model,
                        kernel_size=1,
                        bias=True)

    def expand_forward(self, x):
        # x: (B, C_i, T)
        C_i = x.size(1)
        # берём только первые C_i каналов из весов [d_model × C_max × 1]
        w = self.lut.weight[:, :C_i, :]
        b = self.lut.bias
        y = F.conv1d(x, w, bias=b, stride=self.lut.stride, padding=self.lut.padding)
        # → (B, d_model, T) → (B, T, d_model) и масштаб
        return y.permute(0, 2, 1) * math.sqrt(self.d_model)

    exp.forward = types.MethodType(expand_forward, exp)

model.generator._n_chan = C_max    # <–– и здесь
model.generator.proj = nn.Linear(in_features=d_model,
                                 out_features=C_max,
                                 bias=True)

def generator_forward(self, x):
    C_i = self._n_chan
    W = self.proj.weight[:C_i, :]
    b = self.proj.bias[:C_i]
    B, T, D = x.shape
    flat = x.reshape(-1, D)          # (B*T, d_model)
    out = flat @ W.t() + b           # (B*T, C_i)
    return out.view(B, T, C_i)

model.generator.forward = types.MethodType(generator_forward, model.generator)

def init_xavier(m):
    if isinstance(m, (nn.Conv1d, nn.Linear)) and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

model.src_embed[0].apply(init_xavier)
model.tgt_embed[0].apply(init_xavier)
model.generator.apply(init_xavier)

for p in model.parameters():
    p.requires_grad = False
for m in (model.src_embed[0].lut,
          model.tgt_embed[0].lut,
          model.generator.proj):
    for p in m.parameters():
        p.requires_grad = True


  nn.init.xavier_uniform(p)
  checkpoint = torch.load("checkpoint.pth.tar", map_location=torch.device("cpu"))


In [None]:
checkpoint = torch.load('my_checkpoint.pth', map_location='cpu')
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

class InferenceModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, src):
        src_mask = torch.ones((src.size(0), 1, src.size(2)), device=src.device)
        memory = self.model.encoder(self.model.src_embed(src), src_mask)
        out = self.model.generator(memory)
        return out

# Create inference model
inference_model = InferenceModel(model)
inference_model.eval()

dummy_input = torch.randn(1, C_max, 2048)

torch.onnx.export(
    inference_model,
    dummy_input,
    "../eeg_auto_tools/cleaning_models/ART.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)


  checkpoint = torch.load('my_checkpoint.pth', map_location='cpu')
