In [1]:
from torch.func import functional_call
import torch.autograd.forward_ad as fwAD
import torch
import torch.nn as nn
from torch.optim import Adam

from model import SimpleNet
from dataset import get_core_train_loader, split_dataset_core_train

model = SimpleNet(784, 10)
checkpoint = torch.load('./checkpoints/042924-train-core-dataset-mnist/last_checkpoint.pth')
model.load_state_dict(checkpoint['model'])

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
params = {name: p.detach().clone().to(device) for name, p in model.named_parameters()}
del model

_, train_dataset = split_dataset_core_train('mnist', 'simple', 0.2)
_, train_loader = get_core_train_loader(_, train_dataset, 128)

In [2]:
print(params)

{'lin_1.weight': tensor([[ 0.0208,  0.0336,  0.0317,  ...,  0.0054,  0.0312,  0.0195],
        [ 0.0208, -0.0243, -0.0056,  ..., -0.0243,  0.0037,  0.0152],
        [ 0.0162,  0.0131, -0.0214,  ...,  0.0194,  0.0005,  0.0351],
        ...,
        [-0.0551, -0.0420, -0.0266,  ...,  0.0027, -0.0332, -0.0093],
        [ 0.0195,  0.0358,  0.0273,  ...,  0.0100,  0.0219,  0.0078],
        [ 0.0364,  0.0124,  0.0196,  ...,  0.0343,  0.0170,  0.0218]],
       device='cuda:0'), 'lin_1.bias': tensor([ 0.0095,  0.0004,  0.0197,  0.0114, -0.0445,  0.0104, -0.0115, -0.0331,
         0.0105, -0.0243,  0.0063, -0.0338,  0.0037, -0.0364, -0.0052, -0.0204,
        -0.0054, -0.0315, -0.0095,  0.0188,  0.0281,  0.0110,  0.0150,  0.0233,
         0.0278, -0.0041,  0.0072,  0.0102,  0.0416, -0.0135, -0.0125,  0.0089,
         0.0245, -0.0384,  0.0027, -0.0257,  0.0003, -0.0337,  0.0234, -0.0157,
        -0.0351,  0.0114, -0.0169, -0.0192, -0.0372, -0.0257, -0.0181,  0.0144,
        -0.0060, -0.0163, -0.0

In [3]:
class MixedLinear(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.tangent_model = SimpleNet(784, 10) 
        self.tangents = {name: p for name, p in self.tangent_model.named_parameters()}
    
    def forward(self, params, x):
        dual_params = {}
        with fwAD.dual_level():
            for name, p in params.items():
                dual_params[name] = fwAD.make_dual(p, self.tangents[name])
            out = functional_call(self.tangent_model, dual_params, x)
            jvp = fwAD.unpack_dual(out).tangent
        return out + jvp

In [4]:
mixed_linear = MixedLinear()
mixed_linear = mixed_linear.to(device)

In [5]:
mixed_linear.tangents

{'lin_1.weight': Parameter containing:
 tensor([[ 0.0200, -0.0089, -0.0070,  ..., -0.0355, -0.0322, -0.0152],
         [ 0.0178,  0.0353,  0.0136,  ..., -0.0044,  0.0349, -0.0337],
         [ 0.0165,  0.0110, -0.0238,  ..., -0.0047,  0.0092,  0.0350],
         ...,
         [-0.0174, -0.0197, -0.0018,  ...,  0.0297,  0.0218,  0.0176],
         [-0.0319,  0.0075, -0.0006,  ..., -0.0277,  0.0324,  0.0008],
         [ 0.0107,  0.0331, -0.0029,  ...,  0.0101,  0.0118,  0.0353]],
        device='cuda:0', requires_grad=True),
 'lin_1.bias': Parameter containing:
 tensor([-3.2776e-02, -1.6389e-03, -3.2103e-02, -4.2866e-03, -2.3340e-02,
          8.1726e-05,  4.9326e-03,  1.6578e-02,  5.0908e-03,  3.1007e-03,
         -2.9043e-02, -1.6292e-02, -2.1973e-02, -3.4411e-03, -1.7099e-02,
         -2.2489e-02,  1.0228e-02, -2.8306e-03, -2.7606e-02, -4.1055e-03,
         -1.3854e-02,  2.7847e-02, -6.8181e-03,  3.5263e-02,  2.8912e-02,
         -1.8661e-02,  1.7160e-02,  8.4342e-03, -1.7303e-02, -2.862

In [6]:
[param for param in mixed_linear.tangent_model.parameters()]

[Parameter containing:
 tensor([[ 0.0200, -0.0089, -0.0070,  ..., -0.0355, -0.0322, -0.0152],
         [ 0.0178,  0.0353,  0.0136,  ..., -0.0044,  0.0349, -0.0337],
         [ 0.0165,  0.0110, -0.0238,  ..., -0.0047,  0.0092,  0.0350],
         ...,
         [-0.0174, -0.0197, -0.0018,  ...,  0.0297,  0.0218,  0.0176],
         [-0.0319,  0.0075, -0.0006,  ..., -0.0277,  0.0324,  0.0008],
         [ 0.0107,  0.0331, -0.0029,  ...,  0.0101,  0.0118,  0.0353]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([-3.2776e-02, -1.6389e-03, -3.2103e-02, -4.2866e-03, -2.3340e-02,
          8.1726e-05,  4.9326e-03,  1.6578e-02,  5.0908e-03,  3.1007e-03,
         -2.9043e-02, -1.6292e-02, -2.1973e-02, -3.4411e-03, -1.7099e-02,
         -2.2489e-02,  1.0228e-02, -2.8306e-03, -2.7606e-02, -4.1055e-03,
         -1.3854e-02,  2.7847e-02, -6.8181e-03,  3.5263e-02,  2.8912e-02,
         -1.8661e-02,  1.7160e-02,  8.4342e-03, -1.7303e-02, -2.8621e-02,
         -8.3778e-03,  

In [7]:
optimizer = Adam(mixed_linear.parameters(), lr=0.001)
criterion = nn.MSELoss()

for epoch in range(100):
    mixed_linear.train()
    for iter_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        data = data.reshape((-1, 784))
        optimizer.zero_grad()
        preds = mixed_linear(params, data)
        loss = criterion(preds, label)
        loss.backward()
        optimizer.step()
        if iter_idx == 0 or (iter_idx + 1) % 50 == 0 or (iter_idx + 1) == len(train_loader): 
            print('epoch: {}, iter_idx: {}, loss: {}'.format(epoch + 1, iter_idx + 1, loss.item()))
    
    mixed_linear.eval()
    with torch.no_grad():
        true_count = 0
        sample_count = 0
        for data, label in train_loader:
            data, label = data.to(device), label.to(device)
            data = data.reshape((-1, 784))
            preds = mixed_linear(params, data)
            predicted_labels = torch.argmax(preds, dim=1)
            ground_truth = torch.argmax(label, dim=1)
            true_count += torch.count_nonzero(predicted_labels == ground_truth).item()
            sample_count += data.shape[0]
        print('epoch: {}, accuracy: {}'.format(epoch + 1, (true_count / sample_count)))

epoch: 1, iter_idx: 1, loss: 235.1256866455078
epoch: 1, iter_idx: 50, loss: 20.913421630859375
epoch: 1, iter_idx: 94, loss: 13.852892875671387
epoch: 1, accuracy: 0.3179166666666667
epoch: 2, iter_idx: 1, loss: 12.282530784606934
epoch: 2, iter_idx: 50, loss: 9.477126121520996
epoch: 2, iter_idx: 94, loss: 7.5135321617126465
epoch: 2, accuracy: 0.30316666666666664
epoch: 3, iter_idx: 1, loss: 7.1286821365356445
epoch: 3, iter_idx: 50, loss: 5.605279445648193
epoch: 3, iter_idx: 94, loss: 4.649112224578857
epoch: 3, accuracy: 0.30725
epoch: 4, iter_idx: 1, loss: 4.4858832359313965
epoch: 4, iter_idx: 50, loss: 3.8514435291290283
epoch: 4, iter_idx: 94, loss: 3.0707461833953857
epoch: 4, accuracy: 0.32016666666666665
epoch: 5, iter_idx: 1, loss: 2.8337345123291016
epoch: 5, iter_idx: 50, loss: 2.5082554817199707
epoch: 5, iter_idx: 94, loss: 2.6235270500183105
epoch: 5, accuracy: 0.31333333333333335
epoch: 6, iter_idx: 1, loss: 2.4107778072357178
epoch: 6, iter_idx: 50, loss: 1.9695236

In [5]:
params

{'lin_1.weight': tensor([[ 0.0208,  0.0336,  0.0317,  ...,  0.0054,  0.0312,  0.0195],
         [ 0.0208, -0.0243, -0.0056,  ..., -0.0243,  0.0037,  0.0152],
         [ 0.0162,  0.0131, -0.0214,  ...,  0.0194,  0.0005,  0.0351],
         ...,
         [-0.0551, -0.0420, -0.0266,  ...,  0.0027, -0.0332, -0.0093],
         [ 0.0195,  0.0358,  0.0273,  ...,  0.0100,  0.0219,  0.0078],
         [ 0.0364,  0.0124,  0.0196,  ...,  0.0343,  0.0170,  0.0218]],
        device='cuda:0'),
 'lin_1.bias': tensor([ 0.0095,  0.0004,  0.0197,  0.0114, -0.0445,  0.0104, -0.0115, -0.0331,
          0.0105, -0.0243,  0.0063, -0.0338,  0.0037, -0.0364, -0.0052, -0.0204,
         -0.0054, -0.0315, -0.0095,  0.0188,  0.0281,  0.0110,  0.0150,  0.0233,
          0.0278, -0.0041,  0.0072,  0.0102,  0.0416, -0.0135, -0.0125,  0.0089,
          0.0245, -0.0384,  0.0027, -0.0257,  0.0003, -0.0337,  0.0234, -0.0157,
         -0.0351,  0.0114, -0.0169, -0.0192, -0.0372, -0.0257, -0.0181,  0.0144,
         -0.0060,

In [6]:
mixed_linear.tangents

{'lin_1.weight': Parameter containing:
 tensor([[-0.0066, -0.0198,  0.0250,  ..., -0.0190,  0.0086,  0.0204],
         [ 0.0261,  0.0283, -0.0039,  ..., -0.0066, -0.0087, -0.0004],
         [ 0.0279,  0.0105,  0.0245,  ..., -0.0267, -0.0085, -0.0197],
         ...,
         [ 0.0298,  0.0151,  0.0262,  ...,  0.0159,  0.0442,  0.0422],
         [ 0.0279,  0.0092, -0.0378,  ..., -0.0040,  0.0196, -0.0344],
         [-0.0120, -0.0096,  0.0039,  ...,  0.0317,  0.0370,  0.0101]],
        device='cuda:0', requires_grad=True),
 'lin_1.bias': Parameter containing:
 tensor([-0.0162, -0.0100,  0.0187,  0.0209, -0.0109,  0.0053, -0.0072,  0.0106,
         -0.0198, -0.0135, -0.0385,  0.0133, -0.0445, -0.0157, -0.0005,  0.0169,
         -0.0282,  0.0248, -0.0379, -0.0052,  0.0254,  0.0207,  0.0285, -0.0328,
         -0.0051, -0.0054,  0.0040, -0.0369, -0.0362, -0.0127, -0.0369, -0.0254,
         -0.0228, -0.0053,  0.0194,  0.0185,  0.0202, -0.0364,  0.0232,  0.0195,
         -0.0066,  0.0183, -0.03

In [8]:
[param for param in mixed_linear.tangent_model.parameters()]

[Parameter containing:
 tensor([[-0.0066, -0.0198,  0.0250,  ..., -0.0190,  0.0086,  0.0204],
         [ 0.0261,  0.0283, -0.0039,  ..., -0.0066, -0.0087, -0.0004],
         [ 0.0279,  0.0105,  0.0245,  ..., -0.0267, -0.0085, -0.0197],
         ...,
         [ 0.0298,  0.0151,  0.0262,  ...,  0.0159,  0.0442,  0.0422],
         [ 0.0279,  0.0092, -0.0378,  ..., -0.0040,  0.0196, -0.0344],
         [-0.0120, -0.0096,  0.0039,  ...,  0.0317,  0.0370,  0.0101]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([-0.0162, -0.0100,  0.0187,  0.0209, -0.0109,  0.0053, -0.0072,  0.0106,
         -0.0198, -0.0135, -0.0385,  0.0133, -0.0445, -0.0157, -0.0005,  0.0169,
         -0.0282,  0.0248, -0.0379, -0.0052,  0.0254,  0.0207,  0.0285, -0.0328,
         -0.0051, -0.0054,  0.0040, -0.0369, -0.0362, -0.0127, -0.0369, -0.0254,
         -0.0228, -0.0053,  0.0194,  0.0185,  0.0202, -0.0364,  0.0232,  0.0195,
         -0.0066,  0.0183, -0.0361,  0.0032,  0.0269, -0.0149,