In [1]:
from bp_torch import bp_torch
from bp_paddle import bp_paddle
import argparse
parser = argparse.ArgumentParser(description="FNet-CoLA")

parser.add_argument("--seed", type=int, default=1234, help='随机种子')
parser.add_argument("--torch-dir", type=str, default='google/fnet-large', help='模型位置')
parser.add_argument("--paddle-dir", type=str, default='../model/paddle/fnet-large', help='模型位置')
parser.add_argument("--batch-size", type=int, default=4, help='Batch Size')
parser.add_argument("--lr", type=float, default=1e-5, help='Learning Rate')
parser.add_argument("--warmup", type=int, default=0, help='Warmup Steps')
parser.add_argument("--num-epochs", type=int, default=3, help='Epoch 数')
args = parser.parse_args(args=[])
args.device = "cpu"


In [2]:
def compare(a, b):
    if isinstance(a, int):
        print(a == b)
        return
    if isinstance(a, float):
        print(a == b)
        return
    a = torch.from_numpy(a.detach().numpy()).float()
    b = torch.from_numpy(b.detach().numpy()).float()
    print("mean difference:", (a - b).abs().mean().numpy())
    print("max difference:", (a - b).abs().max().numpy())
    print("min difference:", (a - b).abs().min().numpy())

In [3]:
import numpy as np
import paddle
import random
from modeling import FNetModel, FNetForSequenceClassification
from reprod_log import ReprodLogger
paddle.set_device("cpu")
    
model_paddle = FNetForSequenceClassification.from_pretrained(
        "/root/autodl-tmp/PaddleFNet/model/paddle/fnet-base/")
classifier_weights_paddle = paddle.load(
        "classifier_weights/paddle_classifier_weights.bin")
model_paddle.load_dict(classifier_weights_paddle)
model_paddle.eval()
loss_fnc_paddle = paddle.nn.CrossEntropyLoss()

[32m[2021-12-01 13:47:05,734] [    INFO][0m - Weights of FNetForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias'][0m


In [4]:
import numpy as np
import torch
import torch.fft as fft
from reprod_log import ReprodLogger
from transformers import AdamW
from transformers.models.fnet import FNetForSequenceClassification

model_torch = FNetForSequenceClassification.from_pretrained(
        "/root/autodl-tmp/PaddleFNet/model/pytorch/fnet-base/", num_labels=2)
classifier_weights_torch = torch.load(
        "classifier_weights/torch_classifier_weights.bin")
model_torch.load_state_dict(classifier_weights_torch, strict=False)
model_torch.eval()
model_torch.to(args.device)
loss_fnc_torch = torch.nn.CrossEntropyLoss()

Some weights of FNetForSequenceClassification were not initialized from the model checkpoint at /root/autodl-tmp/PaddleFNet/model/pytorch/fnet-base/ and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# paddle
decay_params_paddle = [
        p.name for n, p in model_paddle.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
optimizer_paddle = paddle.optimizer.AdamW(
        learning_rate=args.lr,
        parameters=model_paddle.parameters(),
        weight_decay=1e-2,
        epsilon=1e-6,
        apply_decay_param_fun=lambda x: x in decay_params_paddle, )
fake_data = np.load("fake_data/fake_data.npy")
fake_label = np.load("fake_data/fake_label.npy")
input_ids_paddle = paddle.to_tensor(fake_data)
labels_paddle = paddle.to_tensor(fake_label)

# Torch
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters_torch = [
        {'params': [p for n, p in model_torch.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in model_torch.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer_torch = AdamW(
        optimizer_grouped_parameters_torch,
        lr=args.lr,
        eps=1e-6
)
input_ids_torch = torch.from_numpy(fake_data).to(args.device)
labels_torch = torch.from_numpy(fake_label).to(args.device)

In [12]:
output_paddle = model_paddle(input_ids_paddle)
output_torch = model_torch(input_ids_torch)[0]
compare(output_torch, output_paddle)

loss_paddle = loss_fnc_paddle(output_paddle, labels_paddle)
loss_torch = loss_fnc_torch(output_torch, labels_torch)
compare(loss_torch, loss_paddle)

mean difference: 9.3877316e-07
max difference: 2.451241e-06
min difference: 0.0
mean difference: 3.5762787e-07
max difference: 3.5762787e-07
min difference: 3.5762787e-07


In [11]:
loss_paddle.backward()
loss_torch.backward()

#optimizer_paddle.step()
#optimizer_torch.step()

In [None]:
for name, param in model_paddle.named_parameters():
    print(name)

In [9]:
for i, val1 in enumerate(model_torch.named_parameters()):
    for j, val2 in enumerate(model_paddle.named_parameters()):
        if i != j:
            continue
        if ".weight" in val1[0]:
            if ("embeddings." not in val1[0] and ".LayerNorm." not in val1[0]) or "embeddings.projection" in val1[0]:
                if val1[1].ndim == 2:
                    compare(val1[1].grad,val2[1].grad.t())
                else:
                    compare(val1[1].grad,val2[1].grad)
            else:
                compare(val1[1].grad,val2[1].grad)

mean difference: 4.6746648e-11
max difference: 8.586794e-07
min difference: 0.0
mean difference: 1.551815e-09
max difference: 1.0319054e-06
min difference: 0.0
mean difference: 7.804825e-08
max difference: 6.2584877e-06
min difference: 0.0
mean difference: 1.7867082e-08
max difference: 5.5442797e-07
min difference: 2.910383e-11
mean difference: 4.3043946e-09
max difference: 7.0547685e-08
min difference: 0.0
mean difference: 1.1985196e-08
max difference: 7.3661795e-08
min difference: 0.0
mean difference: 1.0260561e-09
max difference: 8.8475645e-07
min difference: 0.0
mean difference: 9.781608e-10
max difference: 1.3189856e-07
min difference: 0.0
mean difference: 8.5712015e-09
max difference: 1.6624108e-07
min difference: 0.0
mean difference: 8.457069e-09
max difference: 5.6810677e-08
min difference: 0.0
mean difference: 1.293989e-09
max difference: 5.4831617e-08
min difference: 0.0
mean difference: 1.4648006e-09
max difference: 6.056507e-08
min difference: 0.0
mean difference: 9.587563e