In [None]:
import jittor as jt
import jittor.nn as nn
jt.flags.use_cuda = 1
from jittor.dataset.mnist import MNIST
from jittornode import odeint
from tqdm import tqdm

class nmODEBlock(nn.Module):
    def __init__(self):
        super(nmODEBlock, self).__init__()
    
    def fresh(self, gamma):
        self.gamma = gamma

    def execute(self, t, y):
        dydt = -y + jt.pow(jt.sin(y + self.gamma), 2)
        return dydt
    

class nmODENet(nn.Module):
    def __init__(self):
        super(nmODENet, self).__init__()
        self.conv1 = nn.Conv(1, 32, 3, padding=1)      
        self.relu = nn.ReLU()
        self.pool = nn.Pool(2, 2)                      
        self.conv2 = nn.Conv(32, 64, 3, padding=1)     
        
        self.gap = nn.AdaptiveAvgPool2d(1)             
        self.flatten = nn.Flatten()

        self.nmODE = nmODEBlock()                    
        self.t = jt.linspace(0.0, 1.0, 10)

        self.fc = nn.Linear(64, 10)
        
    def execute(self, x):
        b = x.size(0)
        
        x = self.relu(self.conv1(x))  
        x = self.pool(x)              
        x = self.relu(self.conv2(x))   
        x = self.gap(x)        
        gamma = self.flatten(x)
        
        self.nmODE.fresh(gamma) # feed external input gamma into nmODE
        
        y_0 = jt.zeros((b, 64)) # initial value of y, i.e., y(0)
        y_T = odeint(self.nmODE, y_0, self.t, method='rk4') # solve the nmODE to obtain y(T)
        y_T = y_T[-1] # only use the last time step
        
        return self.fc(y_T)


In [None]:
# load MNIST
train_loader = MNIST(train=True).set_attrs(batch_size=128, shuffle=True)
test_loader = MNIST(train=False).set_attrs(batch_size=128)

# create nmODE network 
model = nmODENet()
optimizer = nn.Adam(model.parameters(), lr=0.001)

# training and test
for epoch in range(100):
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} [Train]")
    for i, (imgs, labels) in pbar:
        imgs = imgs[:, 0:1, :, :]  # only load the first channel
        labels = jt.array(labels)

        preds = model(imgs)
        loss = jt.nn.cross_entropy_loss(preds, labels)
        optimizer.step(loss)
        
        if i % 20 == 0:
            pred_labels, _ = jt.argmax(preds, dim=1)
            acc = (pred_labels == labels).float32().mean()

            pbar.set_postfix({
                "Iter": i,
                "Loss": f"{loss.item():.4f}",
                "Acc": f"{acc.item():.4f}"
            })

    total, correct = 0, 0
    for imgs, labels in test_loader:
        imgs = imgs[:, 0:1, :, :]
        labels = jt.array(labels)

        preds = model(imgs)
        pred_labels, _ = jt.argmax(preds, dim=1)
        correct += (pred_labels == labels).sum().item()
        total += labels.shape[0]

    test_acc = correct / total
    print(f"Epoch {epoch} [Test] Accuracy: {test_acc:.4f}")